From bd9333eda60a9029a198acaeacb1eca4312bd1e8 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 4 May 2026 23:05:16 -0500 Subject: =?UTF-8?q?Initial=20release:=20GRAFT=20(KAFT)=20=E2=80=94=20NeurI?= =?UTF-8?q?PS=202026=20submission=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%. --- experiments/run_combo_20seeds.py | 454 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 454 insertions(+) create mode 100644 experiments/run_combo_20seeds.py (limited to 'experiments/run_combo_20seeds.py') diff --git a/experiments/run_combo_20seeds.py b/experiments/run_combo_20seeds.py new file mode 100644 index 0000000..1598964 --- /dev/null +++ b/experiments/run_combo_20seeds.py @@ -0,0 +1,454 @@ +#!/usr/bin/env python3 +"""Task 2ceadaa7: GRAFT + Forward Tricks combo experiments (20 seeds). + +Combos: GRAFT+ResGCN, GRAFT+DropEdge, GRAFT+PairNorm, GRAFT+JKNet +Each compared to: BP, forward_trick_only, GRAFT_only, combo +""" + +import torch +import torch.nn.functional as F +import numpy as np +import json +import os +from scipy import stats as scipy_stats +from src.data import load_dataset, spmm, build_normalized_adj +from src.trainers import BPTrainer, GraphGrAPETrainer, _FeedbackTrainerBase +from run_deep_baselines import ResGCNTrainer, JKNetTrainer +from run_dropedge import BPDropEdgeTrainer +from run_pairnorm_baseline import BPPairNormTrainer, pairnorm +from run_dblp_depth import load_dblp + +device = 'cuda:0' +SEEDS = list(range(20)) +EPOCHS = 200 +OUT_DIR = 'results/combo_20seeds' + +grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10, + lr_feedback=0.5, num_probes=64, topo_mode='fixed_A') + + +# ═══════════════════════════════════════════════════════════════════════════ +# GRAFT + ResGCN combo (fixed version) +# ═══════════════════════════════════════════════════════════════════════════ +class GRAFTResGCN(GraphGrAPETrainer): + """GRAFT backward + ResGCN forward (skip connections).""" + + def forward(self): + X = self.data['X'] + H = X + H0 = None + Hs, Zs = [], [] + + for l in range(self.num_layers): + Z = self._graph_conv(H, self.weights[l], l) + Zs.append(Z) + if l < self.num_layers - 1: + H_new = F.relu(Z) + if H_new.size(1) == H.size(1): + H = H + H_new + else: + H = H_new + Hs.append(H) + if l == 0: + H0 = H + else: + return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} + return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} + + +# ═══════════════════════════════════════════════════════════════════════════ +# GRAFT + DropEdge combo +# ═══════════════════════════════════════════════════════════════════════════ +class GRAFTDropEdge(GraphGrAPETrainer): + """GRAFT backward + DropEdge forward (random edge dropping).""" + + def __init__(self, *args, drop_rate=0.5, **kwargs): + super().__init__(*args, **kwargs) + self.drop_rate = drop_rate + self._A_hat_orig = self.data['A_hat'] + self._edge_index_orig = self._A_hat_orig.indices() + self._edge_values_orig = self._A_hat_orig.values() + self._N = self._A_hat_orig.size(0) + + def _drop_edges(self): + if not self._training or self.drop_rate <= 0: + return self._A_hat_orig + mask = torch.rand(self._edge_values_orig.size(0), + device=self._edge_values_orig.device) > self.drop_rate + new_vals = self._edge_values_orig * mask.float() / (1 - self.drop_rate) + return torch.sparse_coo_tensor( + self._edge_index_orig, new_vals, (self._N, self._N) + ).coalesce() + + def forward(self): + # DropEdge only in forward pass, GRAFT backward uses original A_hat + self.data['A_hat'] = self._drop_edges() + result = super().forward() # uses GraphGrAPETrainer.forward() + self.data['A_hat'] = self._A_hat_orig + return result + + def evaluate(self, mask_name='test_mask'): + self.data['A_hat'] = self._A_hat_orig + return super().evaluate(mask_name) + + +# ═══════════════════════════════════════════════════════════════════════════ +# GRAFT + PairNorm combo +# ═══════════════════════════════════════════════════════════════════════════ +class GRAFTPairNorm(GraphGrAPETrainer): + """GRAFT backward + PairNorm forward (center & scale normalization).""" + + def __init__(self, *args, pn_scale=1.0, **kwargs): + super().__init__(*args, **kwargs) + self.pn_scale = pn_scale + + def forward(self): + X = self.data['X'] + H = X + H0 = None + Hs, Zs = [], [] + + if self.backbone == 'appnp': + for l in range(self.num_layers): + Z = H @ self.weights[l] + Zs.append(Z) + if l < self.num_layers - 1: + H = F.relu(Z) + H = pairnorm(H, self.pn_scale) + Hs.append(H) + if l == 0: + H0 = H + else: + Z = self._appnp_propagate(Z) + Zs[-1] = Z + return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} + + for l in range(self.num_layers): + Z = self._graph_conv(H, self.weights[l], l) + Zs.append(Z) + if l < self.num_layers - 1: + H = F.relu(Z) + H = pairnorm(H, self.pn_scale) + Hs.append(H) + if l == 0: + H0 = H + else: + return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} + return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} + + +# ═══════════════════════════════════════════════════════════════════════════ +# GRAFT + JKNet combo +# ═══════════════════════════════════════════════════════════════════════════ +class GRAFTJKNet(GraphGrAPETrainer): + """GRAFT backward + JKNet forward (jumping knowledge max-pool). + + Note: JKNet changes the output architecture. We max-pool hidden layers + and project to num_classes. GRAFT backward operates on hidden layers + as usual; the JK projection is treated as the output layer. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # JK projection: hidden_dim -> num_classes + self.jk_proj = torch.randn(self.hidden_dim, self.d_out, + device=self.device) * 0.01 + # Add to Adam state + self._adam.append({'m': torch.zeros_like(self.jk_proj), + 'v': torch.zeros_like(self.jk_proj)}) + + def forward(self): + X = self.data['X'] + H = X + H0 = None + Hs, Zs = [], [] + + for l in range(self.num_layers): + Z = self._graph_conv(H, self.weights[l], l) + Zs.append(Z) + if l < self.num_layers - 1: + H = F.relu(Z) + Hs.append(H) + if l == 0: + H0 = H + + # JK max-pool over hidden layers + if Hs: + stacked = torch.stack(Hs, dim=0) # (L-1, N, d) + jk_repr = stacked.max(dim=0)[0] # (N, d) + Z_out = jk_repr @ self.jk_proj + # Override Hs[-1] for backward compatibility with _update_weights + # which uses Hs[-1] as input to output layer + Hs_for_backward = list(Hs) + Hs_for_backward[-1] = jk_repr + return Z_out, {'Hs': Hs_for_backward, 'Zs': Zs, 'H0': H0} + else: + return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} + + def _update_weights(self, inter, E0, deltas): + """Override to handle JK projection separately.""" + # Update hidden layers using GRAFT feedback as usual + X = self.data['X'] + Hs = inter['Hs'] + H0 = inter['H0'] + + grads = [] + for l in range(self.num_layers): + if l == self.num_layers - 1: + # Skip the original output layer — JK projection replaces it + # But still compute gradient for W_L (unused in JK, but keeps indexing) + H_prev = Hs[-1] if Hs else X + g = H_prev.t() @ self._graph_conv_T(E0, l) + else: + if l == 0: + H_in = X + else: + H_prev = Hs[l - 1] + if self.residual_alpha > 0 and H0 is not None: + H_in = (1 - self.residual_alpha) * H_prev + self.residual_alpha * H0 + else: + H_in = H_prev + g = H_in.t() @ self._graph_conv_T(deltas[l], l) + grads.append(g) + + # Update JK projection: grad = jk_repr.T @ E0 + jk_repr = Hs[-1] if Hs else X + jk_grad = jk_repr.t() @ E0 + + if self._use_adam: + self._adam_t += 1 + for i in range(self.num_layers): + self.weights[i] = self.weights[i] - self._adam_step(i, grads[i]) + # Update jk_proj with Adam (use last index) + jk_idx = len(self._adam) - 1 + s = self._adam[jk_idx] + b1, b2, eps = self._adam_beta1, self._adam_beta2, self._adam_eps + t = self._adam_t + s['m'] = b1 * s['m'] + (1 - b1) * jk_grad + s['v'] = b2 * s['v'] + (1 - b2) * jk_grad ** 2 + m_hat = s['m'] / (1 - b1 ** t) + v_hat = s['v'] / (1 - b2 ** t) + self.jk_proj = self.jk_proj - self.lr * (m_hat / (v_hat.sqrt() + eps) + self.wd * self.jk_proj) + else: + for i in range(self.num_layers): + self.weights[i] = self.weights[i] - self.lr * (grads[i] + self.wd * self.weights[i]) + self.jk_proj = self.jk_proj - self.lr * (jk_grad + self.wd * self.jk_proj) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Training +# ═══════════════════════════════════════════════════════════════════════════ + +def train_one(cls, common, extra, seed): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + t = cls(**common, **extra) + if hasattr(t, 'align_mode'): + t.align_mode = 'chain_norm' + bv, bt = 0, 0 + for ep in range(EPOCHS): + t.train_step() + if ep % 5 == 0: + v = t.evaluate('val_mask') + te = t.evaluate('test_mask') + if v > bv: bv, bt = v, te + del t; torch.cuda.empty_cache() + return bt + + +def main(): + os.makedirs(OUT_DIR, exist_ok=True) + per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json') + if os.path.exists(per_seed_file): + with open(per_seed_file) as f: + per_seed_data = json.load(f) + else: + per_seed_data = {} + + # Reuse existing per-seed data from other experiments + # BP, ResGCN, GRAFT from resgcn_20seeds + try: + with open('results/resgcn_20seeds/per_seed_data.json') as f: + resgcn_cache = json.load(f) + except: + resgcn_cache = {} + + # DropEdge from dropedge_20seeds + try: + with open('results/dropedge_20seeds/per_seed_data.json') as f: + de_cache = json.load(f) + except: + de_cache = {} + + # PairNorm from pairnorm_extended + try: + with open('results/pairnorm_extended/per_seed_data.json') as f: + pn_cache = json.load(f) + except: + pn_cache = {} + + METHODS = { + 'BP': (BPTrainer, {}), + 'ResGCN': (ResGCNTrainer, {}), + 'DropEdge': (BPDropEdgeTrainer, {'drop_rate': 0.5}), + 'PairNorm': (BPPairNormTrainer, {'pn_scale': 1.0}), + 'JKNet': (JKNetTrainer, {}), + 'GRAFT': (GraphGrAPETrainer, grape_extra), + 'GRAFT+ResGCN': (GRAFTResGCN, grape_extra), + 'GRAFT+DropEdge': (GRAFTDropEdge, {**grape_extra, 'drop_rate': 0.5}), + 'GRAFT+PairNorm': (GRAFTPairNorm, {**grape_extra, 'pn_scale': 1.0}), + 'GRAFT+JKNet': (GRAFTJKNet, grape_extra), + } + + datasets_cfg = { + 'Cora': lambda: load_dataset('Cora', device=device), + 'CiteSeer': lambda: load_dataset('CiteSeer', device=device), + 'DBLP': lambda: load_dblp(), + } + + for ds_name, loader in datasets_cfg.items(): + data = loader() + common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, + num_layers=6, residual_alpha=0.0, backbone='gcn') + + for mname, (cls, extra) in METHODS.items(): + key = f"{ds_name}_{mname}" + + if key not in per_seed_data: + per_seed_data[key] = {} + + print(f"\n=== {key} (20 seeds) ===", flush=True) + + for seed in SEEDS: + sk = str(seed) + if sk in per_seed_data[key]: + print(f" seed {seed}: cached", flush=True) + continue + + # Try to pull from existing caches + cached = None + if mname == 'BP' and f"{ds_name}_BP" in resgcn_cache: + cached = resgcn_cache[f"{ds_name}_BP"].get(sk) + elif mname == 'ResGCN' and f"{ds_name}_ResGCN" in resgcn_cache: + cached = resgcn_cache[f"{ds_name}_ResGCN"].get(sk) + elif mname == 'GRAFT' and f"{ds_name}_GRAFT" in resgcn_cache: + cached = resgcn_cache[f"{ds_name}_GRAFT"].get(sk) + elif mname == 'DropEdge': + de_key = f"{ds_name}_gcn_L6" + if de_key in de_cache and sk in de_cache[de_key]: + cached = de_cache[de_key][sk].get('de05') + elif mname == 'PairNorm': + pn_key = f"{ds_name}_gcn_L6_PN" + if pn_key in pn_cache and sk in pn_cache[pn_key]: + cached = pn_cache[pn_key][sk] + + if cached is not None: + per_seed_data[key][sk] = cached + print(f" seed {seed}: from cache ({cached*100:.1f}%)", flush=True) + else: + try: + acc = train_one(cls, common, extra, seed) + per_seed_data[key][sk] = acc + print(f" seed {seed}: {acc*100:.1f}%", flush=True) + except Exception as e: + print(f" seed {seed}: FAILED - {e}", flush=True) + per_seed_data[key][sk] = 0.0 + + # Save after each seed + with open(per_seed_file, 'w') as f: + json.dump(per_seed_data, f, indent=2) + + del data; torch.cuda.empty_cache() + + # ═══════════════════════════════════════════════════════════════════════ + # Analysis + # ═══════════════════════════════════════════════════════════════════════ + results = {} + print("\n" + "=" * 120) + print("FULL RESULTS TABLE") + print("=" * 120) + + for ds in ['Cora', 'CiteSeer', 'DBLP']: + print(f"\n--- {ds} GCN L=6 lr=0.01 ---") + print(f"{'Method':<18} {'Mean±Std':>12} {'vs GRAFT':>18} {'vs FwdTrick':>18}") + print("-" * 70) + + # Get GRAFT accs for comparison + gr_accs = np.array([per_seed_data[f"{ds}_GRAFT"][str(s)] for s in SEEDS]) * 100 + + for mname in ['BP', 'ResGCN', 'DropEdge', 'PairNorm', 'JKNet', + 'GRAFT', 'GRAFT+ResGCN', 'GRAFT+DropEdge', 'GRAFT+PairNorm', 'GRAFT+JKNet']: + key = f"{ds}_{mname}" + if key not in per_seed_data or len(per_seed_data[key]) < 20: + print(f" {mname:<16} MISSING ({len(per_seed_data.get(key, {}))} seeds)") + continue + + accs = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100 + m, s = accs.mean(), accs.std() + + results[key] = {'mean': float(m), 'std': float(s), 'accs': accs.tolist()} + + # Paired t-test vs GRAFT + if mname != 'GRAFT': + t_stat, p_val = scipy_stats.ttest_rel(accs, gr_accs) + delta = m - gr_accs.mean() + sig = '***' if p_val < 0.001 else ('**' if p_val < 0.01 else ('*' if p_val < 0.05 else 'ns')) + vs_graft = f"Δ{delta:+.1f} p={p_val:.4f}{sig}" + results[key]['vs_GRAFT'] = { + 'delta': float(delta), 'p_value': float(p_val), + 'significant': bool(p_val < 0.05) + } + else: + vs_graft = "—" + + # Paired t-test vs forward trick only + fwd_map = { + 'GRAFT+ResGCN': 'ResGCN', 'GRAFT+DropEdge': 'DropEdge', + 'GRAFT+PairNorm': 'PairNorm', 'GRAFT+JKNet': 'JKNet' + } + if mname in fwd_map: + fwd_key = f"{ds}_{fwd_map[mname]}" + if fwd_key in per_seed_data and len(per_seed_data[fwd_key]) >= 20: + fwd_accs = np.array([per_seed_data[fwd_key][str(s)] for s in SEEDS]) * 100 + t2, p2 = scipy_stats.ttest_rel(accs, fwd_accs) + d2 = m - fwd_accs.mean() + s2 = '***' if p2 < 0.001 else ('**' if p2 < 0.01 else ('*' if p2 < 0.05 else 'ns')) + vs_fwd = f"Δ{d2:+.1f} p={p2:.4f}{s2}" + results[key][f'vs_{fwd_map[mname]}'] = { + 'delta': float(d2), 'p_value': float(p2), + 'significant': bool(p2 < 0.05) + } + else: + vs_fwd = "N/A" + else: + vs_fwd = "" + + print(f" {mname:<16} {m:>5.1f}±{s:<5.1f} {vs_graft:>18} {vs_fwd:>18}") + + # Summary: which combos are additive? + print("\n" + "=" * 80) + print("COMBO ADDITIVITY SUMMARY") + print("=" * 80) + for ds in ['Cora', 'CiteSeer', 'DBLP']: + print(f"\n{ds}:") + gr_m = results.get(f"{ds}_GRAFT", {}).get('mean', 0) + for combo, fwd in [('GRAFT+ResGCN', 'ResGCN'), ('GRAFT+DropEdge', 'DropEdge'), + ('GRAFT+PairNorm', 'PairNorm'), ('GRAFT+JKNet', 'JKNet')]: + ck = f"{ds}_{combo}" + fk = f"{ds}_{fwd}" + if ck in results and fk in results: + c_m = results[ck]['mean'] + f_m = results[fk]['mean'] + vs_gr = results[ck].get('vs_GRAFT', {}) + vs_fw = results[ck].get(f'vs_{fwd}', {}) + better_than_both = c_m > gr_m and c_m > f_m + marker = "✓ ADDITIVE" if better_than_both else "✗ not additive" + print(f" {combo}: {c_m:.1f} | GRAFT={gr_m:.1f} | {fwd}={f_m:.1f} → {marker}") + + # Save + with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved to {OUT_DIR}/results.json") + + +if __name__ == '__main__': + main() -- cgit v1.2.3