1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
|
#!/usr/bin/env python3
"""
Cora perturbation experiment: directly test causal factors.
Three perturbation types:
1. Edge rewiring: destroy community structure
2. Label shuffling: reduce homophily
3. Feature masking: reduce feature quality
"""
import torch
import torch.nn.functional as F
import numpy as np
import json
import os
from src.data import load_dataset, build_normalized_adj, build_row_normalized_adj
from src.trainers import BPTrainer, DFATrainer, KAFTTrainer
device = 'cuda:0'
SEEDS = [0, 1, 2, 3, 4]
EPOCHS = 200
OUT_DIR = 'results/cora_perturbation'
def perturb_edges(data, rewire_frac, seed=0):
"""Randomly rewire a fraction of edges (destroys community structure)."""
d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
rng = torch.Generator().manual_seed(seed)
A = d['A_hat']
idx = A.indices()
vals = A.values()
N = d['num_nodes']
n_rewire = int(rewire_frac * idx.shape[1])
if n_rewire > 0:
perm = torch.randperm(idx.shape[1], generator=rng)[:n_rewire].to(idx.device)
new_targets = torch.randint(0, N, (n_rewire,), generator=rng).to(idx.device)
idx_new = idx.clone()
idx_new[1, perm] = new_targets
A_new = torch.sparse_coo_tensor(idx_new, vals, (N, N)).coalesce()
d['A_hat'] = A_new
d['A_row'] = A_new # simplified
d['A_row_T'] = A_new
return d
def perturb_labels(data, shuffle_frac, seed=0):
"""Shuffle a fraction of labels (reduces homophily)."""
d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
rng = torch.Generator().manual_seed(seed)
y = d['y'].clone()
N = len(y)
n_shuffle = int(shuffle_frac * N)
perm = torch.randperm(N, generator=rng)[:n_shuffle]
shuffled = y[perm][torch.randperm(n_shuffle, generator=rng)]
y[perm] = shuffled
d['y'] = y
return d
def perturb_features(data, mask_frac, seed=0):
"""Zero out a fraction of feature dimensions (reduces feature quality)."""
d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
rng = torch.Generator().manual_seed(seed)
X = d['X'].clone()
F_dim = X.shape[1]
n_mask = int(mask_frac * F_dim)
mask_dims = torch.randperm(F_dim, generator=rng)[:n_mask]
X[:, mask_dims] = 0
d['X'] = X
return d
def train_one(cls, common, extra, seed):
torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
t = cls(**common, **extra)
if hasattr(t, 'align_mode'):
t.align_mode = 'chain_norm'
bv, bt = 0, 0
for ep in range(EPOCHS):
t.train_step()
if ep % 5 == 0:
v, te = t.evaluate('val_mask'), t.evaluate('test_mask')
if v > bv: bv, bt = v, te
del t; torch.cuda.empty_cache()
return bt
def main():
os.makedirs(OUT_DIR, exist_ok=True)
data_orig = load_dataset('Cora', device=device)
grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10,
lr_feedback=0.5, num_probes=64, topo_mode='fixed_A')
results = {}
L = 6
perturbations = [
('edge_rewire', [0, 0.1, 0.2, 0.3, 0.5], perturb_edges),
('label_shuffle', [0, 0.1, 0.2, 0.3, 0.5], perturb_labels),
('feature_mask', [0, 0.2, 0.4, 0.6, 0.8], perturb_features),
]
for ptype, fracs, pfunc in perturbations:
print(f"\n=== {ptype} (Cora, GCN, L={L}) ===", flush=True)
print(f"{'frac':>6} | {'BP':>8} {'DFA':>8} {'GrAPE':>8} | {'Δ(BP)':>7}", flush=True)
for frac in fracs:
bp_accs, gr_accs = [], []
for seed in SEEDS:
data = pfunc(data_orig, frac, seed=seed) if frac > 0 else data_orig
common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
num_layers=L, residual_alpha=0.0, backbone='gcn')
bp_accs.append(train_one(BPTrainer, common, {}, seed))
gr_accs.append(train_one(KAFTTrainer, common, grape_extra, seed))
bp, gr = np.mean(bp_accs)*100, np.mean(gr_accs)*100
delta = gr - bp
key = f"{ptype}|frac={frac}"
results[key] = {'bp': float(np.mean(bp_accs)), 'grape': float(np.mean(gr_accs)),
'delta': float(gr - bp), 'frac': frac, 'ptype': ptype}
print(f"{frac:>6.1f} | {bp:>7.1f} {'—':>8} {gr:>7.1f} | {delta:>+6.1f}", flush=True)
with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
json.dump(results, f, indent=2)
print(f"\nSaved to {OUT_DIR}/results.json")
if __name__ == '__main__':
main()
|