summaryrefslogtreecommitdiff
path: root/diag/wl.py
diff options
context:
space:
mode:
Diffstat (limited to 'diag/wl.py')
-rw-r--r--diag/wl.py166
1 files changed, 166 insertions, 0 deletions
diff --git a/diag/wl.py b/diag/wl.py
new file mode 100644
index 0000000..26ab0a3
--- /dev/null
+++ b/diag/wl.py
@@ -0,0 +1,166 @@
+"""1-WL color-refinement instrument for diagnosing GNN failures (H1 vs H2).
+
+A GIN with L layers == L rounds of 1-WL refinement (injective sum aggregation).
+A failure on sample i is attributed by label purity of its WL color classes:
+
+ converged-WL class IMPURE (train labels conflict under same color)
+ -> H2 : 1-WL ceiling. No MPNN at ANY depth separates -> needs >1-WL (noise).
+ converged pure, but L-round class impure
+ -> H1a_depth : separable only with MORE rounds -> deterministic RR-on-graph / depth helps.
+ L-round class pure (info present at depth L) but model wrong
+ -> H1b_opt : optimization / capacity. Train better.
+
+Refinement is dataset-global (shared per-round signature->label map) so node colors and
+graph-color histograms are comparable across graphs.
+"""
+from collections import Counter, defaultdict
+import numpy as np
+
+
+def edges_to_adj(n, edge_index):
+ adj = [[] for _ in range(n)]
+ ei = np.asarray(edge_index)
+ for a, b in zip(ei[0].tolist(), ei[1].tolist()):
+ adj[a].append(b)
+ return adj
+
+
+def wl_refine(adjs, inits=None, max_rounds=None):
+ """Dataset-level 1-WL. Returns (node_rounds, ghist_rounds, conv_round).
+ node_rounds[r][g] = int color array (global labels) of graph g after r rounds.
+ ghist_rounds[r][g] = canonical color histogram (hashable) of graph g after r rounds.
+ conv_round = round index at which the global partition stabilized.
+ """
+ if inits is None:
+ inits = [np.zeros(len(a), dtype=np.int64) for a in adjs]
+ else:
+ inits = [np.asarray(x, dtype=np.int64) for x in inits]
+ if max_rounds is None:
+ max_rounds = max((len(a) for a in adjs), default=0) + 2
+
+ d = {}
+ def lab(s):
+ v = d.get(s)
+ if v is None:
+ v = len(d); d[s] = v
+ return v
+
+ cur = [np.array([lab(('i', int(c))) for c in init], dtype=np.int64) for init in inits]
+ node_rounds = [cur]
+ nclasses = [len(d)]
+
+ for _r in range(max_rounds):
+ d = {}
+ nxt = []
+ for adj in adjs:
+ c = cur_g = node_rounds[-1][len(nxt)]
+ arr = np.empty(len(adj), dtype=np.int64)
+ for v in range(len(adj)):
+ sig = (int(c[v]), tuple(sorted(int(c[u]) for u in adj[v])))
+ arr[v] = lab(sig)
+ nxt.append(arr)
+ node_rounds.append(nxt)
+ nclasses.append(len(d))
+ if nclasses[-1] == nclasses[-2]: # global #classes stopped growing -> converged
+ break
+
+ conv_round = len(node_rounds) - 1
+ ghist_rounds = [[_hist(c) for c in nr] for nr in node_rounds]
+ return node_rounds, ghist_rounds, conv_round
+
+
+def _hist(colors):
+ return tuple(sorted(Counter(colors.tolist()).items()))
+
+
+def graph_colors_at(ghist_rounds, conv_round, L):
+ return ghist_rounds[min(L, conv_round)]
+
+
+# ---------- classification attribution ----------
+def attribute_classification(ghist_rounds, conv_round, L, y, train_idx, eval_idx):
+ y = np.asarray(y)
+ conv = ghist_rounds[conv_round]
+ Lr = min(L, conv_round)
+ lr = ghist_rounds[Lr]
+ conv_train, lr_train = defaultdict(list), defaultdict(list)
+ for i in train_idx:
+ conv_train[conv[i]].append(int(y[i]))
+ lr_train[lr[i]].append(int(y[i]))
+
+ def pure(dct, key):
+ labs = dct.get(key)
+ return labs is not None and len(set(labs)) == 1
+
+ def majority(dct, key):
+ labs = dct.get(key)
+ return Counter(labs).most_common(1)[0][0] if labs else None
+
+ buckets = {}
+ wl_opt = lr_opt = 0
+ for i in eval_idx:
+ if conv[i] not in conv_train:
+ buckets[i] = 'novel'
+ elif not pure(conv_train, conv[i]):
+ buckets[i] = 'H2'
+ elif not pure(lr_train, lr[i]):
+ buckets[i] = 'H1a_depth'
+ else:
+ buckets[i] = 'H1b_opt'
+ if majority(conv_train, conv[i]) == int(y[i]):
+ wl_opt += 1
+ if majority(lr_train, lr[i]) == int(y[i]):
+ lr_opt += 1
+ n = len(eval_idx)
+ return {
+ 'buckets': buckets,
+ 'counts': dict(Counter(buckets.values())),
+ 'wl_optimal_acc_converged': wl_opt / n, # best ANY MPNN can do
+ 'wl_optimal_acc_Ldepth': lr_opt / n, # best L-layer MPNN can do
+ 'L_used': Lr, 'conv_round': conv_round,
+ }
+
+
+# ---------- regression decomposition ----------
+def decompose_regression(ghist_rounds, conv_round, L, y, train_idx, eval_idx):
+ """H2 floor = ORACLE within-color variance on FULL data (best possible function of the WL
+ color: do same-color graphs share the target?). This is the true information ceiling and is
+ NOT confounded by train/test coverage. The train-fitted floors are also reported to expose
+ how much apparent error is really novel-color generalization, plus coverage fractions."""
+ y = np.asarray(y, dtype=np.float64)
+ conv = ghist_rounds[conv_round]
+ Lr = min(L, conv_round)
+ lr = ghist_rounds[Lr]
+ full_idx = list(range(len(y)))
+
+ # oracle: best constant per converged color over ALL data -> irreducible by any MPNN
+ conv_mean_full = _group_mean(conv, y, full_idx)
+ e_oracle = np.array([conv_mean_full[conv[i]] - y[i] for i in eval_idx])
+
+ # train-fitted (achievable with this split); fallback to global mean on unseen colors
+ conv_mean_tr = _group_mean(conv, y, train_idx)
+ lr_mean_tr = _group_mean(lr, y, train_idx)
+ gmean = float(y[list(train_idx)].mean())
+ e_conv_tr = np.array([conv_mean_tr.get(conv[i], gmean) - y[i] for i in eval_idx])
+ e_lr_tr = np.array([lr_mean_tr.get(lr[i], gmean) - y[i] for i in eval_idx])
+
+ conv_count = Counter(conv[i] for i in full_idx)
+ train_colors = set(conv[i] for i in train_idx)
+ frac_unseen = float(np.mean([conv[i] not in train_colors for i in eval_idx]))
+ frac_singleton = float(np.mean([conv_count[conv[i]] == 1 for i in eval_idx]))
+ return {
+ 'mse_floor_oracle_H2': float((e_oracle ** 2).mean()), # TRUE 1-WL ceiling
+ 'mse_floor_converged_train': float((e_conv_tr ** 2).mean()),
+ 'mse_floor_Ldepth_train': float((e_lr_tr ** 2).mean()),
+ 'frac_test_unseen_color': frac_unseen,
+ 'frac_test_singleton_color': frac_singleton,
+ 'L_used': Lr, 'conv_round': conv_round,
+ 'var_target_eval': float(y[list(eval_idx)].var()),
+ }
+
+
+def _group_mean(colors, y, idx):
+ acc = defaultdict(list)
+ for i in idx:
+ acc[colors[i]].append(float(y[i]))
+ return {k: float(np.mean(v)) for k, v in acc.items()}