summaryrefslogtreecommitdiff
path: root/experiments/run_cora_perturb.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:05:16 -0500
commitbd9333eda60a9029a198acaeacb1eca4312bd1e8 (patch)
tree7544c347b7ac4e8629fa1cc0fcf341d48cb69e2e /experiments/run_cora_perturb.py
Initial release: GRAFT (KAFT) — NeurIPS 2026 submission code
Topology-factorized Jacobian-aligned feedback for deep GNNs. Includes: - src/: GraphGrAPETrainer (KAFT) + BP / DFA / DFA-GNN / VanillaGrAPE baselines + multi-probe alignment estimator + dataset / sparse-mm utilities. - experiments/: 19 runners reproducing every figure / table in the paper. - figures/: 4 generators + the 4 PDFs cited in the report. - paper/: NeurIPS .tex and consolidated experiments_master notes. Smoke test: 50-epoch Cora GCN L=4 gives BP 77.3% / KAFT 79.0%.
Diffstat (limited to 'experiments/run_cora_perturb.py')
-rw-r--r--experiments/run_cora_perturb.py129
1 files changed, 129 insertions, 0 deletions
diff --git a/experiments/run_cora_perturb.py b/experiments/run_cora_perturb.py
new file mode 100644
index 0000000..1dabc95
--- /dev/null
+++ b/experiments/run_cora_perturb.py
@@ -0,0 +1,129 @@
+#!/usr/bin/env python3
+"""
+Cora perturbation experiment: directly test causal factors.
+Three perturbation types:
+1. Edge rewiring: destroy community structure
+2. Label shuffling: reduce homophily
+3. Feature masking: reduce feature quality
+"""
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import json
+import os
+from src.data import load_dataset, build_normalized_adj, build_row_normalized_adj
+from src.trainers import BPTrainer, DFATrainer, GraphGrAPETrainer
+
+device = 'cuda:0'
+SEEDS = [0, 1, 2, 3, 4]
+EPOCHS = 200
+OUT_DIR = 'results/cora_perturbation'
+
+
+def perturb_edges(data, rewire_frac, seed=0):
+ """Randomly rewire a fraction of edges (destroys community structure)."""
+ d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
+ rng = torch.Generator().manual_seed(seed)
+ A = d['A_hat']
+ idx = A.indices()
+ vals = A.values()
+ N = d['num_nodes']
+
+ n_rewire = int(rewire_frac * idx.shape[1])
+ if n_rewire > 0:
+ perm = torch.randperm(idx.shape[1], generator=rng)[:n_rewire].to(idx.device)
+ new_targets = torch.randint(0, N, (n_rewire,), generator=rng).to(idx.device)
+ idx_new = idx.clone()
+ idx_new[1, perm] = new_targets
+
+ A_new = torch.sparse_coo_tensor(idx_new, vals, (N, N)).coalesce()
+ d['A_hat'] = A_new
+ d['A_row'] = A_new # simplified
+ d['A_row_T'] = A_new
+ return d
+
+
+def perturb_labels(data, shuffle_frac, seed=0):
+ """Shuffle a fraction of labels (reduces homophily)."""
+ d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
+ rng = torch.Generator().manual_seed(seed)
+ y = d['y'].clone()
+ N = len(y)
+ n_shuffle = int(shuffle_frac * N)
+ perm = torch.randperm(N, generator=rng)[:n_shuffle]
+ shuffled = y[perm][torch.randperm(n_shuffle, generator=rng)]
+ y[perm] = shuffled
+ d['y'] = y
+ return d
+
+
+def perturb_features(data, mask_frac, seed=0):
+ """Zero out a fraction of feature dimensions (reduces feature quality)."""
+ d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()}
+ rng = torch.Generator().manual_seed(seed)
+ X = d['X'].clone()
+ F_dim = X.shape[1]
+ n_mask = int(mask_frac * F_dim)
+ mask_dims = torch.randperm(F_dim, generator=rng)[:n_mask]
+ X[:, mask_dims] = 0
+ d['X'] = X
+ return d
+
+
+def train_one(cls, common, extra, seed):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ t = cls(**common, **extra)
+ if hasattr(t, 'align_mode'):
+ t.align_mode = 'chain_norm'
+ bv, bt = 0, 0
+ for ep in range(EPOCHS):
+ t.train_step()
+ if ep % 5 == 0:
+ v, te = t.evaluate('val_mask'), t.evaluate('test_mask')
+ if v > bv: bv, bt = v, te
+ del t; torch.cuda.empty_cache()
+ return bt
+
+
+def main():
+ os.makedirs(OUT_DIR, exist_ok=True)
+ data_orig = load_dataset('Cora', device=device)
+ grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10,
+ lr_feedback=0.5, num_probes=64, topo_mode='fixed_A')
+ results = {}
+ L = 6
+
+ perturbations = [
+ ('edge_rewire', [0, 0.1, 0.2, 0.3, 0.5], perturb_edges),
+ ('label_shuffle', [0, 0.1, 0.2, 0.3, 0.5], perturb_labels),
+ ('feature_mask', [0, 0.2, 0.4, 0.6, 0.8], perturb_features),
+ ]
+
+ for ptype, fracs, pfunc in perturbations:
+ print(f"\n=== {ptype} (Cora, GCN, L={L}) ===", flush=True)
+ print(f"{'frac':>6} | {'BP':>8} {'DFA':>8} {'GrAPE':>8} | {'Δ(BP)':>7}", flush=True)
+
+ for frac in fracs:
+ bp_accs, gr_accs = [], []
+ for seed in SEEDS:
+ data = pfunc(data_orig, frac, seed=seed) if frac > 0 else data_orig
+ common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
+ num_layers=L, residual_alpha=0.0, backbone='gcn')
+ bp_accs.append(train_one(BPTrainer, common, {}, seed))
+ gr_accs.append(train_one(GraphGrAPETrainer, common, grape_extra, seed))
+
+ bp, gr = np.mean(bp_accs)*100, np.mean(gr_accs)*100
+ delta = gr - bp
+ key = f"{ptype}|frac={frac}"
+ results[key] = {'bp': float(np.mean(bp_accs)), 'grape': float(np.mean(gr_accs)),
+ 'delta': float(gr - bp), 'frac': frac, 'ptype': ptype}
+ print(f"{frac:>6.1f} | {bp:>7.1f} {'—':>8} {gr:>7.1f} | {delta:>+6.1f}", flush=True)
+
+ with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved to {OUT_DIR}/results.json")
+
+
+if __name__ == '__main__':
+ main()