#!/usr/bin/env python3 """Task 2ceadaa7: KAFT + Forward Tricks combo experiments (20 seeds). Combos: KAFT+ResGCN, KAFT+DropEdge, KAFT+PairNorm, KAFT+JKNet Each compared to: BP, forward_trick_only, GRAFT_only, combo """ import torch import torch.nn.functional as F import numpy as np import json import os from scipy import stats as scipy_stats from src.data import load_dataset, spmm, build_normalized_adj from src.trainers import BPTrainer, KAFTTrainer, _FeedbackTrainerBase from run_deep_baselines import ResGCNTrainer, JKNetTrainer from run_dropedge import BPDropEdgeTrainer from run_pairnorm_baseline import BPPairNormTrainer, pairnorm from run_dblp_depth import load_dblp device = 'cuda:0' SEEDS = list(range(20)) EPOCHS = 200 OUT_DIR = 'results/combo_20seeds' grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10, lr_feedback=0.5, num_probes=64, topo_mode='fixed_A') # ═══════════════════════════════════════════════════════════════════════════ # KAFT + ResGCN combo (fixed version) # ═══════════════════════════════════════════════════════════════════════════ class GRAFTResGCN(KAFTTrainer): """KAFT backward + ResGCN forward (skip connections).""" def forward(self): X = self.data['X'] H = X H0 = None Hs, Zs = [], [] for l in range(self.num_layers): Z = self._graph_conv(H, self.weights[l], l) Zs.append(Z) if l < self.num_layers - 1: H_new = F.relu(Z) if H_new.size(1) == H.size(1): H = H + H_new else: H = H_new Hs.append(H) if l == 0: H0 = H else: return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} # ═══════════════════════════════════════════════════════════════════════════ # KAFT + DropEdge combo # ═══════════════════════════════════════════════════════════════════════════ class GRAFTDropEdge(KAFTTrainer): """KAFT backward + DropEdge forward (random edge dropping).""" def __init__(self, *args, drop_rate=0.5, **kwargs): super().__init__(*args, **kwargs) self.drop_rate = drop_rate self._A_hat_orig = self.data['A_hat'] self._edge_index_orig = self._A_hat_orig.indices() self._edge_values_orig = self._A_hat_orig.values() self._N = self._A_hat_orig.size(0) def _drop_edges(self): if not self._training or self.drop_rate <= 0: return self._A_hat_orig mask = torch.rand(self._edge_values_orig.size(0), device=self._edge_values_orig.device) > self.drop_rate new_vals = self._edge_values_orig * mask.float() / (1 - self.drop_rate) return torch.sparse_coo_tensor( self._edge_index_orig, new_vals, (self._N, self._N) ).coalesce() def forward(self): # DropEdge only in forward pass, KAFT backward uses original A_hat self.data['A_hat'] = self._drop_edges() result = super().forward() # uses KAFTTrainer.forward() self.data['A_hat'] = self._A_hat_orig return result def evaluate(self, mask_name='test_mask'): self.data['A_hat'] = self._A_hat_orig return super().evaluate(mask_name) # ═══════════════════════════════════════════════════════════════════════════ # KAFT + PairNorm combo # ═══════════════════════════════════════════════════════════════════════════ class GRAFTPairNorm(KAFTTrainer): """KAFT backward + PairNorm forward (center & scale normalization).""" def __init__(self, *args, pn_scale=1.0, **kwargs): super().__init__(*args, **kwargs) self.pn_scale = pn_scale def forward(self): X = self.data['X'] H = X H0 = None Hs, Zs = [], [] if self.backbone == 'appnp': for l in range(self.num_layers): Z = H @ self.weights[l] Zs.append(Z) if l < self.num_layers - 1: H = F.relu(Z) H = pairnorm(H, self.pn_scale) Hs.append(H) if l == 0: H0 = H else: Z = self._appnp_propagate(Z) Zs[-1] = Z return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} for l in range(self.num_layers): Z = self._graph_conv(H, self.weights[l], l) Zs.append(Z) if l < self.num_layers - 1: H = F.relu(Z) H = pairnorm(H, self.pn_scale) Hs.append(H) if l == 0: H0 = H else: return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} # ═══════════════════════════════════════════════════════════════════════════ # KAFT + JKNet combo # ═══════════════════════════════════════════════════════════════════════════ class GRAFTJKNet(KAFTTrainer): """KAFT backward + JKNet forward (jumping knowledge max-pool). Note: JKNet changes the output architecture. We max-pool hidden layers and project to num_classes. KAFT backward operates on hidden layers as usual; the JK projection is treated as the output layer. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # JK projection: hidden_dim -> num_classes self.jk_proj = torch.randn(self.hidden_dim, self.d_out, device=self.device) * 0.01 # Add to Adam state self._adam.append({'m': torch.zeros_like(self.jk_proj), 'v': torch.zeros_like(self.jk_proj)}) def forward(self): X = self.data['X'] H = X H0 = None Hs, Zs = [], [] for l in range(self.num_layers): Z = self._graph_conv(H, self.weights[l], l) Zs.append(Z) if l < self.num_layers - 1: H = F.relu(Z) Hs.append(H) if l == 0: H0 = H # JK max-pool over hidden layers if Hs: stacked = torch.stack(Hs, dim=0) # (L-1, N, d) jk_repr = stacked.max(dim=0)[0] # (N, d) Z_out = jk_repr @ self.jk_proj # Override Hs[-1] for backward compatibility with _update_weights # which uses Hs[-1] as input to output layer Hs_for_backward = list(Hs) Hs_for_backward[-1] = jk_repr return Z_out, {'Hs': Hs_for_backward, 'Zs': Zs, 'H0': H0} else: return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0} def _update_weights(self, inter, E0, deltas): """Override to handle JK projection separately.""" # Update hidden layers using KAFT feedback as usual X = self.data['X'] Hs = inter['Hs'] H0 = inter['H0'] grads = [] for l in range(self.num_layers): if l == self.num_layers - 1: # Skip the original output layer — JK projection replaces it # But still compute gradient for W_L (unused in JK, but keeps indexing) H_prev = Hs[-1] if Hs else X g = H_prev.t() @ self._graph_conv_T(E0, l) else: if l == 0: H_in = X else: H_prev = Hs[l - 1] if self.residual_alpha > 0 and H0 is not None: H_in = (1 - self.residual_alpha) * H_prev + self.residual_alpha * H0 else: H_in = H_prev g = H_in.t() @ self._graph_conv_T(deltas[l], l) grads.append(g) # Update JK projection: grad = jk_repr.T @ E0 jk_repr = Hs[-1] if Hs else X jk_grad = jk_repr.t() @ E0 if self._use_adam: self._adam_t += 1 for i in range(self.num_layers): self.weights[i] = self.weights[i] - self._adam_step(i, grads[i]) # Update jk_proj with Adam (use last index) jk_idx = len(self._adam) - 1 s = self._adam[jk_idx] b1, b2, eps = self._adam_beta1, self._adam_beta2, self._adam_eps t = self._adam_t s['m'] = b1 * s['m'] + (1 - b1) * jk_grad s['v'] = b2 * s['v'] + (1 - b2) * jk_grad ** 2 m_hat = s['m'] / (1 - b1 ** t) v_hat = s['v'] / (1 - b2 ** t) self.jk_proj = self.jk_proj - self.lr * (m_hat / (v_hat.sqrt() + eps) + self.wd * self.jk_proj) else: for i in range(self.num_layers): self.weights[i] = self.weights[i] - self.lr * (grads[i] + self.wd * self.weights[i]) self.jk_proj = self.jk_proj - self.lr * (jk_grad + self.wd * self.jk_proj) # ═══════════════════════════════════════════════════════════════════════════ # Training # ═══════════════════════════════════════════════════════════════════════════ def train_one(cls, common, extra, seed): torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) t = cls(**common, **extra) if hasattr(t, 'align_mode'): t.align_mode = 'chain_norm' bv, bt = 0, 0 for ep in range(EPOCHS): t.train_step() if ep % 5 == 0: v = t.evaluate('val_mask') te = t.evaluate('test_mask') if v > bv: bv, bt = v, te del t; torch.cuda.empty_cache() return bt def main(): os.makedirs(OUT_DIR, exist_ok=True) per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json') if os.path.exists(per_seed_file): with open(per_seed_file) as f: per_seed_data = json.load(f) else: per_seed_data = {} # Reuse existing per-seed data from other experiments # BP, ResGCN, KAFT from resgcn_20seeds try: with open('results/resgcn_20seeds/per_seed_data.json') as f: resgcn_cache = json.load(f) except: resgcn_cache = {} # DropEdge from dropedge_20seeds try: with open('results/dropedge_20seeds/per_seed_data.json') as f: de_cache = json.load(f) except: de_cache = {} # PairNorm from pairnorm_extended try: with open('results/pairnorm_extended/per_seed_data.json') as f: pn_cache = json.load(f) except: pn_cache = {} METHODS = { 'BP': (BPTrainer, {}), 'ResGCN': (ResGCNTrainer, {}), 'DropEdge': (BPDropEdgeTrainer, {'drop_rate': 0.5}), 'PairNorm': (BPPairNormTrainer, {'pn_scale': 1.0}), 'JKNet': (JKNetTrainer, {}), 'KAFT': (KAFTTrainer, grape_extra), 'KAFT+ResGCN': (GRAFTResGCN, grape_extra), 'KAFT+DropEdge': (GRAFTDropEdge, {**grape_extra, 'drop_rate': 0.5}), 'KAFT+PairNorm': (GRAFTPairNorm, {**grape_extra, 'pn_scale': 1.0}), 'KAFT+JKNet': (GRAFTJKNet, grape_extra), } datasets_cfg = { 'Cora': lambda: load_dataset('Cora', device=device), 'CiteSeer': lambda: load_dataset('CiteSeer', device=device), 'DBLP': lambda: load_dblp(), } for ds_name, loader in datasets_cfg.items(): data = loader() common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, num_layers=6, residual_alpha=0.0, backbone='gcn') for mname, (cls, extra) in METHODS.items(): key = f"{ds_name}_{mname}" if key not in per_seed_data: per_seed_data[key] = {} print(f"\n=== {key} (20 seeds) ===", flush=True) for seed in SEEDS: sk = str(seed) if sk in per_seed_data[key]: print(f" seed {seed}: cached", flush=True) continue # Try to pull from existing caches cached = None if mname == 'BP' and f"{ds_name}_BP" in resgcn_cache: cached = resgcn_cache[f"{ds_name}_BP"].get(sk) elif mname == 'ResGCN' and f"{ds_name}_ResGCN" in resgcn_cache: cached = resgcn_cache[f"{ds_name}_ResGCN"].get(sk) elif mname == 'KAFT' and f"{ds_name}_GRAFT" in resgcn_cache: cached = resgcn_cache[f"{ds_name}_GRAFT"].get(sk) elif mname == 'DropEdge': de_key = f"{ds_name}_gcn_L6" if de_key in de_cache and sk in de_cache[de_key]: cached = de_cache[de_key][sk].get('de05') elif mname == 'PairNorm': pn_key = f"{ds_name}_gcn_L6_PN" if pn_key in pn_cache and sk in pn_cache[pn_key]: cached = pn_cache[pn_key][sk] if cached is not None: per_seed_data[key][sk] = cached print(f" seed {seed}: from cache ({cached*100:.1f}%)", flush=True) else: try: acc = train_one(cls, common, extra, seed) per_seed_data[key][sk] = acc print(f" seed {seed}: {acc*100:.1f}%", flush=True) except Exception as e: print(f" seed {seed}: FAILED - {e}", flush=True) per_seed_data[key][sk] = 0.0 # Save after each seed with open(per_seed_file, 'w') as f: json.dump(per_seed_data, f, indent=2) del data; torch.cuda.empty_cache() # ═══════════════════════════════════════════════════════════════════════ # Analysis # ═══════════════════════════════════════════════════════════════════════ results = {} print("\n" + "=" * 120) print("FULL RESULTS TABLE") print("=" * 120) for ds in ['Cora', 'CiteSeer', 'DBLP']: print(f"\n--- {ds} GCN L=6 lr=0.01 ---") print(f"{'Method':<18} {'Mean±Std':>12} {'vs KAFT':>18} {'vs FwdTrick':>18}") print("-" * 70) # Get KAFT accs for comparison gr_accs = np.array([per_seed_data[f"{ds}_GRAFT"][str(s)] for s in SEEDS]) * 100 for mname in ['BP', 'ResGCN', 'DropEdge', 'PairNorm', 'JKNet', 'KAFT', 'KAFT+ResGCN', 'KAFT+DropEdge', 'KAFT+PairNorm', 'KAFT+JKNet']: key = f"{ds}_{mname}" if key not in per_seed_data or len(per_seed_data[key]) < 20: print(f" {mname:<16} MISSING ({len(per_seed_data.get(key, {}))} seeds)") continue accs = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100 m, s = accs.mean(), accs.std() results[key] = {'mean': float(m), 'std': float(s), 'accs': accs.tolist()} # Paired t-test vs KAFT if mname != 'KAFT': t_stat, p_val = scipy_stats.ttest_rel(accs, gr_accs) delta = m - gr_accs.mean() sig = '***' if p_val < 0.001 else ('**' if p_val < 0.01 else ('*' if p_val < 0.05 else 'ns')) vs_graft = f"Δ{delta:+.1f} p={p_val:.4f}{sig}" results[key]['vs_GRAFT'] = { 'delta': float(delta), 'p_value': float(p_val), 'significant': bool(p_val < 0.05) } else: vs_graft = "—" # Paired t-test vs forward trick only fwd_map = { 'KAFT+ResGCN': 'ResGCN', 'KAFT+DropEdge': 'DropEdge', 'KAFT+PairNorm': 'PairNorm', 'KAFT+JKNet': 'JKNet' } if mname in fwd_map: fwd_key = f"{ds}_{fwd_map[mname]}" if fwd_key in per_seed_data and len(per_seed_data[fwd_key]) >= 20: fwd_accs = np.array([per_seed_data[fwd_key][str(s)] for s in SEEDS]) * 100 t2, p2 = scipy_stats.ttest_rel(accs, fwd_accs) d2 = m - fwd_accs.mean() s2 = '***' if p2 < 0.001 else ('**' if p2 < 0.01 else ('*' if p2 < 0.05 else 'ns')) vs_fwd = f"Δ{d2:+.1f} p={p2:.4f}{s2}" results[key][f'vs_{fwd_map[mname]}'] = { 'delta': float(d2), 'p_value': float(p2), 'significant': bool(p2 < 0.05) } else: vs_fwd = "N/A" else: vs_fwd = "" print(f" {mname:<16} {m:>5.1f}±{s:<5.1f} {vs_graft:>18} {vs_fwd:>18}") # Summary: which combos are additive? print("\n" + "=" * 80) print("COMBO ADDITIVITY SUMMARY") print("=" * 80) for ds in ['Cora', 'CiteSeer', 'DBLP']: print(f"\n{ds}:") gr_m = results.get(f"{ds}_GRAFT", {}).get('mean', 0) for combo, fwd in [('KAFT+ResGCN', 'ResGCN'), ('KAFT+DropEdge', 'DropEdge'), ('KAFT+PairNorm', 'PairNorm'), ('KAFT+JKNet', 'JKNet')]: ck = f"{ds}_{combo}" fk = f"{ds}_{fwd}" if ck in results and fk in results: c_m = results[ck]['mean'] f_m = results[fk]['mean'] vs_gr = results[ck].get('vs_GRAFT', {}) vs_fw = results[ck].get(f'vs_{fwd}', {}) better_than_both = c_m > gr_m and c_m > f_m marker = "✓ ADDITIVE" if better_than_both else "✗ not additive" print(f" {combo}: {c_m:.1f} | KAFT={gr_m:.1f} | {fwd}={f_m:.1f} → {marker}") # Save with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved to {OUT_DIR}/results.json") if __name__ == '__main__': main()