diff options
Diffstat (limited to 'experiments/run_shallow_depth.py')
| -rw-r--r-- | experiments/run_shallow_depth.py | 20 |
1 files changed, 10 insertions, 10 deletions
diff --git a/experiments/run_shallow_depth.py b/experiments/run_shallow_depth.py index 68c9138..5f1cc26 100644 --- a/experiments/run_shallow_depth.py +++ b/experiments/run_shallow_depth.py @@ -1,8 +1,8 @@ #!/usr/bin/env python3 """E2: Shallow depth (L=2,3,4) on 4 datasets. Last exploratory avenue after -E1 (deep scaling) and E0-extras (more datasets) both failed to extend GRAFT's -regime. If GRAFT still wins at L=2/3 (standard GNN depth), we can counter -the reviewer attack 'L=5,6 nobody uses'. If GRAFT matches BP only at L=5,6, +E1 (deep scaling) and E0-extras (more datasets) both failed to extend KAFT's +regime. If KAFT still wins at L=2/3 (standard GNN depth), we can counter +the reviewer attack 'L=5,6 nobody uses'. If KAFT matches BP only at L=5,6, paper stays at current scope and we ship.""" import torch @@ -11,7 +11,7 @@ import json import os from scipy import stats as scipy_stats from src.data import load_dataset -from src.trainers import BPTrainer, GraphGrAPETrainer +from src.trainers import BPTrainer, KAFTTrainer from run_deep_baselines import ResGCNTrainer from run_combo_20seeds import GRAFTResGCN from run_dblp_depth import load_dblp @@ -27,8 +27,8 @@ grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10, METHODS = { 'BP': (BPTrainer, {}), - 'GRAFT': (GraphGrAPETrainer, grape_extra), - 'GRAFT+ResGCN': (GRAFTResGCN, grape_extra), + 'KAFT': (KAFTTrainer, grape_extra), + 'KAFT+ResGCN': (GRAFTResGCN, grape_extra), } @@ -108,10 +108,10 @@ def main(): t, p = scipy_stats.ttest_rel(gr_accs, bp_accs) delta = gr_accs.mean() - bp_accs.mean() print(f" {ds} L={L}: BP {bp_accs.mean():5.1f}±{bp_accs.std():4.1f} " - f"GRAFT {gr_accs.mean():5.1f}±{gr_accs.std():4.1f} " - f"GRAFT+ResGCN {stk_accs.mean():5.1f}±{stk_accs.std():4.1f} " - f"Δ(GRAFT-BP)={delta:+.1f}, p={p:.4f}") - for mname, accs in [('BP', bp_accs), ('GRAFT', gr_accs), ('GRAFT+ResGCN', stk_accs)]: + f"KAFT {gr_accs.mean():5.1f}±{gr_accs.std():4.1f} " + f"KAFT+ResGCN {stk_accs.mean():5.1f}±{stk_accs.std():4.1f} " + f"Δ(KAFT-BP)={delta:+.1f}, p={p:.4f}") + for mname, accs in [('BP', bp_accs), ('KAFT', gr_accs), ('KAFT+ResGCN', stk_accs)]: key = f"{ds}_L{L}_{mname}" results[key] = {'mean': float(accs.mean()), 'std': float(accs.std()), 'per_seed': accs.tolist()} |
