1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
|
#!/usr/bin/env python3
"""E2: Shallow depth (L=2,3,4) on 4 datasets. Last exploratory avenue after
E1 (deep scaling) and E0-extras (more datasets) both failed to extend KAFT's
regime. If KAFT still wins at L=2/3 (standard GNN depth), we can counter
the reviewer attack 'L=5,6 nobody uses'. If KAFT matches BP only at L=5,6,
paper stays at current scope and we ship."""
import torch
import numpy as np
import json
import os
from scipy import stats as scipy_stats
from src.data import load_dataset
from src.trainers import BPTrainer, KAFTTrainer
from run_deep_baselines import ResGCNTrainer
from run_combo_20seeds import GRAFTResGCN
from run_dblp_depth import load_dblp
device = 'cuda:0'
SEEDS = list(range(20))
EPOCHS = 200
DEPTHS = [2, 3, 4]
OUT_DIR = 'results/shallow_depth_20seeds'
grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10,
lr_feedback=0.5, num_probes=64, topo_mode='fixed_A')
METHODS = {
'BP': (BPTrainer, {}),
'KAFT': (KAFTTrainer, grape_extra),
'KAFT+ResGCN': (GRAFTResGCN, grape_extra),
}
def train_one(cls, common, extra, seed):
torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
t = cls(**common, **extra)
if hasattr(t, 'align_mode'):
t.align_mode = 'chain_norm'
bv, bt = 0, 0
for ep in range(EPOCHS):
t.train_step()
if ep % 5 == 0:
v = t.evaluate('val_mask')
te = t.evaluate('test_mask')
if v > bv: bv, bt = v, te
del t; torch.cuda.empty_cache()
return bt
def main():
os.makedirs(OUT_DIR, exist_ok=True)
per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json')
if os.path.exists(per_seed_file):
with open(per_seed_file) as f:
per_seed_data = json.load(f)
else:
per_seed_data = {}
datasets_cfg = {
'Cora': lambda: load_dataset('Cora', device=device),
'CiteSeer': lambda: load_dataset('CiteSeer', device=device),
'PubMed': lambda: load_dataset('PubMed', device=device),
'DBLP': lambda: load_dblp(),
}
for ds_name, loader in datasets_cfg.items():
data = loader()
for L in DEPTHS:
print(f"\n{'=' * 60}\n{ds_name} L={L}\n{'=' * 60}", flush=True)
common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
num_layers=L, residual_alpha=0.0, backbone='gcn')
for mname, (cls, extra) in METHODS.items():
key = f"{ds_name}_L{L}_{mname}"
if key not in per_seed_data:
per_seed_data[key] = {}
print(f"\n--- {key} ---", flush=True)
for seed in SEEDS:
sk = str(seed)
if sk in per_seed_data[key]:
print(f" seed {seed}: cached ({per_seed_data[key][sk]*100:.1f}%)", flush=True)
continue
try:
acc = train_one(cls, common, extra, seed)
per_seed_data[key][sk] = acc
print(f" seed {seed}: {acc*100:.1f}%", flush=True)
except Exception as e:
print(f" seed {seed}: FAILED - {e}", flush=True)
per_seed_data[key][sk] = 0.0
with open(per_seed_file, 'w') as f:
json.dump(per_seed_data, f, indent=2)
del data; torch.cuda.empty_cache()
# Summary
print(f"\n{'=' * 70}\nShallow depth summary (20 seeds)\n{'=' * 70}")
results = {}
for ds in datasets_cfg:
for L in DEPTHS:
bp_key = f"{ds}_L{L}_BP"
gr_key = f"{ds}_L{L}_GRAFT"
stk_key = f"{ds}_L{L}_GRAFT+ResGCN"
bp_accs = np.array([per_seed_data[bp_key][str(s)] for s in SEEDS]) * 100
gr_accs = np.array([per_seed_data[gr_key][str(s)] for s in SEEDS]) * 100
stk_accs = np.array([per_seed_data[stk_key][str(s)] for s in SEEDS]) * 100
t, p = scipy_stats.ttest_rel(gr_accs, bp_accs)
delta = gr_accs.mean() - bp_accs.mean()
print(f" {ds} L={L}: BP {bp_accs.mean():5.1f}±{bp_accs.std():4.1f} "
f"KAFT {gr_accs.mean():5.1f}±{gr_accs.std():4.1f} "
f"KAFT+ResGCN {stk_accs.mean():5.1f}±{stk_accs.std():4.1f} "
f"Δ(KAFT-BP)={delta:+.1f}, p={p:.4f}")
for mname, accs in [('BP', bp_accs), ('KAFT', gr_accs), ('KAFT+ResGCN', stk_accs)]:
key = f"{ds}_L{L}_{mname}"
results[key] = {'mean': float(accs.mean()), 'std': float(accs.std()),
'per_seed': accs.tolist()}
with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
json.dump(results, f, indent=2)
print(f"\nSaved to {OUT_DIR}/results.json")
if __name__ == '__main__':
main()
|