summaryrefslogtreecommitdiff
path: root/experiments/run_cora_perturb.py
blob: 1dabc95aa47342d3f8bbebef176a00791d5cde72 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
#!/usr/bin/env python3
"""
Cora perturbation experiment: directly test causal factors.
Three perturbation types:
1. Edge rewiring: destroy community structure
2. Label shuffling: reduce homophily
3. Feature masking: reduce feature quality
"""

import torch
import torch.nn.functional as F
import numpy as np
import json
import os
from src.data import load_dataset, build_normalized_adj, build_row_normalized_adj
from src.trainers import BPTrainer, DFATrainer, GraphGrAPETrainer

device = 'cuda:0'
SEEDS = [0, 1, 2, 3, 4]
EPOCHS = 200
OUT_DIR = 'results/cora_perturbation'


def perturb_edges(data, rewire_frac, seed=0):
    """Randomly rewire a fraction of edges (destroys community structure)."""
    d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
    rng = torch.Generator().manual_seed(seed)
    A = d['A_hat']
    idx = A.indices()
    vals = A.values()
    N = d['num_nodes']

    n_rewire = int(rewire_frac * idx.shape[1])
    if n_rewire > 0:
        perm = torch.randperm(idx.shape[1], generator=rng)[:n_rewire].to(idx.device)
        new_targets = torch.randint(0, N, (n_rewire,), generator=rng).to(idx.device)
        idx_new = idx.clone()
        idx_new[1, perm] = new_targets

        A_new = torch.sparse_coo_tensor(idx_new, vals, (N, N)).coalesce()
        d['A_hat'] = A_new
        d['A_row'] = A_new  # simplified
        d['A_row_T'] = A_new
    return d


def perturb_labels(data, shuffle_frac, seed=0):
    """Shuffle a fraction of labels (reduces homophily)."""
    d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
    rng = torch.Generator().manual_seed(seed)
    y = d['y'].clone()
    N = len(y)
    n_shuffle = int(shuffle_frac * N)
    perm = torch.randperm(N, generator=rng)[:n_shuffle]
    shuffled = y[perm][torch.randperm(n_shuffle, generator=rng)]
    y[perm] = shuffled
    d['y'] = y
    return d


def perturb_features(data, mask_frac, seed=0):
    """Zero out a fraction of feature dimensions (reduces feature quality)."""
    d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
    rng = torch.Generator().manual_seed(seed)
    X = d['X'].clone()
    F_dim = X.shape[1]
    n_mask = int(mask_frac * F_dim)
    mask_dims = torch.randperm(F_dim, generator=rng)[:n_mask]
    X[:, mask_dims] = 0
    d['X'] = X
    return d


def train_one(cls, common, extra, seed):
    torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
    t = cls(**common, **extra)
    if hasattr(t, 'align_mode'):
        t.align_mode = 'chain_norm'
    bv, bt = 0, 0
    for ep in range(EPOCHS):
        t.train_step()
        if ep % 5 == 0:
            v, te = t.evaluate('val_mask'), t.evaluate('test_mask')
            if v > bv: bv, bt = v, te
    del t; torch.cuda.empty_cache()
    return bt


def main():
    os.makedirs(OUT_DIR, exist_ok=True)
    data_orig = load_dataset('Cora', device=device)
    grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10,
                       lr_feedback=0.5, num_probes=64, topo_mode='fixed_A')
    results = {}
    L = 6

    perturbations = [
        ('edge_rewire', [0, 0.1, 0.2, 0.3, 0.5], perturb_edges),
        ('label_shuffle', [0, 0.1, 0.2, 0.3, 0.5], perturb_labels),
        ('feature_mask', [0, 0.2, 0.4, 0.6, 0.8], perturb_features),
    ]

    for ptype, fracs, pfunc in perturbations:
        print(f"\n=== {ptype} (Cora, GCN, L={L}) ===", flush=True)
        print(f"{'frac':>6} | {'BP':>8} {'DFA':>8} {'GrAPE':>8} | {'Δ(BP)':>7}", flush=True)

        for frac in fracs:
            bp_accs, gr_accs = [], []
            for seed in SEEDS:
                data = pfunc(data_orig, frac, seed=seed) if frac > 0 else data_orig
                common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
                              num_layers=L, residual_alpha=0.0, backbone='gcn')
                bp_accs.append(train_one(BPTrainer, common, {}, seed))
                gr_accs.append(train_one(GraphGrAPETrainer, common, grape_extra, seed))

            bp, gr = np.mean(bp_accs)*100, np.mean(gr_accs)*100
            delta = gr - bp
            key = f"{ptype}|frac={frac}"
            results[key] = {'bp': float(np.mean(bp_accs)), 'grape': float(np.mean(gr_accs)),
                           'delta': float(gr - bp), 'frac': frac, 'ptype': ptype}
            print(f"{frac:>6.1f} | {bp:>7.1f} {'—':>8} {gr:>7.1f} | {delta:>+6.1f}", flush=True)

    with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
        json.dump(results, f, indent=2)
    print(f"\nSaved to {OUT_DIR}/results.json")


if __name__ == '__main__':
    main()