summaryrefslogtreecommitdiff
path: root/experiments/run_ablation_20seeds.py
blob: e90917ddd1b2b9efc7b1f472cf16677477aa1320 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#!/usr/bin/env python3
"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → KAFT."""

import torch
import numpy as np
import json
import os
from scipy import stats as scipy_stats
from src.data import load_dataset
from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, KAFTTrainer

device = 'cuda:0'
SEEDS = list(range(20))
EPOCHS = 200
OUT_DIR = 'results/ablation_20seeds'

METHODS = {
    'BP': (BPTrainer, {}),
    'DFA': (DFATrainer, {}),
    'DFA-GNN': (DFAGNNTrainer, {'topo_mode': 'fixed_A'}),
    'KAFT': (KAFTTrainer, {
        'diffusion_alpha': 0.5, 'diffusion_iters': 10,
        'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A'
    }),
}


def train_one(cls, common, extra, seed):
    torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
    t = cls(**common, **extra)
    if hasattr(t, 'align_mode'):
        t.align_mode = 'chain_norm'
    bv, bt = 0, 0
    for ep in range(EPOCHS):
        t.train_step()
        if ep % 5 == 0:
            v = t.evaluate('val_mask')
            te = t.evaluate('test_mask')
            if v > bv: bv, bt = v, te
    del t; torch.cuda.empty_cache()
    return bt


def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json')
    if os.path.exists(per_seed_file):
        with open(per_seed_file) as f:
            per_seed_data = json.load(f)
    else:
        per_seed_data = {}

    results = {}

    for ds_name in ['Cora', 'CiteSeer', 'PubMed']:
        data = load_dataset(ds_name, device=device)
        common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
                      num_layers=6, residual_alpha=0.0, backbone='gcn')

        for mname, (cls, extra) in METHODS.items():
            key = f"{ds_name}_{mname}"
            print(f"\n=== {key} (20 seeds) ===", flush=True)

            if key not in per_seed_data:
                per_seed_data[key] = {}

            for seed in SEEDS:
                sk = str(seed)
                if sk in per_seed_data[key]:
                    print(f"  seed {seed}: cached", flush=True)
                    continue
                acc = train_one(cls, common, extra, seed)
                per_seed_data[key][sk] = acc
                print(f"  seed {seed}: {acc*100:.1f}%", flush=True)

                with open(per_seed_file, 'w') as f:
                    json.dump(per_seed_data, f, indent=2)

            accs = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100
            results[key] = {
                'mean': float(accs.mean()), 'std': float(accs.std()),
                'accs': accs.tolist(),
            }
            print(f"  {mname}: {accs.mean():.1f} ± {accs.std():.1f}%")

        del data; torch.cuda.empty_cache()

    # Paired t-tests between adjacent methods
    print("\n=== Paired t-tests (adjacent methods) ===")
    method_names = list(METHODS.keys())
    for ds in ['Cora', 'CiteSeer', 'PubMed']:
        print(f"\n{ds}:")
        for i in range(len(method_names) - 1):
            m1, m2 = method_names[i], method_names[i+1]
            a1 = np.array(results[f"{ds}_{m1}"]['accs'])
            a2 = np.array(results[f"{ds}_{m2}"]['accs'])
            t_stat, p_val = scipy_stats.ttest_rel(a2, a1)
            sig = '***' if p_val < 0.001 else ('**' if p_val < 0.01 else ('*' if p_val < 0.05 else 'ns'))
            delta = a2.mean() - a1.mean()
            results[f"{ds}_{m1}_vs_{m2}"] = {
                'delta': float(delta), 't_stat': float(t_stat), 'p_value': float(p_val)
            }
            print(f"  {m1} → {m2}: Δ{delta:+.1f}% p={p_val:.4f} {sig}")

    with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\nSaved to {OUT_DIR}/results.json")


if __name__ == '__main__':
    main()