"""1-WL color-refinement instrument for diagnosing GNN failures (H1 vs H2). A GIN with L layers == L rounds of 1-WL refinement (injective sum aggregation). A failure on sample i is attributed by label purity of its WL color classes: converged-WL class IMPURE (train labels conflict under same color) -> H2 : 1-WL ceiling. No MPNN at ANY depth separates -> needs >1-WL (noise). converged pure, but L-round class impure -> H1a_depth : separable only with MORE rounds -> deterministic RR-on-graph / depth helps. L-round class pure (info present at depth L) but model wrong -> H1b_opt : optimization / capacity. Train better. Refinement is dataset-global (shared per-round signature->label map) so node colors and graph-color histograms are comparable across graphs. """ from collections import Counter, defaultdict import numpy as np def edges_to_adj(n, edge_index): adj = [[] for _ in range(n)] ei = np.asarray(edge_index) for a, b in zip(ei[0].tolist(), ei[1].tolist()): adj[a].append(b) return adj def wl_refine(adjs, inits=None, max_rounds=None): """Dataset-level 1-WL. Returns (node_rounds, ghist_rounds, conv_round). node_rounds[r][g] = int color array (global labels) of graph g after r rounds. ghist_rounds[r][g] = canonical color histogram (hashable) of graph g after r rounds. conv_round = round index at which the global partition stabilized. """ if inits is None: inits = [np.zeros(len(a), dtype=np.int64) for a in adjs] else: inits = [np.asarray(x, dtype=np.int64) for x in inits] if max_rounds is None: max_rounds = max((len(a) for a in adjs), default=0) + 2 d = {} def lab(s): v = d.get(s) if v is None: v = len(d); d[s] = v return v cur = [np.array([lab(('i', int(c))) for c in init], dtype=np.int64) for init in inits] node_rounds = [cur] nclasses = [len(d)] for _r in range(max_rounds): d = {} nxt = [] for adj in adjs: c = cur_g = node_rounds[-1][len(nxt)] arr = np.empty(len(adj), dtype=np.int64) for v in range(len(adj)): sig = (int(c[v]), tuple(sorted(int(c[u]) for u in adj[v]))) arr[v] = lab(sig) nxt.append(arr) node_rounds.append(nxt) nclasses.append(len(d)) if nclasses[-1] == nclasses[-2]: # global #classes stopped growing -> converged break conv_round = len(node_rounds) - 1 ghist_rounds = [[_hist(c) for c in nr] for nr in node_rounds] return node_rounds, ghist_rounds, conv_round def _hist(colors): return tuple(sorted(Counter(colors.tolist()).items())) def graph_colors_at(ghist_rounds, conv_round, L): return ghist_rounds[min(L, conv_round)] # ---------- classification attribution ---------- def attribute_classification(ghist_rounds, conv_round, L, y, train_idx, eval_idx): y = np.asarray(y) conv = ghist_rounds[conv_round] Lr = min(L, conv_round) lr = ghist_rounds[Lr] conv_train, lr_train = defaultdict(list), defaultdict(list) for i in train_idx: conv_train[conv[i]].append(int(y[i])) lr_train[lr[i]].append(int(y[i])) def pure(dct, key): labs = dct.get(key) return labs is not None and len(set(labs)) == 1 def majority(dct, key): labs = dct.get(key) return Counter(labs).most_common(1)[0][0] if labs else None buckets = {} wl_opt = lr_opt = 0 for i in eval_idx: if conv[i] not in conv_train: buckets[i] = 'novel' elif not pure(conv_train, conv[i]): buckets[i] = 'H2' elif not pure(lr_train, lr[i]): buckets[i] = 'H1a_depth' else: buckets[i] = 'H1b_opt' if majority(conv_train, conv[i]) == int(y[i]): wl_opt += 1 if majority(lr_train, lr[i]) == int(y[i]): lr_opt += 1 n = len(eval_idx) return { 'buckets': buckets, 'counts': dict(Counter(buckets.values())), 'wl_optimal_acc_converged': wl_opt / n, # best ANY MPNN can do 'wl_optimal_acc_Ldepth': lr_opt / n, # best L-layer MPNN can do 'L_used': Lr, 'conv_round': conv_round, } # ---------- regression decomposition ---------- def decompose_regression(ghist_rounds, conv_round, L, y, train_idx, eval_idx): """H2 floor = ORACLE within-color variance on FULL data (best possible function of the WL color: do same-color graphs share the target?). This is the true information ceiling and is NOT confounded by train/test coverage. The train-fitted floors are also reported to expose how much apparent error is really novel-color generalization, plus coverage fractions.""" y = np.asarray(y, dtype=np.float64) conv = ghist_rounds[conv_round] Lr = min(L, conv_round) lr = ghist_rounds[Lr] full_idx = list(range(len(y))) # oracle: best constant per converged color over ALL data -> irreducible by any MPNN conv_mean_full = _group_mean(conv, y, full_idx) e_oracle = np.array([conv_mean_full[conv[i]] - y[i] for i in eval_idx]) # train-fitted (achievable with this split); fallback to global mean on unseen colors conv_mean_tr = _group_mean(conv, y, train_idx) lr_mean_tr = _group_mean(lr, y, train_idx) gmean = float(y[list(train_idx)].mean()) e_conv_tr = np.array([conv_mean_tr.get(conv[i], gmean) - y[i] for i in eval_idx]) e_lr_tr = np.array([lr_mean_tr.get(lr[i], gmean) - y[i] for i in eval_idx]) conv_count = Counter(conv[i] for i in full_idx) train_colors = set(conv[i] for i in train_idx) frac_unseen = float(np.mean([conv[i] not in train_colors for i in eval_idx])) frac_singleton = float(np.mean([conv_count[conv[i]] == 1 for i in eval_idx])) return { 'mse_floor_oracle_H2': float((e_oracle ** 2).mean()), # TRUE 1-WL ceiling 'mse_floor_converged_train': float((e_conv_tr ** 2).mean()), 'mse_floor_Ldepth_train': float((e_lr_tr ** 2).mean()), 'frac_test_unseen_color': frac_unseen, 'frac_test_singleton_color': frac_singleton, 'L_used': Lr, 'conv_round': conv_round, 'var_target_eval': float(y[list(eval_idx)].var()), } def _group_mean(colors, y, idx): acc = defaultdict(list) for i in idx: acc[colors[i]].append(float(y[i])) return {k: float(np.mean(v)) for k, v in acc.items()}