From bd9333eda60a9029a198acaeacb1eca4312bd1e8 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 4 May 2026 23:05:16 -0500 Subject: =?UTF-8?q?Initial=20release:=20GRAFT=20(KAFT)=20=E2=80=94=20NeurI?= =?UTF-8?q?PS=202026=20submission=20code?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%. --- experiments/run_pepita_baseline.py | 188 +++++++++++++++++++++++++++++++++++++ 1 file changed, 188 insertions(+) create mode 100644 experiments/run_pepita_baseline.py (limited to 'experiments/run_pepita_baseline.py') diff --git a/experiments/run_pepita_baseline.py b/experiments/run_pepita_baseline.py new file mode 100644 index 0000000..2327217 --- /dev/null +++ b/experiments/run_pepita_baseline.py @@ -0,0 +1,188 @@ +#!/usr/bin/env python3 +"""H2: PEPITA (Dellaferrera & Kreiman 2022) adapted to GCN L=6. + +Algorithm (per batch / full graph): + 1. Forward pass 1 (clean): X -> H_0, ..., H_{L-2}, Z_out + 2. Compute E0 = softmax(Z_out) - y_onehot (masked to train nodes, unscaled) + 3. Project error to input: X_mod = X - E0 @ F (F is fixed random: C × d_in) + 4. Forward pass 2 (modulated): X_mod -> H_0^m, ..., H_{L-2}^m + 5. Weight updates: + Output layer W_{L-1}: standard gradient via E0 (only place BP-like) + Hidden layer W_l (l < L-1): gradient ~ (agg_input_l)^T @ relu_gate * (H_l^clean - H_l^mod) + +Runs on 4 datasets × 20 seeds at GCN L=6. +""" + +import torch +import torch.nn.functional as F +import numpy as np +import json +import os +from src.data import load_dataset, spmm +from src.trainers import _FeedbackTrainerBase, label_spreading +from run_dblp_depth import load_dblp + +device = 'cuda:0' +SEEDS = list(range(20)) +EPOCHS = 200 +OUT_DIR = 'results/pepita_baseline_20seeds' + + +class PEPITATrainer(_FeedbackTrainerBase): + """PEPITA backward rule for GCN.""" + + def __init__(self, data, hidden_dim, lr, weight_decay, + diffusion_alpha=0.5, diffusion_iters=10, + num_layers=2, residual_alpha=0.0, backbone='gcn', + pepita_fb_scale=0.05, **_kw): + super().__init__(data, hidden_dim, lr, weight_decay, + diffusion_alpha, diffusion_iters, + num_layers, residual_alpha, backbone, + _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) + # Fixed random feedback: C × d_in, projects output error back to input + self.F_fb = torch.randn(self.d_out, self.d_in, device=self.device) * pepita_fb_scale + + def _pepita_output_error_unscaled(self, Z_out): + """Raw error (not divided by n_labeled) for perturbation purposes.""" + mask = self.data['train_mask'] + y = self.data['y'] + probs = F.softmax(Z_out.detach(), dim=1) + y_oh = F.one_hot(y, self.d_out).float() + E = torch.zeros_like(probs) + E[mask] = probs[mask] - y_oh[mask] + return E + + def train_step(self): + # Pass 1: clean forward + Z_out_clean, inter_clean = self.forward() + # Perturbation error (unscaled) + E_unscaled = self._pepita_output_error_unscaled(Z_out_clean) + # Gradient error (scaled by n_labeled) for output layer + E0_scaled, _ = self._output_error(Z_out_clean) + + # Modulate input + X_orig = self.data['X'] + X_mod = X_orig - E_unscaled @ self.F_fb + + # Pass 2: modulated forward + self.data['X'] = X_mod + try: + Z_out_mod, inter_mod = self.forward() + finally: + self.data['X'] = X_orig + + # Per-layer gradients + grads = [] + for l in range(self.num_layers): + if l == self.num_layers - 1: + # Output layer: standard gradient via scaled E0 + H_prev = inter_clean['Hs'][-1] if inter_clean['Hs'] else X_orig + g = H_prev.t() @ self._graph_conv_T(E0_scaled, l) + else: + # Hidden layer: activity difference, relu-gated + if l == 0: + H_prev = X_orig + else: + H_prev = inter_clean['Hs'][l - 1] + + relu_gate = (inter_clean['Zs'][l].detach() > 0).float() + # activity difference (post-ReLU) + delta_post = inter_clean['Hs'][l] - inter_mod['Hs'][l] + # scale by n_labeled like BP does + n_labeled = self.data['train_mask'].sum().float().clamp(min=1.0) + delta = relu_gate * delta_post / n_labeled + + g = H_prev.t() @ self._graph_conv_T(delta, l) + grads.append(g) + + # Apply Adam + if self._use_adam: + self._adam_t += 1 + for i in range(self.num_layers): + self.weights[i] = self.weights[i] - self._adam_step(i, grads[i]) + else: + for i in range(self.num_layers): + self.weights[i] = self.weights[i] - self.lr * (grads[i] + self.wd * self.weights[i]) + + with torch.no_grad(): + mask = self.data['train_mask'] + loss = F.cross_entropy(Z_out_clean[mask], self.data['y'][mask]).item() + acc = (Z_out_clean[mask].argmax(1) == self.data['y'][mask]).float().mean().item() + return loss, acc, {} + + +def train_one(cls, common, extra, seed): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + t = cls(**common, **extra) + bv, bt = 0, 0 + for ep in range(EPOCHS): + t.train_step() + if ep % 5 == 0: + v = t.evaluate('val_mask') + te = t.evaluate('test_mask') + if v > bv: bv, bt = v, te + del t; torch.cuda.empty_cache() + return bt + + +def main(): + os.makedirs(OUT_DIR, exist_ok=True) + per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json') + if os.path.exists(per_seed_file): + with open(per_seed_file) as f: + per_seed_data = json.load(f) + else: + per_seed_data = {} + + datasets_cfg = { + 'Cora': lambda: load_dataset('Cora', device=device), + 'CiteSeer': lambda: load_dataset('CiteSeer', device=device), + 'PubMed': lambda: load_dataset('PubMed', device=device), + 'DBLP': lambda: load_dblp(), + } + + for ds_name, loader in datasets_cfg.items(): + data = loader() + common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, + num_layers=6, residual_alpha=0.0, backbone='gcn') + + key = f"{ds_name}_PEPITA" + if key not in per_seed_data: + per_seed_data[key] = {} + + print(f"\n=== {key} (20 seeds, GCN L=6) ===", flush=True) + for seed in SEEDS: + sk = str(seed) + if sk in per_seed_data[key]: + print(f" seed {seed}: cached ({per_seed_data[key][sk]*100:.1f}%)", flush=True) + continue + try: + acc = train_one(PEPITATrainer, common, {}, seed) + per_seed_data[key][sk] = acc + print(f" seed {seed}: {acc*100:.1f}%", flush=True) + except Exception as e: + print(f" seed {seed}: FAILED - {e}", flush=True) + per_seed_data[key][sk] = 0.0 + + with open(per_seed_file, 'w') as f: + json.dump(per_seed_data, f, indent=2) + + del data; torch.cuda.empty_cache() + + # Summary + print(f"\n{'=' * 70}\nPEPITA summary (20 seeds, GCN L=6)\n{'=' * 70}") + results = {} + for ds in datasets_cfg: + key = f"{ds}_PEPITA" + vals = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100 + results[key] = {'mean': float(vals.mean()), 'std': float(vals.std()), + 'per_seed': vals.tolist()} + print(f" {ds:<12} {vals.mean():5.1f} ± {vals.std():4.1f}") + + with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved to {OUT_DIR}/results.json") + + +if __name__ == '__main__': + main() -- cgit v1.2.3