diff options
Diffstat (limited to 'diag/selftest_wl.py')
| -rw-r--r-- | diag/selftest_wl.py | 53 |
1 files changed, 53 insertions, 0 deletions
diff --git a/diag/selftest_wl.py b/diag/selftest_wl.py new file mode 100644 index 0000000..6c08310 --- /dev/null +++ b/diag/selftest_wl.py @@ -0,0 +1,53 @@ +"""Validate the 1-WL instrument on canonical graphs BEFORE trusting any decomposition. +Run: PYTHONPATH=/home/yurenh2/rrog python3 /home/yurenh2/rrog/diag/selftest_wl.py +""" +import numpy as np +from diag import wl, datasets + + +def run(): + P = datasets.canonical_pairs() + names = list(P.keys()) + adjs = [wl.edges_to_adj(P[k]['n'], P[k]['edge_index']) for k in names] + node_rounds, ghist_rounds, conv = wl.wl_refine(adjs) + color = {names[i]: ghist_rounds[conv][i] for i in range(len(names))} + print("converged round:", conv) + for k in names: + print(f" {k:5s} tri={P[k]['tri']} wlcolor={color[k]}") + + # (1) C6 == 2C3 under 1-WL (both 2-regular) yet differ in triangles -> counting is H2 + assert color['C6'] == color['2C3'], "C6 vs 2C3 should be 1-WL-equal" + assert P['C6']['tri'] != P['2C3']['tri'] + # (2) P4 != K1,3 (different degree multiset) + assert color['P4'] != color['K1,3'], "P4 vs K1,3 should be 1-WL-distinct" + print("OK canonical: C6==2C3 (WL blind to triangles), P4!=K1,3") + + # (3) regression H2 floor on {C6,2C3} == target variance + sub = [wl.edges_to_adj(P['C6']['n'], P['C6']['edge_index']), + wl.edges_to_adj(P['2C3']['n'], P['2C3']['edge_index'])] + y = [P['C6']['tri'], P['2C3']['tri']] + _, gh, cv = wl.wl_refine(sub) + dec = wl.decompose_regression(gh, cv, L=10, y=y, train_idx=[0, 1], eval_idx=[0, 1]) + print(f" triangle-count H2 floor MSE on {{C6,2C3}} = {dec['mse_floor_oracle_H2']:.4f} " + f"(target var = {np.var(y):.4f})") + assert abs(dec['mse_floor_oracle_H2'] - np.var(y)) < 1e-9 + + # (4) CSL: 4-regular -> 1 node color, 1 graph color, WL-optimal acc = chance (0.1) -> 100% H2 + csl = datasets.build_csl(n_per_class=15, seed=0) + adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in csl] + nr, gh, cv = wl.wl_refine(adjs) + n_node_colors = len(set(nr[cv][0].tolist())) + n_graph_colors = len(set(gh[cv])) + y = [d['y'] for d in csl] + idx = list(range(len(csl))) + att = wl.attribute_classification(gh, cv, L=4, y=y, train_idx=idx, eval_idx=idx) + print(f" CSL: node-colors={n_node_colors}, distinct graph-colors={n_graph_colors}, " + f"WL-optimal acc={att['wl_optimal_acc_converged']:.3f} (chance 0.1), buckets={att['counts']}") + assert n_node_colors == 1 and n_graph_colors == 1 + assert abs(att['wl_optimal_acc_converged'] - 0.1) < 1e-6 + assert att['counts'].get('H2', 0) == len(csl) + print("OK CSL: fully 1-WL-collapsed -> 100% of failures are H2. Instrument VALIDATED.") + + +if __name__ == "__main__": + run() |
