summaryrefslogtreecommitdiff
path: root/diag/selftest_wl.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/selftest_wl.py')
-rw-r--r--diag/selftest_wl.py53
1 files changed, 53 insertions, 0 deletions
diff --git a/diag/selftest_wl.py b/diag/selftest_wl.py
new file mode 100644
index 0000000..6c08310
--- /dev/null
+++ b/diag/selftest_wl.py
@@ -0,0 +1,53 @@
+"""Validate the 1-WL instrument on canonical graphs BEFORE trusting any decomposition.
+Run: PYTHONPATH=/home/yurenh2/rrog python3 /home/yurenh2/rrog/diag/selftest_wl.py
+"""
+import numpy as np
+from diag import wl, datasets
+
+
+def run():
+ P = datasets.canonical_pairs()
+ names = list(P.keys())
+ adjs = [wl.edges_to_adj(P[k]['n'], P[k]['edge_index']) for k in names]
+ node_rounds, ghist_rounds, conv = wl.wl_refine(adjs)
+ color = {names[i]: ghist_rounds[conv][i] for i in range(len(names))}
+ print("converged round:", conv)
+ for k in names:
+ print(f" {k:5s} tri={P[k]['tri']} wlcolor={color[k]}")
+
+ # (1) C6 == 2C3 under 1-WL (both 2-regular) yet differ in triangles -> counting is H2
+ assert color['C6'] == color['2C3'], "C6 vs 2C3 should be 1-WL-equal"
+ assert P['C6']['tri'] != P['2C3']['tri']
+ # (2) P4 != K1,3 (different degree multiset)
+ assert color['P4'] != color['K1,3'], "P4 vs K1,3 should be 1-WL-distinct"
+ print("OK canonical: C6==2C3 (WL blind to triangles), P4!=K1,3")
+
+ # (3) regression H2 floor on {C6,2C3} == target variance
+ sub = [wl.edges_to_adj(P['C6']['n'], P['C6']['edge_index']),
+ wl.edges_to_adj(P['2C3']['n'], P['2C3']['edge_index'])]
+ y = [P['C6']['tri'], P['2C3']['tri']]
+ _, gh, cv = wl.wl_refine(sub)
+ dec = wl.decompose_regression(gh, cv, L=10, y=y, train_idx=[0, 1], eval_idx=[0, 1])
+ print(f" triangle-count H2 floor MSE on {{C6,2C3}} = {dec['mse_floor_oracle_H2']:.4f} "
+ f"(target var = {np.var(y):.4f})")
+ assert abs(dec['mse_floor_oracle_H2'] - np.var(y)) < 1e-9
+
+ # (4) CSL: 4-regular -> 1 node color, 1 graph color, WL-optimal acc = chance (0.1) -> 100% H2
+ csl = datasets.build_csl(n_per_class=15, seed=0)
+ adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in csl]
+ nr, gh, cv = wl.wl_refine(adjs)
+ n_node_colors = len(set(nr[cv][0].tolist()))
+ n_graph_colors = len(set(gh[cv]))
+ y = [d['y'] for d in csl]
+ idx = list(range(len(csl)))
+ att = wl.attribute_classification(gh, cv, L=4, y=y, train_idx=idx, eval_idx=idx)
+ print(f" CSL: node-colors={n_node_colors}, distinct graph-colors={n_graph_colors}, "
+ f"WL-optimal acc={att['wl_optimal_acc_converged']:.3f} (chance 0.1), buckets={att['counts']}")
+ assert n_node_colors == 1 and n_graph_colors == 1
+ assert abs(att['wl_optimal_acc_converged'] - 0.1) < 1e-6
+ assert att['counts'].get('H2', 0) == len(csl)
+ print("OK CSL: fully 1-WL-collapsed -> 100% of failures are H2. Instrument VALIDATED.")
+
+
+if __name__ == "__main__":
+ run()