summaryrefslogtreecommitdiff
path: root/experiments/run_pepita_baseline.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
commitbd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch)
tree7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /experiments/run_pepita_baseline.py
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'experiments/run_pepita_baseline.py')
-rw-r--r--experiments/run_pepita_baseline.py188
1 files changed, 188 insertions, 0 deletions
diff --git a/experiments/run_pepita_baseline.py b/experiments/run_pepita_baseline.py
new file mode 100644
index 0000000..2327217
--- /dev/null
+++ b/experiments/run_pepita_baseline.py
@@ -0,0 +1,188 @@
+#!/usr/bin/env python3
+"""H2: PEPITA (Dellaferrera & Kreiman 2022) adapted to GCN L=6.
+
+Algorithm (per batch / full graph):
+ 1. Forward pass 1 (clean): X -> H_0, ..., H_{L-2}, Z_out
+ 2. Compute E0 = softmax(Z_out) - y_onehot (masked to train nodes, unscaled)
+ 3. Project error to input: X_mod = X - E0 @ F (F is fixed random: C × d_in)
+ 4. Forward pass 2 (modulated): X_mod -> H_0^m, ..., H_{L-2}^m
+ 5. Weight updates:
+ Output layer W_{L-1}: standard gradient via E0 (only place BP-like)
+ Hidden layer W_l (l < L-1): gradient ~ (agg_input_l)^T @ relu_gate * (H_l^clean - H_l^mod)
+
+Runs on 4 datasets × 20 seeds at GCN L=6.
+"""
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import json
+import os
+from src.data import load_dataset, spmm
+from src.trainers import _FeedbackTrainerBase, label_spreading
+from run_dblp_depth import load_dblp
+
+device = 'cuda:0'
+SEEDS = list(range(20))
+EPOCHS = 200
+OUT_DIR = 'results/pepita_baseline_20seeds'
+
+
+class PEPITATrainer(_FeedbackTrainerBase):
+ """PEPITA backward rule for GCN."""
+
+ def __init__(self, data, hidden_dim, lr, weight_decay,
+ diffusion_alpha=0.5, diffusion_iters=10,
+ num_layers=2, residual_alpha=0.0, backbone='gcn',
+ pepita_fb_scale=0.05, **_kw):
+ super().__init__(data, hidden_dim, lr, weight_decay,
+ diffusion_alpha, diffusion_iters,
+ num_layers, residual_alpha, backbone,
+ _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0))
+ # Fixed random feedback: C × d_in, projects output error back to input
+ self.F_fb = torch.randn(self.d_out, self.d_in, device=self.device) * pepita_fb_scale
+
+ def _pepita_output_error_unscaled(self, Z_out):
+ """Raw error (not divided by n_labeled) for perturbation purposes."""
+ mask = self.data['train_mask']
+ y = self.data['y']
+ probs = F.softmax(Z_out.detach(), dim=1)
+ y_oh = F.one_hot(y, self.d_out).float()
+ E = torch.zeros_like(probs)
+ E[mask] = probs[mask] - y_oh[mask]
+ return E
+
+ def train_step(self):
+ # Pass 1: clean forward
+ Z_out_clean, inter_clean = self.forward()
+ # Perturbation error (unscaled)
+ E_unscaled = self._pepita_output_error_unscaled(Z_out_clean)
+ # Gradient error (scaled by n_labeled) for output layer
+ E0_scaled, _ = self._output_error(Z_out_clean)
+
+ # Modulate input
+ X_orig = self.data['X']
+ X_mod = X_orig - E_unscaled @ self.F_fb
+
+ # Pass 2: modulated forward
+ self.data['X'] = X_mod
+ try:
+ Z_out_mod, inter_mod = self.forward()
+ finally:
+ self.data['X'] = X_orig
+
+ # Per-layer gradients
+ grads = []
+ for l in range(self.num_layers):
+ if l == self.num_layers - 1:
+ # Output layer: standard gradient via scaled E0
+ H_prev = inter_clean['Hs'][-1] if inter_clean['Hs'] else X_orig
+ g = H_prev.t() @ self._graph_conv_T(E0_scaled, l)
+ else:
+ # Hidden layer: activity difference, relu-gated
+ if l == 0:
+ H_prev = X_orig
+ else:
+ H_prev = inter_clean['Hs'][l - 1]
+
+ relu_gate = (inter_clean['Zs'][l].detach() > 0).float()
+ # activity difference (post-ReLU)
+ delta_post = inter_clean['Hs'][l] - inter_mod['Hs'][l]
+ # scale by n_labeled like BP does
+ n_labeled = self.data['train_mask'].sum().float().clamp(min=1.0)
+ delta = relu_gate * delta_post / n_labeled
+
+ g = H_prev.t() @ self._graph_conv_T(delta, l)
+ grads.append(g)
+
+ # Apply Adam
+ if self._use_adam:
+ self._adam_t += 1
+ for i in range(self.num_layers):
+ self.weights[i] = self.weights[i] - self._adam_step(i, grads[i])
+ else:
+ for i in range(self.num_layers):
+ self.weights[i] = self.weights[i] - self.lr * (grads[i] + self.wd * self.weights[i])
+
+ with torch.no_grad():
+ mask = self.data['train_mask']
+ loss = F.cross_entropy(Z_out_clean[mask], self.data['y'][mask]).item()
+ acc = (Z_out_clean[mask].argmax(1) == self.data['y'][mask]).float().mean().item()
+ return loss, acc, {}
+
+
+def train_one(cls, common, extra, seed):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ t = cls(**common, **extra)
+ bv, bt = 0, 0
+ for ep in range(EPOCHS):
+ t.train_step()
+ if ep % 5 == 0:
+ v = t.evaluate('val_mask')
+ te = t.evaluate('test_mask')
+ if v > bv: bv, bt = v, te
+ del t; torch.cuda.empty_cache()
+ return bt
+
+
+def main():
+ os.makedirs(OUT_DIR, exist_ok=True)
+ per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json')
+ if os.path.exists(per_seed_file):
+ with open(per_seed_file) as f:
+ per_seed_data = json.load(f)
+ else:
+ per_seed_data = {}
+
+ datasets_cfg = {
+ 'Cora': lambda: load_dataset('Cora', device=device),
+ 'CiteSeer': lambda: load_dataset('CiteSeer', device=device),
+ 'PubMed': lambda: load_dataset('PubMed', device=device),
+ 'DBLP': lambda: load_dblp(),
+ }
+
+ for ds_name, loader in datasets_cfg.items():
+ data = loader()
+ common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
+ num_layers=6, residual_alpha=0.0, backbone='gcn')
+
+ key = f"{ds_name}_PEPITA"
+ if key not in per_seed_data:
+ per_seed_data[key] = {}
+
+ print(f"\n=== {key} (20 seeds, GCN L=6) ===", flush=True)
+ for seed in SEEDS:
+ sk = str(seed)
+ if sk in per_seed_data[key]:
+ print(f" seed {seed}: cached ({per_seed_data[key][sk]*100:.1f}%)", flush=True)
+ continue
+ try:
+ acc = train_one(PEPITATrainer, common, {}, seed)
+ per_seed_data[key][sk] = acc
+ print(f" seed {seed}: {acc*100:.1f}%", flush=True)
+ except Exception as e:
+ print(f" seed {seed}: FAILED - {e}", flush=True)
+ per_seed_data[key][sk] = 0.0
+
+ with open(per_seed_file, 'w') as f:
+ json.dump(per_seed_data, f, indent=2)
+
+ del data; torch.cuda.empty_cache()
+
+ # Summary
+ print(f"\n{'=' * 70}\nPEPITA summary (20 seeds, GCN L=6)\n{'=' * 70}")
+ results = {}
+ for ds in datasets_cfg:
+ key = f"{ds}_PEPITA"
+ vals = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100
+ results[key] = {'mean': float(vals.mean()), 'std': float(vals.std()),
+ 'per_seed': vals.tolist()}
+ print(f" {ds:<12} {vals.mean():5.1f} ± {vals.std():4.1f}")
+
+ with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved to {OUT_DIR}/results.json")
+
+
+if __name__ == '__main__':
+ main()