diff options
Diffstat (limited to 'experiments/run_cora_perturb.py')
| -rw-r--r-- | experiments/run_cora_perturb.py | 129 |
1 files changed, 129 insertions, 0 deletions
diff --git a/experiments/run_cora_perturb.py b/experiments/run_cora_perturb.py new file mode 100644 index 0000000..1dabc95 --- /dev/null +++ b/experiments/run_cora_perturb.py @@ -0,0 +1,129 @@ +#!/usr/bin/env python3 +""" +Cora perturbation experiment: directly test causal factors. +Three perturbation types: +1. Edge rewiring: destroy community structure +2. Label shuffling: reduce homophily +3. Feature masking: reduce feature quality +""" + +import torch +import torch.nn.functional as F +import numpy as np +import json +import os +from src.data import load_dataset, build_normalized_adj, build_row_normalized_adj +from src.trainers import BPTrainer, DFATrainer, GraphGrAPETrainer + +device = 'cuda:0' +SEEDS = [0, 1, 2, 3, 4] +EPOCHS = 200 +OUT_DIR = 'results/cora_perturbation' + + +def perturb_edges(data, rewire_frac, seed=0): + """Randomly rewire a fraction of edges (destroys community structure).""" + d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + rng = torch.Generator().manual_seed(seed) + A = d['A_hat'] + idx = A.indices() + vals = A.values() + N = d['num_nodes'] + + n_rewire = int(rewire_frac * idx.shape[1]) + if n_rewire > 0: + perm = torch.randperm(idx.shape[1], generator=rng)[:n_rewire].to(idx.device) + new_targets = torch.randint(0, N, (n_rewire,), generator=rng).to(idx.device) + idx_new = idx.clone() + idx_new[1, perm] = new_targets + + A_new = torch.sparse_coo_tensor(idx_new, vals, (N, N)).coalesce() + d['A_hat'] = A_new + d['A_row'] = A_new # simplified + d['A_row_T'] = A_new + return d + + +def perturb_labels(data, shuffle_frac, seed=0): + """Shuffle a fraction of labels (reduces homophily).""" + d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + rng = torch.Generator().manual_seed(seed) + y = d['y'].clone() + N = len(y) + n_shuffle = int(shuffle_frac * N) + perm = torch.randperm(N, generator=rng)[:n_shuffle] + shuffled = y[perm][torch.randperm(n_shuffle, generator=rng)] + y[perm] = shuffled + d['y'] = y + return d + + +def perturb_features(data, mask_frac, seed=0): + """Zero out a fraction of feature dimensions (reduces feature quality).""" + d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()} + rng = torch.Generator().manual_seed(seed) + X = d['X'].clone() + F_dim = X.shape[1] + n_mask = int(mask_frac * F_dim) + mask_dims = torch.randperm(F_dim, generator=rng)[:n_mask] + X[:, mask_dims] = 0 + d['X'] = X + return d + + +def train_one(cls, common, extra, seed): + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + t = cls(**common, **extra) + if hasattr(t, 'align_mode'): + t.align_mode = 'chain_norm' + bv, bt = 0, 0 + for ep in range(EPOCHS): + t.train_step() + if ep % 5 == 0: + v, te = t.evaluate('val_mask'), t.evaluate('test_mask') + if v > bv: bv, bt = v, te + del t; torch.cuda.empty_cache() + return bt + + +def main(): + os.makedirs(OUT_DIR, exist_ok=True) + data_orig = load_dataset('Cora', device=device) + grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10, + lr_feedback=0.5, num_probes=64, topo_mode='fixed_A') + results = {} + L = 6 + + perturbations = [ + ('edge_rewire', [0, 0.1, 0.2, 0.3, 0.5], perturb_edges), + ('label_shuffle', [0, 0.1, 0.2, 0.3, 0.5], perturb_labels), + ('feature_mask', [0, 0.2, 0.4, 0.6, 0.8], perturb_features), + ] + + for ptype, fracs, pfunc in perturbations: + print(f"\n=== {ptype} (Cora, GCN, L={L}) ===", flush=True) + print(f"{'frac':>6} | {'BP':>8} {'DFA':>8} {'GrAPE':>8} | {'Δ(BP)':>7}", flush=True) + + for frac in fracs: + bp_accs, gr_accs = [], [] + for seed in SEEDS: + data = pfunc(data_orig, frac, seed=seed) if frac > 0 else data_orig + common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, + num_layers=L, residual_alpha=0.0, backbone='gcn') + bp_accs.append(train_one(BPTrainer, common, {}, seed)) + gr_accs.append(train_one(GraphGrAPETrainer, common, grape_extra, seed)) + + bp, gr = np.mean(bp_accs)*100, np.mean(gr_accs)*100 + delta = gr - bp + key = f"{ptype}|frac={frac}" + results[key] = {'bp': float(np.mean(bp_accs)), 'grape': float(np.mean(gr_accs)), + 'delta': float(gr - bp), 'frac': frac, 'ptype': ptype} + print(f"{frac:>6.1f} | {bp:>7.1f} {'—':>8} {gr:>7.1f} | {delta:>+6.1f}", flush=True) + + with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: + json.dump(results, f, indent=2) + print(f"\nSaved to {OUT_DIR}/results.json") + + +if __name__ == '__main__': + main() |
