summaryrefslogtreecommitdiff
path: root/experiments/run_cafo_baseline.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
commitbd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch)
tree7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /experiments/run_cafo_baseline.py
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'experiments/run_cafo_baseline.py')
-rw-r--r--experiments/run_cafo_baseline.py198
1 files changed, 198 insertions, 0 deletions
diff --git a/experiments/run_cafo_baseline.py b/experiments/run_cafo_baseline.py
new file mode 100644
index 0000000..3d8c2d7
--- /dev/null
+++ b/experiments/run_cafo_baseline.py
@@ -0,0 +1,198 @@
+#!/usr/bin/env python3
+"""H3: CaFo+CE (Cascaded Forward Learning with Top-Down Feedback, Park et al. 2023).
+
+Greedy layer-wise training for GCN L=6:
+ - Each hidden layer l has an auxiliary classifier W_aux_l: hidden → num_classes
+ - Forward through all layers with .detach() between layers (blocks upstream gradient)
+ - Per-layer CE loss on labeled nodes via auxiliary classifier
+ - Output layer uses standard cross-entropy
+ - No global backprop — each W_l only sees its local loss
+
+Tests CaFo on Cora/CiteSeer/PubMed/DBLP × 20 seeds, GCN L=6.
+"""
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import json
+import os
+from src.data import load_dataset, spmm
+from run_dblp_depth import load_dblp
+
+device = 'cuda:0'
+SEEDS = list(range(20))
+EPOCHS = 200
+OUT_DIR = 'results/cafo_baseline_20seeds'
+
+
+class CaFoTrainer:
+ """CaFo+CE: greedy layer-wise training with per-layer CE loss."""
+
+ def __init__(self, data, hidden_dim, lr, weight_decay,
+ num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw):
+ dev = data['X'].device
+ self.data = data
+ self.device = dev
+ self.lr = lr
+ self.wd = weight_decay
+ self.num_layers = num_layers
+ self.residual_alpha = residual_alpha
+ self.backbone = backbone
+ self._training = True
+
+ d_in = data['num_features']
+ d_out = data['num_classes']
+ self.d_out = d_out
+
+ dims = [d_in] + [hidden_dim] * (num_layers - 1) + [d_out]
+ # Main layer weights — autograd Parameters
+ self.weights = []
+ for i in range(num_layers):
+ w = torch.empty(dims[i], dims[i + 1], device=dev)
+ torch.nn.init.xavier_uniform_(w)
+ w.requires_grad_(True)
+ self.weights.append(w)
+
+ # Auxiliary classifier per hidden layer: hidden_dim -> d_out
+ self.W_aux = []
+ for i in range(num_layers - 1):
+ w_aux = torch.empty(hidden_dim, d_out, device=dev)
+ torch.nn.init.xavier_uniform_(w_aux)
+ w_aux.requires_grad_(True)
+ self.W_aux.append(w_aux)
+
+ params = self.weights + self.W_aux
+ self.optim = torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
+
+ def _gcn_conv(self, H, W):
+ """GCN conv: A_hat @ (H @ W)."""
+ return spmm(self.data['A_hat'], H @ W)
+
+ def train_step(self):
+ X = self.data['X']
+ y = self.data['y']
+ mask = self.data['train_mask']
+
+ self.optim.zero_grad()
+
+ H = X
+ total_loss = 0.0
+ for l in range(self.num_layers):
+ if l > 0:
+ H = H.detach() # block grad flow upstream
+
+ Z = self._gcn_conv(H, self.weights[l])
+
+ if l < self.num_layers - 1:
+ H_new = F.relu(Z)
+ # Auxiliary classifier (projects hidden to classes)
+ Z_aux = H_new @ self.W_aux[l]
+ loss_l = F.cross_entropy(Z_aux[mask], y[mask])
+ loss_l.backward()
+ total_loss += loss_l.item()
+ H = H_new
+ else:
+ # Output layer: standard CE
+ loss_final = F.cross_entropy(Z[mask], y[mask])
+ loss_final.backward()
+ total_loss += loss_final.item()
+
+ self.optim.step()
+
+ with torch.no_grad():
+ Z_out = self._forward_full_detached()
+ acc = (Z_out[mask].argmax(1) == y[mask]).float().mean().item()
+ return total_loss, acc, {}
+
+ def _forward_full_detached(self):
+ """Full forward pass with no_grad for evaluation."""
+ X = self.data['X']
+ H = X
+ for l in range(self.num_layers):
+ Z = self._gcn_conv(H, self.weights[l].detach())
+ if l < self.num_layers - 1:
+ H = F.relu(Z)
+ return Z
+
+ @torch.no_grad()
+ def evaluate(self, mask_name='test_mask'):
+ self._training = False
+ Z = self._forward_full_detached()
+ self._training = True
+ mask = self.data[mask_name]
+ return (Z[mask].argmax(1) == self.data['y'][mask]).float().mean().item()
+
+
+def train_one(seed, data, num_layers=6):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ t = CaFoTrainer(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
+ num_layers=num_layers, residual_alpha=0.0, backbone='gcn')
+ bv, bt = 0, 0
+ for ep in range(EPOCHS):
+ t.train_step()
+ if ep % 5 == 0:
+ v = t.evaluate('val_mask')
+ te = t.evaluate('test_mask')
+ if v > bv: bv, bt = v, te
+ del t; torch.cuda.empty_cache()
+ return bt
+
+
+def main():
+ os.makedirs(OUT_DIR, exist_ok=True)
+ per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json')
+ if os.path.exists(per_seed_file):
+ with open(per_seed_file) as f:
+ per_seed_data = json.load(f)
+ else:
+ per_seed_data = {}
+
+ datasets_cfg = {
+ 'Cora': lambda: load_dataset('Cora', device=device),
+ 'CiteSeer': lambda: load_dataset('CiteSeer', device=device),
+ 'PubMed': lambda: load_dataset('PubMed', device=device),
+ 'DBLP': lambda: load_dblp(),
+ }
+
+ for ds_name, loader in datasets_cfg.items():
+ data = loader()
+ key = f"{ds_name}_CaFo+CE"
+ if key not in per_seed_data:
+ per_seed_data[key] = {}
+
+ print(f"\n=== {key} (20 seeds, GCN L=6) ===", flush=True)
+ for seed in SEEDS:
+ sk = str(seed)
+ if sk in per_seed_data[key]:
+ print(f" seed {seed}: cached ({per_seed_data[key][sk]*100:.1f}%)", flush=True)
+ continue
+ try:
+ acc = train_one(seed, data)
+ per_seed_data[key][sk] = acc
+ print(f" seed {seed}: {acc*100:.1f}%", flush=True)
+ except Exception as e:
+ print(f" seed {seed}: FAILED - {e}", flush=True)
+ per_seed_data[key][sk] = 0.0
+
+ with open(per_seed_file, 'w') as f:
+ json.dump(per_seed_data, f, indent=2)
+
+ del data; torch.cuda.empty_cache()
+
+ # Summary
+ print(f"\n{'=' * 70}\nCaFo+CE summary (20 seeds, GCN L=6)\n{'=' * 70}")
+ results = {}
+ for ds in datasets_cfg:
+ key = f"{ds}_CaFo+CE"
+ vals = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100
+ results[key] = {'mean': float(vals.mean()), 'std': float(vals.std()),
+ 'per_seed': vals.tolist()}
+ print(f" {ds:<12} {vals.mean():5.1f} ± {vals.std():4.1f}")
+
+ with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved to {OUT_DIR}/results.json")
+
+
+if __name__ == '__main__':
+ main()