diff options
Diffstat (limited to 'diag/wl.py')
| -rw-r--r-- | diag/wl.py | 166 |
1 files changed, 166 insertions, 0 deletions
diff --git a/diag/wl.py b/diag/wl.py new file mode 100644 index 0000000..26ab0a3 --- /dev/null +++ b/diag/wl.py @@ -0,0 +1,166 @@ +"""1-WL color-refinement instrument for diagnosing GNN failures (H1 vs H2). + +A GIN with L layers == L rounds of 1-WL refinement (injective sum aggregation). +A failure on sample i is attributed by label purity of its WL color classes: + + converged-WL class IMPURE (train labels conflict under same color) + -> H2 : 1-WL ceiling. No MPNN at ANY depth separates -> needs >1-WL (noise). + converged pure, but L-round class impure + -> H1a_depth : separable only with MORE rounds -> deterministic RR-on-graph / depth helps. + L-round class pure (info present at depth L) but model wrong + -> H1b_opt : optimization / capacity. Train better. + +Refinement is dataset-global (shared per-round signature->label map) so node colors and +graph-color histograms are comparable across graphs. +""" +from collections import Counter, defaultdict +import numpy as np + + +def edges_to_adj(n, edge_index): + adj = [[] for _ in range(n)] + ei = np.asarray(edge_index) + for a, b in zip(ei[0].tolist(), ei[1].tolist()): + adj[a].append(b) + return adj + + +def wl_refine(adjs, inits=None, max_rounds=None): + """Dataset-level 1-WL. Returns (node_rounds, ghist_rounds, conv_round). + node_rounds[r][g] = int color array (global labels) of graph g after r rounds. + ghist_rounds[r][g] = canonical color histogram (hashable) of graph g after r rounds. + conv_round = round index at which the global partition stabilized. + """ + if inits is None: + inits = [np.zeros(len(a), dtype=np.int64) for a in adjs] + else: + inits = [np.asarray(x, dtype=np.int64) for x in inits] + if max_rounds is None: + max_rounds = max((len(a) for a in adjs), default=0) + 2 + + d = {} + def lab(s): + v = d.get(s) + if v is None: + v = len(d); d[s] = v + return v + + cur = [np.array([lab(('i', int(c))) for c in init], dtype=np.int64) for init in inits] + node_rounds = [cur] + nclasses = [len(d)] + + for _r in range(max_rounds): + d = {} + nxt = [] + for adj in adjs: + c = cur_g = node_rounds[-1][len(nxt)] + arr = np.empty(len(adj), dtype=np.int64) + for v in range(len(adj)): + sig = (int(c[v]), tuple(sorted(int(c[u]) for u in adj[v]))) + arr[v] = lab(sig) + nxt.append(arr) + node_rounds.append(nxt) + nclasses.append(len(d)) + if nclasses[-1] == nclasses[-2]: # global #classes stopped growing -> converged + break + + conv_round = len(node_rounds) - 1 + ghist_rounds = [[_hist(c) for c in nr] for nr in node_rounds] + return node_rounds, ghist_rounds, conv_round + + +def _hist(colors): + return tuple(sorted(Counter(colors.tolist()).items())) + + +def graph_colors_at(ghist_rounds, conv_round, L): + return ghist_rounds[min(L, conv_round)] + + +# ---------- classification attribution ---------- +def attribute_classification(ghist_rounds, conv_round, L, y, train_idx, eval_idx): + y = np.asarray(y) + conv = ghist_rounds[conv_round] + Lr = min(L, conv_round) + lr = ghist_rounds[Lr] + conv_train, lr_train = defaultdict(list), defaultdict(list) + for i in train_idx: + conv_train[conv[i]].append(int(y[i])) + lr_train[lr[i]].append(int(y[i])) + + def pure(dct, key): + labs = dct.get(key) + return labs is not None and len(set(labs)) == 1 + + def majority(dct, key): + labs = dct.get(key) + return Counter(labs).most_common(1)[0][0] if labs else None + + buckets = {} + wl_opt = lr_opt = 0 + for i in eval_idx: + if conv[i] not in conv_train: + buckets[i] = 'novel' + elif not pure(conv_train, conv[i]): + buckets[i] = 'H2' + elif not pure(lr_train, lr[i]): + buckets[i] = 'H1a_depth' + else: + buckets[i] = 'H1b_opt' + if majority(conv_train, conv[i]) == int(y[i]): + wl_opt += 1 + if majority(lr_train, lr[i]) == int(y[i]): + lr_opt += 1 + n = len(eval_idx) + return { + 'buckets': buckets, + 'counts': dict(Counter(buckets.values())), + 'wl_optimal_acc_converged': wl_opt / n, # best ANY MPNN can do + 'wl_optimal_acc_Ldepth': lr_opt / n, # best L-layer MPNN can do + 'L_used': Lr, 'conv_round': conv_round, + } + + +# ---------- regression decomposition ---------- +def decompose_regression(ghist_rounds, conv_round, L, y, train_idx, eval_idx): + """H2 floor = ORACLE within-color variance on FULL data (best possible function of the WL + color: do same-color graphs share the target?). This is the true information ceiling and is + NOT confounded by train/test coverage. The train-fitted floors are also reported to expose + how much apparent error is really novel-color generalization, plus coverage fractions.""" + y = np.asarray(y, dtype=np.float64) + conv = ghist_rounds[conv_round] + Lr = min(L, conv_round) + lr = ghist_rounds[Lr] + full_idx = list(range(len(y))) + + # oracle: best constant per converged color over ALL data -> irreducible by any MPNN + conv_mean_full = _group_mean(conv, y, full_idx) + e_oracle = np.array([conv_mean_full[conv[i]] - y[i] for i in eval_idx]) + + # train-fitted (achievable with this split); fallback to global mean on unseen colors + conv_mean_tr = _group_mean(conv, y, train_idx) + lr_mean_tr = _group_mean(lr, y, train_idx) + gmean = float(y[list(train_idx)].mean()) + e_conv_tr = np.array([conv_mean_tr.get(conv[i], gmean) - y[i] for i in eval_idx]) + e_lr_tr = np.array([lr_mean_tr.get(lr[i], gmean) - y[i] for i in eval_idx]) + + conv_count = Counter(conv[i] for i in full_idx) + train_colors = set(conv[i] for i in train_idx) + frac_unseen = float(np.mean([conv[i] not in train_colors for i in eval_idx])) + frac_singleton = float(np.mean([conv_count[conv[i]] == 1 for i in eval_idx])) + return { + 'mse_floor_oracle_H2': float((e_oracle ** 2).mean()), # TRUE 1-WL ceiling + 'mse_floor_converged_train': float((e_conv_tr ** 2).mean()), + 'mse_floor_Ldepth_train': float((e_lr_tr ** 2).mean()), + 'frac_test_unseen_color': frac_unseen, + 'frac_test_singleton_color': frac_singleton, + 'L_used': Lr, 'conv_round': conv_round, + 'var_target_eval': float(y[list(eval_idx)].var()), + } + + +def _group_mean(colors, y, idx): + acc = defaultdict(list) + for i in idx: + acc[colors[i]].append(float(y[i])) + return {k: float(np.mean(v)) for k, v in acc.items()} |
