summaryrefslogtreecommitdiff
path: root/experiments/run_dblp_depth.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
commitbd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch)
tree7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /experiments/run_dblp_depth.py
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'experiments/run_dblp_depth.py')
-rw-r--r--experiments/run_dblp_depth.py162
1 files changed, 162 insertions, 0 deletions
diff --git a/experiments/run_dblp_depth.py b/experiments/run_dblp_depth.py
new file mode 100644
index 0000000..d63b94a
--- /dev/null
+++ b/experiments/run_dblp_depth.py
@@ -0,0 +1,162 @@
+#!/usr/bin/env python3
+"""
+CitationFull-DBLP experiment + depth sweep L=2-6 补数据.
+"""
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import json
+import os
+import time
+from torch_geometric.datasets import CitationFull
+from src.data import build_normalized_adj, build_row_normalized_adj, spmm, precompute_traces
+from src.trainers import BPTrainer, DFATrainer, GraphGrAPETrainer
+from benchmark_efficient import GraphGrAPEEfficient
+
+device = 'cuda:0'
+SEEDS = [0, 1, 2, 3, 4]
+EPOCHS = 200
+OUT_DIR = 'results/dblp_depth'
+
+
+def load_dblp():
+ ds = CitationFull(root='./data', name='DBLP')
+ data = ds[0]
+ N, C = data.num_nodes, ds.num_classes
+ # Random split
+ rng = torch.Generator().manual_seed(42)
+ train_mask = torch.zeros(N, dtype=torch.bool)
+ val_mask = torch.zeros(N, dtype=torch.bool)
+ test_mask = torch.zeros(N, dtype=torch.bool)
+ for c in range(C):
+ idx = (data.y == c).nonzero(as_tuple=True)[0]
+ perm = torch.randperm(len(idx), generator=rng)
+ n_tr = max(1, int(0.05 * len(idx))) # 5% train (like Planetoid)
+ n_va = max(1, int(0.1 * len(idx)))
+ train_mask[idx[perm[:n_tr]]] = True
+ val_mask[idx[perm[n_tr:n_tr + n_va]]] = True
+ test_mask[idx[perm[n_tr + n_va:]]] = True
+
+ A_hat = build_normalized_adj(data.edge_index, N)
+ A_row, A_row_T = build_row_normalized_adj(data.edge_index, N)
+ traces = {k: torch.tensor(0.0) for k in range(5)} # skip expensive trace computation
+
+ return {
+ 'X': data.x.to(device), 'y': data.y.to(device),
+ 'A_hat': A_hat.to(device), 'A_row': A_row.to(device), 'A_row_T': A_row_T.to(device),
+ 'train_mask': train_mask.to(device), 'val_mask': val_mask.to(device),
+ 'test_mask': test_mask.to(device),
+ 'num_nodes': N, 'num_features': data.x.shape[1], 'num_classes': C,
+ 'traces': {k: v.to(device) for k, v in traces.items()},
+ }
+
+
+def train_one(cls, common, extra, seed):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ t = cls(**common, **extra)
+ if hasattr(t, 'align_mode'):
+ t.align_mode = 'chain_norm'
+ bv, bt = 0, 0
+ for ep in range(EPOCHS):
+ t.train_step()
+ if ep % 5 == 0:
+ v, te = t.evaluate('val_mask'), t.evaluate('test_mask')
+ if v > bv: bv, bt = v, te
+ del t; torch.cuda.empty_cache()
+ return bt
+
+
+def time_method(cls, common, extra, n_warmup=10, n_steps=200):
+ torch.manual_seed(0)
+ t = cls(**common, **extra)
+ if hasattr(t, 'align_mode'):
+ t.align_mode = 'chain_norm'
+ for _ in range(n_warmup):
+ t.train_step()
+ torch.cuda.synchronize()
+ times = []
+ for _ in range(n_steps):
+ torch.cuda.synchronize(); t0 = time.perf_counter()
+ t.train_step()
+ torch.cuda.synchronize(); times.append(time.perf_counter() - t0)
+ del t; torch.cuda.empty_cache()
+ return float(np.median(times) * 1000)
+
+
+def main():
+ os.makedirs(OUT_DIR, exist_ok=True)
+ results = {}
+
+ grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10,
+ lr_feedback=0.5, num_probes=64, topo_mode='fixed_A')
+
+ # ======== Part 1: DBLP full sweep ========
+ print("=" * 60)
+ print("Part 1: CitationFull-DBLP")
+ print("=" * 60)
+ dblp = load_dblp()
+ print(f"DBLP: N={dblp['num_nodes']}, F={dblp['num_features']}, C={dblp['num_classes']}, "
+ f"train={dblp['train_mask'].sum().item()}", flush=True)
+
+ for bb in ['gcn', 'sage', 'gin', 'appnp']:
+ for L in [5, 6]:
+ for lr in [0.001, 0.005, 0.01]:
+ common = dict(data=dblp, hidden_dim=64, lr=lr, weight_decay=5e-4,
+ num_layers=L, residual_alpha=0.0, backbone=bb)
+ key = f"DBLP|{bb}|L={L}|lr={lr}"
+ row = {}
+ for mname, cls, extra in [('BP', BPTrainer, {}),
+ ('DFA', DFATrainer, dict(diffusion_alpha=0.5, diffusion_iters=10)),
+ ('GrAPE', GraphGrAPETrainer, grape_extra)]:
+ accs = [train_one(cls, common, extra, s) for s in SEEDS]
+ row[mname] = {'mean': float(np.mean(accs)), 'std': float(np.std(accs))}
+ results[key] = row
+ bp, dfa, gr = row['BP']['mean']*100, row['DFA']['mean']*100, row['GrAPE']['mean']*100
+ print(f" {bb:>6} L={L} lr={lr:.3f} | BP {bp:.1f} DFA {dfa:.1f} GrAPE {gr:.1f} | "
+ f"Δ(BP) {gr-bp:+.1f} Δ(DFA) {gr-dfa:+.1f}", flush=True)
+
+ # DBLP efficiency
+ print("\nDBLP Efficiency:")
+ for bb in ['gcn', 'sage', 'gin', 'appnp']:
+ for L in [5, 6]:
+ common = dict(data=dblp, hidden_dim=64, lr=0.01, weight_decay=5e-4,
+ num_layers=L, residual_alpha=0.0, backbone=bb)
+ bp_ms = time_method(BPTrainer, common, {})
+ eff_ms = time_method(GraphGrAPEEfficient, common,
+ dict(lr_feedback=0.5, num_probes=64, max_topo_power=3,
+ diff_alpha=0.5, align_every=10))
+ key = f"DBLP_eff|{bb}|L={L}"
+ results[key] = {'BP_ms': bp_ms, 'GrAPE_Eff_ms': eff_ms, 'speedup': bp_ms / eff_ms}
+ print(f" {bb:>6} L={L} | BP {bp_ms:.2f}ms GrAPE-Eff {eff_ms:.2f}ms | "
+ f"speedup {bp_ms/eff_ms:.2f}x", flush=True)
+
+ # ======== Part 2: Depth sweep L=2-4 补数据 (Planetoid × GCN/SAGE/APPNP) ========
+ print("\n" + "=" * 60)
+ print("Part 2: Depth sweep L=2,3,4 补数据")
+ print("=" * 60)
+
+ from src.data import load_dataset
+ for ds_name in ['Cora', 'CiteSeer', 'PubMed']:
+ data = load_dataset(ds_name, device=device)
+ for bb in ['gcn', 'sage', 'appnp']:
+ for L in [2, 3, 4]:
+ common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
+ num_layers=L, residual_alpha=0.0, backbone=bb)
+ key = f"{ds_name}|{bb}|L={L}|lr=0.01"
+ row = {}
+ for mname, cls, extra in [('BP', BPTrainer, {}),
+ ('GrAPE', GraphGrAPETrainer, grape_extra)]:
+ accs = [train_one(cls, common, extra, s) for s in SEEDS]
+ row[mname] = {'mean': float(np.mean(accs)), 'std': float(np.std(accs))}
+ results[key] = row
+ bp, gr = row['BP']['mean']*100, row['GrAPE']['mean']*100
+ print(f" {ds_name:>10} {bb:>6} L={L} | BP {bp:.1f} GrAPE {gr:.1f} | Δ {gr-bp:+.1f}", flush=True)
+
+ with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved to {OUT_DIR}/results.json")
+
+
+if __name__ == '__main__':
+ main()