1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
|
"""Validate the 1-WL instrument on canonical graphs BEFORE trusting any decomposition.
Run: PYTHONPATH=/home/yurenh2/rrog python3 /home/yurenh2/rrog/diag/selftest_wl.py
"""
import numpy as np
from diag import wl, datasets
def run():
P = datasets.canonical_pairs()
names = list(P.keys())
adjs = [wl.edges_to_adj(P[k]['n'], P[k]['edge_index']) for k in names]
node_rounds, ghist_rounds, conv = wl.wl_refine(adjs)
color = {names[i]: ghist_rounds[conv][i] for i in range(len(names))}
print("converged round:", conv)
for k in names:
print(f" {k:5s} tri={P[k]['tri']} wlcolor={color[k]}")
# (1) C6 == 2C3 under 1-WL (both 2-regular) yet differ in triangles -> counting is H2
assert color['C6'] == color['2C3'], "C6 vs 2C3 should be 1-WL-equal"
assert P['C6']['tri'] != P['2C3']['tri']
# (2) P4 != K1,3 (different degree multiset)
assert color['P4'] != color['K1,3'], "P4 vs K1,3 should be 1-WL-distinct"
print("OK canonical: C6==2C3 (WL blind to triangles), P4!=K1,3")
# (3) regression H2 floor on {C6,2C3} == target variance
sub = [wl.edges_to_adj(P['C6']['n'], P['C6']['edge_index']),
wl.edges_to_adj(P['2C3']['n'], P['2C3']['edge_index'])]
y = [P['C6']['tri'], P['2C3']['tri']]
_, gh, cv = wl.wl_refine(sub)
dec = wl.decompose_regression(gh, cv, L=10, y=y, train_idx=[0, 1], eval_idx=[0, 1])
print(f" triangle-count H2 floor MSE on {{C6,2C3}} = {dec['mse_floor_oracle_H2']:.4f} "
f"(target var = {np.var(y):.4f})")
assert abs(dec['mse_floor_oracle_H2'] - np.var(y)) < 1e-9
# (4) CSL: 4-regular -> 1 node color, 1 graph color, WL-optimal acc = chance (0.1) -> 100% H2
csl = datasets.build_csl(n_per_class=15, seed=0)
adjs = [wl.edges_to_adj(d['n'], d['edge_index']) for d in csl]
nr, gh, cv = wl.wl_refine(adjs)
n_node_colors = len(set(nr[cv][0].tolist()))
n_graph_colors = len(set(gh[cv]))
y = [d['y'] for d in csl]
idx = list(range(len(csl)))
att = wl.attribute_classification(gh, cv, L=4, y=y, train_idx=idx, eval_idx=idx)
print(f" CSL: node-colors={n_node_colors}, distinct graph-colors={n_graph_colors}, "
f"WL-optimal acc={att['wl_optimal_acc_converged']:.3f} (chance 0.1), buckets={att['counts']}")
assert n_node_colors == 1 and n_graph_colors == 1
assert abs(att['wl_optimal_acc_converged'] - 0.1) < 1e-6
assert att['counts'].get('H2', 0) == len(csl)
print("OK CSL: fully 1-WL-collapsed -> 100% of failures are H2. Instrument VALIDATED.")
if __name__ == "__main__":
run()
|