summaryrefslogtreecommitdiff
path: root/experiments/run_bp_graft_depth.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:10:10 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:10:10 -0500
commitba6ead6d7a41b7ed78bb228181b7262d0c75d2eb (patch)
tree726171fb4b0c536d9287a15daf52929ec65fa3d0 /experiments/run_bp_graft_depth.py
parent37ba0f83e3652a215680fd8515af9c14fc02e21c (diff)
Global rename GRAFT → KAFT (incl. internal class + filenames)
- src/trainers.py: GraphGrAPETrainer → KAFTTrainer; module docstring + comments. VanillaGrAPETrainer kept as-is (it is a separate control method, not KAFT). - experiments/: all 19 runners pick up the new class name; result keys ('Cora_GRAFT' etc) become 'Cora_KAFT'; OUT_DIRs renamed (e.g. bp_graft_depth_20seeds → bp_kaft_depth_20seeds). - figures/: data-lookup keys + display labels both 'KAFT'; output filename graft_depth_sweep.{pdf,png} → kaft_depth_sweep.{pdf,png}. - File rename: experiments/run_bp_graft_depth.py → run_bp_kaft_depth.py; figures/graft_depth_sweep.pdf → kaft_depth_sweep.pdf. - README aligned. Imports verified: from src.trainers import KAFTTrainer succeeds.
Diffstat (limited to 'experiments/run_bp_graft_depth.py')
-rw-r--r--experiments/run_bp_graft_depth.py111
1 files changed, 0 insertions, 111 deletions
diff --git a/experiments/run_bp_graft_depth.py b/experiments/run_bp_graft_depth.py
deleted file mode 100644
index 1e8dd76..0000000
--- a/experiments/run_bp_graft_depth.py
+++ /dev/null
@@ -1,111 +0,0 @@
-#!/usr/bin/env python3
-"""H9: BP + GRAFT depth sweep on Cora/CiteSeer/PubMed.
-
-E1 already did DBLP L={8,12,16,20,24,32}. This fills the gap for Cora/CiteSeer/PubMed
-at L={8,10,12,16,20} so we can plot Figure 4(a)-style depth curves on 4 datasets.
-
-BP + GRAFT only (GRAFT+ResGCN not needed for this figure — that's stacking table).
-"""
-
-import torch
-import numpy as np
-import json
-import os
-from src.data import load_dataset
-from src.trainers import BPTrainer, GraphGrAPETrainer
-
-device = 'cuda:0'
-SEEDS = list(range(20))
-EPOCHS = 200
-DEPTHS = [8, 10, 12, 16, 20]
-OUT_DIR = 'results/bp_graft_depth_20seeds'
-
-grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10,
- lr_feedback=0.5, num_probes=64, topo_mode='fixed_A')
-
-METHODS = {
- 'BP': (BPTrainer, {}),
- 'GRAFT': (GraphGrAPETrainer, grape_extra),
-}
-
-
-def train_one(cls, common, extra, seed):
- torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
- t = cls(**common, **extra)
- if hasattr(t, 'align_mode'):
- t.align_mode = 'chain_norm'
- bv, bt = 0, 0
- for ep in range(EPOCHS):
- t.train_step()
- if ep % 5 == 0:
- v = t.evaluate('val_mask')
- te = t.evaluate('test_mask')
- if v > bv: bv, bt = v, te
- del t; torch.cuda.empty_cache()
- return bt
-
-
-def main():
- os.makedirs(OUT_DIR, exist_ok=True)
- per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json')
- if os.path.exists(per_seed_file):
- with open(per_seed_file) as f:
- per_seed_data = json.load(f)
- else:
- per_seed_data = {}
-
- datasets_cfg = {
- 'Cora': lambda: load_dataset('Cora', device=device),
- 'CiteSeer': lambda: load_dataset('CiteSeer', device=device),
- 'PubMed': lambda: load_dataset('PubMed', device=device),
- }
-
- for ds_name, loader in datasets_cfg.items():
- data = loader()
- for L in DEPTHS:
- common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
- num_layers=L, residual_alpha=0.0, backbone='gcn')
-
- for mname, (cls, extra) in METHODS.items():
- key = f"{ds_name}_L{L}_{mname}"
- if key not in per_seed_data:
- per_seed_data[key] = {}
-
- print(f"\n=== {key} (20 seeds) ===", flush=True)
- for seed in SEEDS:
- sk = str(seed)
- if sk in per_seed_data[key]:
- print(f" seed {seed}: cached ({per_seed_data[key][sk]*100:.1f}%)", flush=True)
- continue
- try:
- acc = train_one(cls, common, extra, seed)
- per_seed_data[key][sk] = acc
- print(f" seed {seed}: {acc*100:.1f}%", flush=True)
- except Exception as e:
- print(f" seed {seed}: FAILED - {e}", flush=True)
- per_seed_data[key][sk] = 0.0
-
- with open(per_seed_file, 'w') as f:
- json.dump(per_seed_data, f, indent=2)
- del data; torch.cuda.empty_cache()
-
- # Summary
- print(f"\n{'=' * 70}\nBP/GRAFT depth sweep summary\n{'=' * 70}")
- results = {}
- for ds in datasets_cfg:
- print(f"\n{ds}:")
- for L in DEPTHS:
- for m in METHODS:
- key = f"{ds}_L{L}_{m}"
- vals = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100
- results[key] = {'mean': float(vals.mean()), 'std': float(vals.std()),
- 'per_seed': vals.tolist()}
- print(f" L={L:2d} {m:<6} {vals.mean():5.1f} ± {vals.std():4.1f}")
-
- with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
- json.dump(results, f, indent=2)
- print(f"\nSaved to {OUT_DIR}/results.json")
-
-
-if __name__ == '__main__':
- main()