#!/usr/bin/env python3 """H2: PEPITA (Dellaferrera & Kreiman 2022) adapted to GCN L=6. Algorithm (per batch / full graph): 1. Forward pass 1 (clean): X -> H_0, ..., H_{L-2}, Z_out 2. Compute E0 = softmax(Z_out) - y_onehot (masked to train nodes, unscaled) 3. Project error to input: X_mod = X - E0 @ F (F is fixed random: C × d_in) 4. Forward pass 2 (modulated): X_mod -> H_0^m, ..., H_{L-2}^m 5. Weight updates: Output layer W_{L-1}: standard gradient via E0 (only place BP-like) Hidden layer W_l (l < L-1): gradient ~ (agg_input_l)^T @ relu_gate * (H_l^clean - H_l^mod) Runs on 4 datasets × 20 seeds at GCN L=6. """ import torch import torch.nn.functional as F import numpy as np import json import os from src.data import load_dataset, spmm from src.trainers import _FeedbackTrainerBase, label_spreading from run_dblp_depth import load_dblp device = 'cuda:0' SEEDS = list(range(20)) EPOCHS = 200 OUT_DIR = 'results/pepita_baseline_20seeds' class PEPITATrainer(_FeedbackTrainerBase): """PEPITA backward rule for GCN.""" def __init__(self, data, hidden_dim, lr, weight_decay, diffusion_alpha=0.5, diffusion_iters=10, num_layers=2, residual_alpha=0.0, backbone='gcn', pepita_fb_scale=0.05, **_kw): super().__init__(data, hidden_dim, lr, weight_decay, diffusion_alpha, diffusion_iters, num_layers, residual_alpha, backbone, _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) # Fixed random feedback: C × d_in, projects output error back to input self.F_fb = torch.randn(self.d_out, self.d_in, device=self.device) * pepita_fb_scale def _pepita_output_error_unscaled(self, Z_out): """Raw error (not divided by n_labeled) for perturbation purposes.""" mask = self.data['train_mask'] y = self.data['y'] probs = F.softmax(Z_out.detach(), dim=1) y_oh = F.one_hot(y, self.d_out).float() E = torch.zeros_like(probs) E[mask] = probs[mask] - y_oh[mask] return E def train_step(self): # Pass 1: clean forward Z_out_clean, inter_clean = self.forward() # Perturbation error (unscaled) E_unscaled = self._pepita_output_error_unscaled(Z_out_clean) # Gradient error (scaled by n_labeled) for output layer E0_scaled, _ = self._output_error(Z_out_clean) # Modulate input X_orig = self.data['X'] X_mod = X_orig - E_unscaled @ self.F_fb # Pass 2: modulated forward self.data['X'] = X_mod try: Z_out_mod, inter_mod = self.forward() finally: self.data['X'] = X_orig # Per-layer gradients grads = [] for l in range(self.num_layers): if l == self.num_layers - 1: # Output layer: standard gradient via scaled E0 H_prev = inter_clean['Hs'][-1] if inter_clean['Hs'] else X_orig g = H_prev.t() @ self._graph_conv_T(E0_scaled, l) else: # Hidden layer: activity difference, relu-gated if l == 0: H_prev = X_orig else: H_prev = inter_clean['Hs'][l - 1] relu_gate = (inter_clean['Zs'][l].detach() > 0).float() # activity difference (post-ReLU) delta_post = inter_clean['Hs'][l] - inter_mod['Hs'][l] # scale by n_labeled like BP does n_labeled = self.data['train_mask'].sum().float().clamp(min=1.0) delta = relu_gate * delta_post / n_labeled g = H_prev.t() @ self._graph_conv_T(delta, l) grads.append(g) # Apply Adam if self._use_adam: self._adam_t += 1 for i in range(self.num_layers): self.weights[i] = self.weights[i] - self._adam_step(i, grads[i]) else: for i in range(self.num_layers): self.weights[i] = self.weights[i] - self.lr * (grads[i] + self.wd * self.weights[i]) with torch.no_grad(): mask = self.data['train_mask'] loss = F.cross_entropy(Z_out_clean[mask], self.data['y'][mask]).item() acc = (Z_out_clean[mask].argmax(1) == self.data['y'][mask]).float().mean().item() return loss, acc, {} def train_one(cls, common, extra, seed): torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) t = cls(**common, **extra) bv, bt = 0, 0 for ep in range(EPOCHS): t.train_step() if ep % 5 == 0: v = t.evaluate('val_mask') te = t.evaluate('test_mask') if v > bv: bv, bt = v, te del t; torch.cuda.empty_cache() return bt def main(): os.makedirs(OUT_DIR, exist_ok=True) per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json') if os.path.exists(per_seed_file): with open(per_seed_file) as f: per_seed_data = json.load(f) else: per_seed_data = {} datasets_cfg = { 'Cora': lambda: load_dataset('Cora', device=device), 'CiteSeer': lambda: load_dataset('CiteSeer', device=device), 'PubMed': lambda: load_dataset('PubMed', device=device), 'DBLP': lambda: load_dblp(), } for ds_name, loader in datasets_cfg.items(): data = loader() common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, num_layers=6, residual_alpha=0.0, backbone='gcn') key = f"{ds_name}_PEPITA" if key not in per_seed_data: per_seed_data[key] = {} print(f"\n=== {key} (20 seeds, GCN L=6) ===", flush=True) for seed in SEEDS: sk = str(seed) if sk in per_seed_data[key]: print(f" seed {seed}: cached ({per_seed_data[key][sk]*100:.1f}%)", flush=True) continue try: acc = train_one(PEPITATrainer, common, {}, seed) per_seed_data[key][sk] = acc print(f" seed {seed}: {acc*100:.1f}%", flush=True) except Exception as e: print(f" seed {seed}: FAILED - {e}", flush=True) per_seed_data[key][sk] = 0.0 with open(per_seed_file, 'w') as f: json.dump(per_seed_data, f, indent=2) del data; torch.cuda.empty_cache() # Summary print(f"\n{'=' * 70}\nPEPITA summary (20 seeds, GCN L=6)\n{'=' * 70}") results = {} for ds in datasets_cfg: key = f"{ds}_PEPITA" vals = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100 results[key] = {'mean': float(vals.mean()), 'std': float(vals.std()), 'per_seed': vals.tolist()} print(f" {ds:<12} {vals.mean():5.1f} ± {vals.std():4.1f}") with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved to {OUT_DIR}/results.json") if __name__ == '__main__': main()