summaryrefslogtreecommitdiff
path: root/experiments/run_shallow_depth.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/run_shallow_depth.py')
-rw-r--r--experiments/run_shallow_depth.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/experiments/run_shallow_depth.py b/experiments/run_shallow_depth.py
index 68c9138..5f1cc26 100644
--- a/experiments/run_shallow_depth.py
+++ b/experiments/run_shallow_depth.py
@@ -1,8 +1,8 @@
#!/usr/bin/env python3
"""E2: Shallow depth (L=2,3,4) on 4 datasets. Last exploratory avenue after
-E1 (deep scaling) and E0-extras (more datasets) both failed to extend GRAFT's
-regime. If GRAFT still wins at L=2/3 (standard GNN depth), we can counter
-the reviewer attack 'L=5,6 nobody uses'. If GRAFT matches BP only at L=5,6,
+E1 (deep scaling) and E0-extras (more datasets) both failed to extend KAFT's
+regime. If KAFT still wins at L=2/3 (standard GNN depth), we can counter
+the reviewer attack 'L=5,6 nobody uses'. If KAFT matches BP only at L=5,6,
paper stays at current scope and we ship."""
import torch
@@ -11,7 +11,7 @@ import json
import os
from scipy import stats as scipy_stats
from src.data import load_dataset
-from src.trainers import BPTrainer, GraphGrAPETrainer
+from src.trainers import BPTrainer, KAFTTrainer
from run_deep_baselines import ResGCNTrainer
from run_combo_20seeds import GRAFTResGCN
from run_dblp_depth import load_dblp
@@ -27,8 +27,8 @@ grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10,
METHODS = {
'BP': (BPTrainer, {}),
- 'GRAFT': (GraphGrAPETrainer, grape_extra),
- 'GRAFT+ResGCN': (GRAFTResGCN, grape_extra),
+ 'KAFT': (KAFTTrainer, grape_extra),
+ 'KAFT+ResGCN': (GRAFTResGCN, grape_extra),
}
@@ -108,10 +108,10 @@ def main():
t, p = scipy_stats.ttest_rel(gr_accs, bp_accs)
delta = gr_accs.mean() - bp_accs.mean()
print(f" {ds} L={L}: BP {bp_accs.mean():5.1f}±{bp_accs.std():4.1f} "
- f"GRAFT {gr_accs.mean():5.1f}±{gr_accs.std():4.1f} "
- f"GRAFT+ResGCN {stk_accs.mean():5.1f}±{stk_accs.std():4.1f} "
- f"Δ(GRAFT-BP)={delta:+.1f}, p={p:.4f}")
- for mname, accs in [('BP', bp_accs), ('GRAFT', gr_accs), ('GRAFT+ResGCN', stk_accs)]:
+ f"KAFT {gr_accs.mean():5.1f}±{gr_accs.std():4.1f} "
+ f"KAFT+ResGCN {stk_accs.mean():5.1f}±{stk_accs.std():4.1f} "
+ f"Δ(KAFT-BP)={delta:+.1f}, p={p:.4f}")
+ for mname, accs in [('BP', bp_accs), ('KAFT', gr_accs), ('KAFT+ResGCN', stk_accs)]:
key = f"{ds}_L{L}_{mname}"
results[key] = {'mean': float(accs.mean()), 'std': float(accs.std()),
'per_seed': accs.tolist()}