diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-05-04 23:05:16 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-05-04 23:05:16 -0500 |
| commit | bd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch) | |
| tree | 7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /experiments/run_shallow_depth.py | |
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes:
- src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines
+ multi-probe alignment estimator + dataset / sparse-mm utilities.
- experiments/: 19 runners reproducing every figure / table in the paper.
- figures/: 4 generators + the 4 PDFs cited in the report.
- paper/: NeurIPS .tex and consolidated experiments_master notes.
Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'experiments/run_shallow_depth.py')
| -rw-r--r-- | experiments/run_shallow_depth.py | 125 |
1 files changed, 125 insertions, 0 deletions
diff --git a/experiments/run_shallow_depth.py b/experiments/run_shallow_depth.py new file mode 100644 index 0000000..68c9138 --- /dev/null +++ b/experiments/run_shallow_depth.py @@ -0,0 +1,125 @@ +#!/usr/bin/env python3 +"""E2: Shallow depth (L=2,3,4) on 4 datasets. Last exploratory avenue after +E1 (deep scaling) and E0-extras (more datasets) both failed to extend GRAFT's +regime. If GRAFT still wins at L=2/3 (standard GNN depth), we can counter +the reviewer attack 'L=5,6 nobody uses'. If GRAFT matches BP only at L=5,6, +paper stays at current scope and we ship.""" + +import torch +import numpy as np +import json +import os +from scipy import stats as scipy_stats +from src.data import load_dataset +from src.trainers import BPTrainer, GraphGrAPETrainer +from run_deep_baselines import ResGCNTrainer +from run_combo_20seeds import GRAFTResGCN +from run_dblp_depth import load_dblp + +device = 'cuda:0' +SEEDS = list(range(20)) +EPOCHS = 200 +DEPTHS = [2, 3, 4] +OUT_DIR = 'results/shallow_depth_20seeds' + +grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10, + lr_feedback=0.5, num_probes=64, topo_mode='fixed_A') + +METHODS = { + 'BP': (BPTrainer, {}), + 'GRAFT': (GraphGrAPETrainer, grape_extra), + 'GRAFT+ResGCN': (GRAFTResGCN, grape_extra), +} + + +def train_one(cls, common, extra, seed): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + t = cls(**common, **extra) + if hasattr(t, 'align_mode'): + t.align_mode = 'chain_norm' + bv, bt = 0, 0 + for ep in range(EPOCHS): + t.train_step() + if ep % 5 == 0: + v = t.evaluate('val_mask') + te = t.evaluate('test_mask') + if v > bv: bv, bt = v, te + del t; torch.cuda.empty_cache() + return bt + + +def main(): + os.makedirs(OUT_DIR, exist_ok=True) + per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json') + if os.path.exists(per_seed_file): + with open(per_seed_file) as f: + per_seed_data = json.load(f) + else: + per_seed_data = {} + + datasets_cfg = { + 'Cora': lambda: load_dataset('Cora', device=device), + 'CiteSeer': lambda: load_dataset('CiteSeer', device=device), + 'PubMed': lambda: load_dataset('PubMed', device=device), + 'DBLP': lambda: load_dblp(), + } + + for ds_name, loader in datasets_cfg.items(): + data = loader() + for L in DEPTHS: + print(f"\n{'=' * 60}\n{ds_name} L={L}\n{'=' * 60}", flush=True) + common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, + num_layers=L, residual_alpha=0.0, backbone='gcn') + + for mname, (cls, extra) in METHODS.items(): + key = f"{ds_name}_L{L}_{mname}" + if key not in per_seed_data: + per_seed_data[key] = {} + + print(f"\n--- {key} ---", flush=True) + for seed in SEEDS: + sk = str(seed) + if sk in per_seed_data[key]: + print(f" seed {seed}: cached ({per_seed_data[key][sk]*100:.1f}%)", flush=True) + continue + try: + acc = train_one(cls, common, extra, seed) + per_seed_data[key][sk] = acc + print(f" seed {seed}: {acc*100:.1f}%", flush=True) + except Exception as e: + print(f" seed {seed}: FAILED - {e}", flush=True) + per_seed_data[key][sk] = 0.0 + + with open(per_seed_file, 'w') as f: + json.dump(per_seed_data, f, indent=2) + del data; torch.cuda.empty_cache() + + # Summary + print(f"\n{'=' * 70}\nShallow depth summary (20 seeds)\n{'=' * 70}") + results = {} + for ds in datasets_cfg: + for L in DEPTHS: + bp_key = f"{ds}_L{L}_BP" + gr_key = f"{ds}_L{L}_GRAFT" + stk_key = f"{ds}_L{L}_GRAFT+ResGCN" + bp_accs = np.array([per_seed_data[bp_key][str(s)] for s in SEEDS]) * 100 + gr_accs = np.array([per_seed_data[gr_key][str(s)] for s in SEEDS]) * 100 + stk_accs = np.array([per_seed_data[stk_key][str(s)] for s in SEEDS]) * 100 + t, p = scipy_stats.ttest_rel(gr_accs, bp_accs) + delta = gr_accs.mean() - bp_accs.mean() + print(f" {ds} L={L}: BP {bp_accs.mean():5.1f}±{bp_accs.std():4.1f} " + f"GRAFT {gr_accs.mean():5.1f}±{gr_accs.std():4.1f} " + f"GRAFT+ResGCN {stk_accs.mean():5.1f}±{stk_accs.std():4.1f} " + f"Δ(GRAFT-BP)={delta:+.1f}, p={p:.4f}") + for mname, accs in [('BP', bp_accs), ('GRAFT', gr_accs), ('GRAFT+ResGCN', stk_accs)]: + key = f"{ds}_L{L}_{mname}" + results[key] = {'mean': float(accs.mean()), 'std': float(accs.std()), + 'per_seed': accs.tolist()} + + with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved to {OUT_DIR}/results.json") + + +if __name__ == '__main__': + main() |
