"""Validate the 1-WL instrument on canonical graphs BEFORE trusting any decomposition. Run: PYTHONPATH=/home/yurenh2/rrog python3 /home/yurenh2/rrog/diag/selftest_wl.py """ import numpy as np from diag import wl, datasets def run(): P = datasets.canonical_pairs() names = list(P.keys()) adjs = [wl.edges_to_adj(P[k]['n'], P[k]['edge_index']) for k in names] node_rounds, ghist_rounds, conv = wl.wl_refine(adjs) color = {names[i]: ghist_rounds[conv][i] for i in range(len(names))} print("converged round:", conv) for k in names: print(f" {k:5s} tri={P[k]['tri']} wlcolor={color[k]}") # (1) C6 == 2C3 under 1-WL (both 2-regular) yet differ in triangles -> counting is H2 assert color['C6'] == color['2C3'], "C6 vs 2C3 should be 1-WL-equal" assert P['C6']['tri'] != P['2C3']['tri'] # (2) P4 != K1,3 (different degree multiset) assert color['P4'] != color['K1,3'], "P4 vs K1,3 should be 1-WL-distinct" print("OK canonical: C6==2C3 (WL blind to triangles), P4!=K1,3") # (3) regression H2 floor on {C6,2C3} == target variance sub = [wl.edges_to_adj(P['C6']['n'], P['C6']['edge_index']), wl.edges_to_adj(P['2C3']['n'], P['2C3']['edge_index'])] y = [P['C6']['tri'], P['2C3']['tri']] _, gh, cv = wl.wl_refine(sub) dec = wl.decompose_regression(gh, cv, L=10, y=y, train_idx=[0, 1], eval_idx=[0, 1]) print(f" triangle-count H2 floor MSE on {{C6,2C3}} = {dec['mse_floor_oracle_H2']:.4f} " f"(target var = {np.var(y):.4f})") assert abs(dec['mse_floor_oracle_H2'] - np.var(y)) < 1e-9 # (4) CSL: 4-regular -> 1 node color, 1 graph color, WL-optimal acc = chance (0.1) -> 100% H2 csl = datasets.build_csl(n_per_class=15, seed=0) adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in csl] nr, gh, cv = wl.wl_refine(adjs) n_node_colors = len(set(nr[cv][0].tolist())) n_graph_colors = len(set(gh[cv])) y = [d['y'] for d in csl] idx = list(range(len(csl))) att = wl.attribute_classification(gh, cv, L=4, y=y, train_idx=idx, eval_idx=idx) print(f" CSL: node-colors={n_node_colors}, distinct graph-colors={n_graph_colors}, " f"WL-optimal acc={att['wl_optimal_acc_converged']:.3f} (chance 0.1), buckets={att['counts']}") assert n_node_colors == 1 and n_graph_colors == 1 assert abs(att['wl_optimal_acc_converged'] - 0.1) < 1e-6 assert att['counts'].get('H2', 0) == len(csl) print("OK CSL: fully 1-WL-collapsed -> 100% of failures are H2. Instrument VALIDATED.") if __name__ == "__main__": run()