summaryrefslogtreecommitdiff
path: root/experiments/run_depth_extras.py
blob: 66a7d457a6075c7be3d33765f6ef6eb3e655960c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
#!/usr/bin/env python3
"""H11: Fill depth sweep at L=14 and L=18 to densify Fig 4(a).
3 methods (BP / DFA-GNN / GRAFT) × 4 datasets × 2 depths × 20 seeds = 480 runs.
"""

import torch
import numpy as np
import json
import os
from src.data import load_dataset
from src.trainers import BPTrainer, DFAGNNTrainer, GraphGrAPETrainer
from run_dblp_depth import load_dblp

device = 'cuda:0'
SEEDS = list(range(20))
EPOCHS = 200
DEPTHS = [14, 18]
OUT_DIR = 'results/depth_extras_20seeds'

grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10,
                   lr_feedback=0.5, num_probes=64, topo_mode='fixed_A')
dfagnn_extra = dict(diffusion_alpha=0.5, diffusion_iters=10, max_topo_power=3)

METHODS = {
    'BP':      (BPTrainer, {}),
    'DFA-GNN': (DFAGNNTrainer, dfagnn_extra),
    'GRAFT':   (GraphGrAPETrainer, grape_extra),
}


def train_one(cls, common, extra, seed):
    torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
    t = cls(**common, **extra)
    if hasattr(t, 'align_mode'):
        t.align_mode = 'chain_norm'
    bv, bt = 0, 0
    for ep in range(EPOCHS):
        t.train_step()
        if ep % 5 == 0:
            v = t.evaluate('val_mask')
            te = t.evaluate('test_mask')
            if v > bv: bv, bt = v, te
    del t; torch.cuda.empty_cache()
    return bt


def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json')
    if os.path.exists(per_seed_file):
        with open(per_seed_file) as f:
            per_seed_data = json.load(f)
    else:
        per_seed_data = {}

    datasets_cfg = {
        'Cora': lambda: load_dataset('Cora', device=device),
        'CiteSeer': lambda: load_dataset('CiteSeer', device=device),
        'PubMed': lambda: load_dataset('PubMed', device=device),
        'DBLP': lambda: load_dblp(),
    }

    for ds_name, loader in datasets_cfg.items():
        data = loader()
        for L in DEPTHS:
            common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
                          num_layers=L, residual_alpha=0.0, backbone='gcn')
            for mname, (cls, extra) in METHODS.items():
                key = f"{ds_name}_L{L}_{mname}"
                if key not in per_seed_data:
                    per_seed_data[key] = {}
                print(f"\n=== {key} (20 seeds) ===", flush=True)
                for seed in SEEDS:
                    sk = str(seed)
                    if sk in per_seed_data[key]:
                        continue
                    try:
                        acc = train_one(cls, common, extra, seed)
                        per_seed_data[key][sk] = acc
                        print(f"  seed {seed}: {acc*100:.1f}%", flush=True)
                    except Exception as e:
                        print(f"  seed {seed}: FAILED - {e}", flush=True)
                        per_seed_data[key][sk] = 0.0
                    with open(per_seed_file, 'w') as f:
                        json.dump(per_seed_data, f, indent=2)
        del data; torch.cuda.empty_cache()

    print(f"\nDone. Saved to {per_seed_file}")


if __name__ == '__main__':
    main()