#!/usr/bin/env python3 """ Cora perturbation experiment: directly test causal factors. Three perturbation types: 1. Edge rewiring: destroy community structure 2. Label shuffling: reduce homophily 3. Feature masking: reduce feature quality """ import torch import torch.nn.functional as F import numpy as np import json import os from src.data import load_dataset, build_normalized_adj, build_row_normalized_adj from src.trainers import BPTrainer, DFATrainer, KAFTTrainer device = 'cuda:0' SEEDS = [0, 1, 2, 3, 4] EPOCHS = 200 OUT_DIR = 'results/cora_perturbation' def perturb_edges(data, rewire_frac, seed=0): """Randomly rewire a fraction of edges (destroys community structure).""" d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()} rng = torch.Generator().manual_seed(seed) A = d['A_hat'] idx = A.indices() vals = A.values() N = d['num_nodes'] n_rewire = int(rewire_frac * idx.shape[1]) if n_rewire > 0: perm = torch.randperm(idx.shape[1], generator=rng)[:n_rewire].to(idx.device) new_targets = torch.randint(0, N, (n_rewire,), generator=rng).to(idx.device) idx_new = idx.clone() idx_new[1, perm] = new_targets A_new = torch.sparse_coo_tensor(idx_new, vals, (N, N)).coalesce() d['A_hat'] = A_new d['A_row'] = A_new # simplified d['A_row_T'] = A_new return d def perturb_labels(data, shuffle_frac, seed=0): """Shuffle a fraction of labels (reduces homophily).""" d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()} rng = torch.Generator().manual_seed(seed) y = d['y'].clone() N = len(y) n_shuffle = int(shuffle_frac * N) perm = torch.randperm(N, generator=rng)[:n_shuffle] shuffled = y[perm][torch.randperm(n_shuffle, generator=rng)] y[perm] = shuffled d['y'] = y return d def perturb_features(data, mask_frac, seed=0): """Zero out a fraction of feature dimensions (reduces feature quality).""" d = {k: v.clone() if isinstance(v, torch.Tensor) else v for k, v in data.items()} rng = torch.Generator().manual_seed(seed) X = d['X'].clone() F_dim = X.shape[1] n_mask = int(mask_frac * F_dim) mask_dims = torch.randperm(F_dim, generator=rng)[:n_mask] X[:, mask_dims] = 0 d['X'] = X return d def train_one(cls, common, extra, seed): torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) t = cls(**common, **extra) if hasattr(t, 'align_mode'): t.align_mode = 'chain_norm' bv, bt = 0, 0 for ep in range(EPOCHS): t.train_step() if ep % 5 == 0: v, te = t.evaluate('val_mask'), t.evaluate('test_mask') if v > bv: bv, bt = v, te del t; torch.cuda.empty_cache() return bt def main(): os.makedirs(OUT_DIR, exist_ok=True) data_orig = load_dataset('Cora', device=device) grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10, lr_feedback=0.5, num_probes=64, topo_mode='fixed_A') results = {} L = 6 perturbations = [ ('edge_rewire', [0, 0.1, 0.2, 0.3, 0.5], perturb_edges), ('label_shuffle', [0, 0.1, 0.2, 0.3, 0.5], perturb_labels), ('feature_mask', [0, 0.2, 0.4, 0.6, 0.8], perturb_features), ] for ptype, fracs, pfunc in perturbations: print(f"\n=== {ptype} (Cora, GCN, L={L}) ===", flush=True) print(f"{'frac':>6} | {'BP':>8} {'DFA':>8} {'GrAPE':>8} | {'Δ(BP)':>7}", flush=True) for frac in fracs: bp_accs, gr_accs = [], [] for seed in SEEDS: data = pfunc(data_orig, frac, seed=seed) if frac > 0 else data_orig common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, num_layers=L, residual_alpha=0.0, backbone='gcn') bp_accs.append(train_one(BPTrainer, common, {}, seed)) gr_accs.append(train_one(KAFTTrainer, common, grape_extra, seed)) bp, gr = np.mean(bp_accs)*100, np.mean(gr_accs)*100 delta = gr - bp key = f"{ptype}|frac={frac}" results[key] = {'bp': float(np.mean(bp_accs)), 'grape': float(np.mean(gr_accs)), 'delta': float(gr - bp), 'frac': frac, 'ptype': ptype} print(f"{frac:>6.1f} | {bp:>7.1f} {'—':>8} {gr:>7.1f} | {delta:>+6.1f}", flush=True) with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved to {OUT_DIR}/results.json") if __name__ == '__main__': main()