#!/usr/bin/env python3 """Gradient reach with 20 seeds (0-19) for statistical significance. Extends 5-seed results. Loads existing seeds 0-4 data if available, runs seeds 5-19, then combines for final statistics. """ import torch import torch.nn.functional as F import numpy as np import json import os from scipy import stats as scipy_stats from src.data import load_dataset, spmm from src.trainers import BPTrainer, GraphGrAPETrainer device = 'cuda:0' ALL_SEEDS = list(range(20)) EPOCHS = 100 OUT_DIR = 'results/gradient_reach_20seeds' OLD_FILE = 'results/gradient_reach_5seeds/results.json' def measure_one(data, L, backbone, seed): A = data['A_hat'] common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4, num_layers=L, residual_alpha=0.0, backbone=backbone) grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10, lr_feedback=0.5, num_probes=64, topo_mode='fixed_A') torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) bp = BPTrainer(**common) torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) gr = GraphGrAPETrainer(**common, **grape_extra) gr.align_mode = 'chain_norm' for _ in range(EPOCHS): bp.train_step() gr.train_step() # BP gradients bp.optimizer.zero_grad() Z_bp, _ = bp.forward() mask = data['train_mask'] loss = F.cross_entropy(Z_bp[mask], data['y'][mask]) loss.backward() bp_norms = [bp.weights[l].grad.norm().item() for l in range(L)] # GRAFT feedback norms Z_gr, inter = gr.forward() E0, E_bar = gr._output_error(Z_gr) graft_norms = [] for l in range(L - 1): power = min(L - l, gr.max_topo_power) topo_E = E_bar for _ in range(power): topo_E = spmm(A, topo_E) fb = topo_E @ gr.Rs[l] relu_gate = (inter['Zs'][l].detach() > 0).float() graft_norms.append((relu_gate * fb).norm().item()) bp_acc = bp.evaluate('test_mask') gr_acc = gr.evaluate('test_mask') del bp, gr; torch.cuda.empty_cache() return bp_norms, graft_norms, bp_acc, gr_acc def main(): os.makedirs(OUT_DIR, exist_ok=True) data = load_dataset('Cora', device=device) # Load existing per-seed data if available old_per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json') if os.path.exists(old_per_seed_file): with open(old_per_seed_file) as f: per_seed_data = json.load(f) print(f"Loaded existing per-seed data from {old_per_seed_file}") else: per_seed_data = {} configs = [ ('gcn', 6), ('gcn', 10), ('appnp', 6), ('appnp', 10), ] for backbone, L in configs: key = f"{backbone}_L{L}" print(f"\n=== {backbone.upper()} L={L} (20 seeds) ===", flush=True) if key not in per_seed_data: per_seed_data[key] = {} for seed in ALL_SEEDS: seed_key = str(seed) if seed_key in per_seed_data[key]: print(f" seed {seed}: already done, skipping", flush=True) continue bn, gn, ba, ga = measure_one(data, L, backbone, seed) per_seed_data[key][seed_key] = { 'bp_norms': bn, 'graft_norms': gn, 'bp_acc': ba, 'gr_acc': ga } print(f" seed {seed}: BP {ba*100:.1f}% GRAFT {ga*100:.1f}%", flush=True) # Save incrementally with open(old_per_seed_file, 'w') as f: json.dump(per_seed_data, f, indent=2) # Aggregate results results = {} for backbone, L in configs: key = f"{backbone}_L{L}" sd = per_seed_data[key] bp_accs = np.array([sd[str(s)]['bp_acc'] for s in ALL_SEEDS]) * 100 gr_accs = np.array([sd[str(s)]['gr_acc'] for s in ALL_SEEDS]) * 100 t_stat, p_val = scipy_stats.ttest_rel(gr_accs, bp_accs) avg_bp_norms = np.mean([sd[str(s)]['bp_norms'] for s in ALL_SEEDS], axis=0) avg_gr_norms = np.mean([sd[str(s)]['graft_norms'] for s in ALL_SEEDS], axis=0) results[key] = { 'bp_acc_mean': float(bp_accs.mean()), 'bp_acc_std': float(bp_accs.std()), 'gr_acc_mean': float(gr_accs.mean()), 'gr_acc_std': float(gr_accs.std()), 'delta_mean': float((gr_accs - bp_accs).mean()), 'delta_std': float((gr_accs - bp_accs).std()), 't_stat': float(t_stat), 'p_value': float(p_val), 'n_seeds': 20, 'avg_bp_norms': avg_bp_norms.tolist(), 'avg_gr_norms': avg_gr_norms.tolist(), 'bp_accs': bp_accs.tolist(), 'gr_accs': gr_accs.tolist(), } sig = '***' if p_val < 0.001 else ('**' if p_val < 0.01 else ('*' if p_val < 0.05 else 'ns')) print(f"\n {key}:") print(f" BP: {bp_accs.mean():.1f} ± {bp_accs.std():.1f}%") print(f" GRAFT: {gr_accs.mean():.1f} ± {gr_accs.std():.1f}%") print(f" Δ: {(gr_accs-bp_accs).mean():+.1f} ± {(gr_accs-bp_accs).std():.1f}% t={t_stat:.2f} p={p_val:.4f} {sig}") print(f" BP norm L0: {avg_bp_norms[0]:.6f}") print(f" GRAFT norm L0: {avg_gr_norms[0]:.4f}") with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f: json.dump(results, f, indent=2) print(f"\nSaved to {OUT_DIR}/results.json") if __name__ == '__main__': main()