summaryrefslogtreecommitdiff
path: root/experiments/run_combo_20seeds.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/run_combo_20seeds.py')
-rw-r--r--experiments/run_combo_20seeds.py454
1 files changed, 454 insertions, 0 deletions
diff --git a/experiments/run_combo_20seeds.py b/experiments/run_combo_20seeds.py
new file mode 100644
index 0000000..1598964
--- /dev/null
+++ b/experiments/run_combo_20seeds.py
@@ -0,0 +1,454 @@
+#!/usr/bin/env python3
+"""Task 2ceadaa7: GRAFT + Forward Tricks combo experiments (20 seeds).
+
+Combos: GRAFT+ResGCN, GRAFT+DropEdge, GRAFT+PairNorm, GRAFT+JKNet
+Each compared to: BP, forward_trick_only, GRAFT_only, combo
+"""
+
+import torch
+import torch.nn.functional as F
+import numpy as np
+import json
+import os
+from scipy import stats as scipy_stats
+from src.data import load_dataset, spmm, build_normalized_adj
+from src.trainers import BPTrainer, GraphGrAPETrainer, _FeedbackTrainerBase
+from run_deep_baselines import ResGCNTrainer, JKNetTrainer
+from run_dropedge import BPDropEdgeTrainer
+from run_pairnorm_baseline import BPPairNormTrainer, pairnorm
+from run_dblp_depth import load_dblp
+
+device = 'cuda:0'
+SEEDS = list(range(20))
+EPOCHS = 200
+OUT_DIR = 'results/combo_20seeds'
+
+grape_extra = dict(diffusion_alpha=0.5, diffusion_iters=10,
+ lr_feedback=0.5, num_probes=64, topo_mode='fixed_A')
+
+
+# ═══════════════════════════════════════════════════════════════════════════
+# GRAFT + ResGCN combo (fixed version)
+# ═══════════════════════════════════════════════════════════════════════════
+class GRAFTResGCN(GraphGrAPETrainer):
+ """GRAFT backward + ResGCN forward (skip connections)."""
+
+ def forward(self):
+ X = self.data['X']
+ H = X
+ H0 = None
+ Hs, Zs = [], []
+
+ for l in range(self.num_layers):
+ Z = self._graph_conv(H, self.weights[l], l)
+ Zs.append(Z)
+ if l < self.num_layers - 1:
+ H_new = F.relu(Z)
+ if H_new.size(1) == H.size(1):
+ H = H + H_new
+ else:
+ H = H_new
+ Hs.append(H)
+ if l == 0:
+ H0 = H
+ else:
+ return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0}
+ return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0}
+
+
+# ═══════════════════════════════════════════════════════════════════════════
+# GRAFT + DropEdge combo
+# ═══════════════════════════════════════════════════════════════════════════
+class GRAFTDropEdge(GraphGrAPETrainer):
+ """GRAFT backward + DropEdge forward (random edge dropping)."""
+
+ def __init__(self, *args, drop_rate=0.5, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.drop_rate = drop_rate
+ self._A_hat_orig = self.data['A_hat']
+ self._edge_index_orig = self._A_hat_orig.indices()
+ self._edge_values_orig = self._A_hat_orig.values()
+ self._N = self._A_hat_orig.size(0)
+
+ def _drop_edges(self):
+ if not self._training or self.drop_rate <= 0:
+ return self._A_hat_orig
+ mask = torch.rand(self._edge_values_orig.size(0),
+ device=self._edge_values_orig.device) > self.drop_rate
+ new_vals = self._edge_values_orig * mask.float() / (1 - self.drop_rate)
+ return torch.sparse_coo_tensor(
+ self._edge_index_orig, new_vals, (self._N, self._N)
+ ).coalesce()
+
+ def forward(self):
+ # DropEdge only in forward pass, GRAFT backward uses original A_hat
+ self.data['A_hat'] = self._drop_edges()
+ result = super().forward() # uses GraphGrAPETrainer.forward()
+ self.data['A_hat'] = self._A_hat_orig
+ return result
+
+ def evaluate(self, mask_name='test_mask'):
+ self.data['A_hat'] = self._A_hat_orig
+ return super().evaluate(mask_name)
+
+
+# ═══════════════════════════════════════════════════════════════════════════
+# GRAFT + PairNorm combo
+# ═══════════════════════════════════════════════════════════════════════════
+class GRAFTPairNorm(GraphGrAPETrainer):
+ """GRAFT backward + PairNorm forward (center & scale normalization)."""
+
+ def __init__(self, *args, pn_scale=1.0, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.pn_scale = pn_scale
+
+ def forward(self):
+ X = self.data['X']
+ H = X
+ H0 = None
+ Hs, Zs = [], []
+
+ if self.backbone == 'appnp':
+ for l in range(self.num_layers):
+ Z = H @ self.weights[l]
+ Zs.append(Z)
+ if l < self.num_layers - 1:
+ H = F.relu(Z)
+ H = pairnorm(H, self.pn_scale)
+ Hs.append(H)
+ if l == 0:
+ H0 = H
+ else:
+ Z = self._appnp_propagate(Z)
+ Zs[-1] = Z
+ return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0}
+
+ for l in range(self.num_layers):
+ Z = self._graph_conv(H, self.weights[l], l)
+ Zs.append(Z)
+ if l < self.num_layers - 1:
+ H = F.relu(Z)
+ H = pairnorm(H, self.pn_scale)
+ Hs.append(H)
+ if l == 0:
+ H0 = H
+ else:
+ return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0}
+ return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0}
+
+
+# ═══════════════════════════════════════════════════════════════════════════
+# GRAFT + JKNet combo
+# ═══════════════════════════════════════════════════════════════════════════
+class GRAFTJKNet(GraphGrAPETrainer):
+ """GRAFT backward + JKNet forward (jumping knowledge max-pool).
+
+ Note: JKNet changes the output architecture. We max-pool hidden layers
+ and project to num_classes. GRAFT backward operates on hidden layers
+ as usual; the JK projection is treated as the output layer.
+ """
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ # JK projection: hidden_dim -> num_classes
+ self.jk_proj = torch.randn(self.hidden_dim, self.d_out,
+ device=self.device) * 0.01
+ # Add to Adam state
+ self._adam.append({'m': torch.zeros_like(self.jk_proj),
+ 'v': torch.zeros_like(self.jk_proj)})
+
+ def forward(self):
+ X = self.data['X']
+ H = X
+ H0 = None
+ Hs, Zs = [], []
+
+ for l in range(self.num_layers):
+ Z = self._graph_conv(H, self.weights[l], l)
+ Zs.append(Z)
+ if l < self.num_layers - 1:
+ H = F.relu(Z)
+ Hs.append(H)
+ if l == 0:
+ H0 = H
+
+ # JK max-pool over hidden layers
+ if Hs:
+ stacked = torch.stack(Hs, dim=0) # (L-1, N, d)
+ jk_repr = stacked.max(dim=0)[0] # (N, d)
+ Z_out = jk_repr @ self.jk_proj
+ # Override Hs[-1] for backward compatibility with _update_weights
+ # which uses Hs[-1] as input to output layer
+ Hs_for_backward = list(Hs)
+ Hs_for_backward[-1] = jk_repr
+ return Z_out, {'Hs': Hs_for_backward, 'Zs': Zs, 'H0': H0}
+ else:
+ return Z, {'Hs': Hs, 'Zs': Zs, 'H0': H0}
+
+ def _update_weights(self, inter, E0, deltas):
+ """Override to handle JK projection separately."""
+ # Update hidden layers using GRAFT feedback as usual
+ X = self.data['X']
+ Hs = inter['Hs']
+ H0 = inter['H0']
+
+ grads = []
+ for l in range(self.num_layers):
+ if l == self.num_layers - 1:
+ # Skip the original output layer — JK projection replaces it
+ # But still compute gradient for W_L (unused in JK, but keeps indexing)
+ H_prev = Hs[-1] if Hs else X
+ g = H_prev.t() @ self._graph_conv_T(E0, l)
+ else:
+ if l == 0:
+ H_in = X
+ else:
+ H_prev = Hs[l - 1]
+ if self.residual_alpha > 0 and H0 is not None:
+ H_in = (1 - self.residual_alpha) * H_prev + self.residual_alpha * H0
+ else:
+ H_in = H_prev
+ g = H_in.t() @ self._graph_conv_T(deltas[l], l)
+ grads.append(g)
+
+ # Update JK projection: grad = jk_repr.T @ E0
+ jk_repr = Hs[-1] if Hs else X
+ jk_grad = jk_repr.t() @ E0
+
+ if self._use_adam:
+ self._adam_t += 1
+ for i in range(self.num_layers):
+ self.weights[i] = self.weights[i] - self._adam_step(i, grads[i])
+ # Update jk_proj with Adam (use last index)
+ jk_idx = len(self._adam) - 1
+ s = self._adam[jk_idx]
+ b1, b2, eps = self._adam_beta1, self._adam_beta2, self._adam_eps
+ t = self._adam_t
+ s['m'] = b1 * s['m'] + (1 - b1) * jk_grad
+ s['v'] = b2 * s['v'] + (1 - b2) * jk_grad ** 2
+ m_hat = s['m'] / (1 - b1 ** t)
+ v_hat = s['v'] / (1 - b2 ** t)
+ self.jk_proj = self.jk_proj - self.lr * (m_hat / (v_hat.sqrt() + eps) + self.wd * self.jk_proj)
+ else:
+ for i in range(self.num_layers):
+ self.weights[i] = self.weights[i] - self.lr * (grads[i] + self.wd * self.weights[i])
+ self.jk_proj = self.jk_proj - self.lr * (jk_grad + self.wd * self.jk_proj)
+
+
+# ═══════════════════════════════════════════════════════════════════════════
+# Training
+# ═══════════════════════════════════════════════════════════════════════════
+
+def train_one(cls, common, extra, seed):
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ t = cls(**common, **extra)
+ if hasattr(t, 'align_mode'):
+ t.align_mode = 'chain_norm'
+ bv, bt = 0, 0
+ for ep in range(EPOCHS):
+ t.train_step()
+ if ep % 5 == 0:
+ v = t.evaluate('val_mask')
+ te = t.evaluate('test_mask')
+ if v > bv: bv, bt = v, te
+ del t; torch.cuda.empty_cache()
+ return bt
+
+
+def main():
+ os.makedirs(OUT_DIR, exist_ok=True)
+ per_seed_file = os.path.join(OUT_DIR, 'per_seed_data.json')
+ if os.path.exists(per_seed_file):
+ with open(per_seed_file) as f:
+ per_seed_data = json.load(f)
+ else:
+ per_seed_data = {}
+
+ # Reuse existing per-seed data from other experiments
+ # BP, ResGCN, GRAFT from resgcn_20seeds
+ try:
+ with open('results/resgcn_20seeds/per_seed_data.json') as f:
+ resgcn_cache = json.load(f)
+ except:
+ resgcn_cache = {}
+
+ # DropEdge from dropedge_20seeds
+ try:
+ with open('results/dropedge_20seeds/per_seed_data.json') as f:
+ de_cache = json.load(f)
+ except:
+ de_cache = {}
+
+ # PairNorm from pairnorm_extended
+ try:
+ with open('results/pairnorm_extended/per_seed_data.json') as f:
+ pn_cache = json.load(f)
+ except:
+ pn_cache = {}
+
+ METHODS = {
+ 'BP': (BPTrainer, {}),
+ 'ResGCN': (ResGCNTrainer, {}),
+ 'DropEdge': (BPDropEdgeTrainer, {'drop_rate': 0.5}),
+ 'PairNorm': (BPPairNormTrainer, {'pn_scale': 1.0}),
+ 'JKNet': (JKNetTrainer, {}),
+ 'GRAFT': (GraphGrAPETrainer, grape_extra),
+ 'GRAFT+ResGCN': (GRAFTResGCN, grape_extra),
+ 'GRAFT+DropEdge': (GRAFTDropEdge, {**grape_extra, 'drop_rate': 0.5}),
+ 'GRAFT+PairNorm': (GRAFTPairNorm, {**grape_extra, 'pn_scale': 1.0}),
+ 'GRAFT+JKNet': (GRAFTJKNet, grape_extra),
+ }
+
+ datasets_cfg = {
+ 'Cora': lambda: load_dataset('Cora', device=device),
+ 'CiteSeer': lambda: load_dataset('CiteSeer', device=device),
+ 'DBLP': lambda: load_dblp(),
+ }
+
+ for ds_name, loader in datasets_cfg.items():
+ data = loader()
+ common = dict(data=data, hidden_dim=64, lr=0.01, weight_decay=5e-4,
+ num_layers=6, residual_alpha=0.0, backbone='gcn')
+
+ for mname, (cls, extra) in METHODS.items():
+ key = f"{ds_name}_{mname}"
+
+ if key not in per_seed_data:
+ per_seed_data[key] = {}
+
+ print(f"\n=== {key} (20 seeds) ===", flush=True)
+
+ for seed in SEEDS:
+ sk = str(seed)
+ if sk in per_seed_data[key]:
+ print(f" seed {seed}: cached", flush=True)
+ continue
+
+ # Try to pull from existing caches
+ cached = None
+ if mname == 'BP' and f"{ds_name}_BP" in resgcn_cache:
+ cached = resgcn_cache[f"{ds_name}_BP"].get(sk)
+ elif mname == 'ResGCN' and f"{ds_name}_ResGCN" in resgcn_cache:
+ cached = resgcn_cache[f"{ds_name}_ResGCN"].get(sk)
+ elif mname == 'GRAFT' and f"{ds_name}_GRAFT" in resgcn_cache:
+ cached = resgcn_cache[f"{ds_name}_GRAFT"].get(sk)
+ elif mname == 'DropEdge':
+ de_key = f"{ds_name}_gcn_L6"
+ if de_key in de_cache and sk in de_cache[de_key]:
+ cached = de_cache[de_key][sk].get('de05')
+ elif mname == 'PairNorm':
+ pn_key = f"{ds_name}_gcn_L6_PN"
+ if pn_key in pn_cache and sk in pn_cache[pn_key]:
+ cached = pn_cache[pn_key][sk]
+
+ if cached is not None:
+ per_seed_data[key][sk] = cached
+ print(f" seed {seed}: from cache ({cached*100:.1f}%)", flush=True)
+ else:
+ try:
+ acc = train_one(cls, common, extra, seed)
+ per_seed_data[key][sk] = acc
+ print(f" seed {seed}: {acc*100:.1f}%", flush=True)
+ except Exception as e:
+ print(f" seed {seed}: FAILED - {e}", flush=True)
+ per_seed_data[key][sk] = 0.0
+
+ # Save after each seed
+ with open(per_seed_file, 'w') as f:
+ json.dump(per_seed_data, f, indent=2)
+
+ del data; torch.cuda.empty_cache()
+
+ # ═══════════════════════════════════════════════════════════════════════
+ # Analysis
+ # ═══════════════════════════════════════════════════════════════════════
+ results = {}
+ print("\n" + "=" * 120)
+ print("FULL RESULTS TABLE")
+ print("=" * 120)
+
+ for ds in ['Cora', 'CiteSeer', 'DBLP']:
+ print(f"\n--- {ds} GCN L=6 lr=0.01 ---")
+ print(f"{'Method':<18} {'Mean±Std':>12} {'vs GRAFT':>18} {'vs FwdTrick':>18}")
+ print("-" * 70)
+
+ # Get GRAFT accs for comparison
+ gr_accs = np.array([per_seed_data[f"{ds}_GRAFT"][str(s)] for s in SEEDS]) * 100
+
+ for mname in ['BP', 'ResGCN', 'DropEdge', 'PairNorm', 'JKNet',
+ 'GRAFT', 'GRAFT+ResGCN', 'GRAFT+DropEdge', 'GRAFT+PairNorm', 'GRAFT+JKNet']:
+ key = f"{ds}_{mname}"
+ if key not in per_seed_data or len(per_seed_data[key]) < 20:
+ print(f" {mname:<16} MISSING ({len(per_seed_data.get(key, {}))} seeds)")
+ continue
+
+ accs = np.array([per_seed_data[key][str(s)] for s in SEEDS]) * 100
+ m, s = accs.mean(), accs.std()
+
+ results[key] = {'mean': float(m), 'std': float(s), 'accs': accs.tolist()}
+
+ # Paired t-test vs GRAFT
+ if mname != 'GRAFT':
+ t_stat, p_val = scipy_stats.ttest_rel(accs, gr_accs)
+ delta = m - gr_accs.mean()
+ sig = '***' if p_val < 0.001 else ('**' if p_val < 0.01 else ('*' if p_val < 0.05 else 'ns'))
+ vs_graft = f"Δ{delta:+.1f} p={p_val:.4f}{sig}"
+ results[key]['vs_GRAFT'] = {
+ 'delta': float(delta), 'p_value': float(p_val),
+ 'significant': bool(p_val < 0.05)
+ }
+ else:
+ vs_graft = "—"
+
+ # Paired t-test vs forward trick only
+ fwd_map = {
+ 'GRAFT+ResGCN': 'ResGCN', 'GRAFT+DropEdge': 'DropEdge',
+ 'GRAFT+PairNorm': 'PairNorm', 'GRAFT+JKNet': 'JKNet'
+ }
+ if mname in fwd_map:
+ fwd_key = f"{ds}_{fwd_map[mname]}"
+ if fwd_key in per_seed_data and len(per_seed_data[fwd_key]) >= 20:
+ fwd_accs = np.array([per_seed_data[fwd_key][str(s)] for s in SEEDS]) * 100
+ t2, p2 = scipy_stats.ttest_rel(accs, fwd_accs)
+ d2 = m - fwd_accs.mean()
+ s2 = '***' if p2 < 0.001 else ('**' if p2 < 0.01 else ('*' if p2 < 0.05 else 'ns'))
+ vs_fwd = f"Δ{d2:+.1f} p={p2:.4f}{s2}"
+ results[key][f'vs_{fwd_map[mname]}'] = {
+ 'delta': float(d2), 'p_value': float(p2),
+ 'significant': bool(p2 < 0.05)
+ }
+ else:
+ vs_fwd = "N/A"
+ else:
+ vs_fwd = ""
+
+ print(f" {mname:<16} {m:>5.1f}±{s:<5.1f} {vs_graft:>18} {vs_fwd:>18}")
+
+ # Summary: which combos are additive?
+ print("\n" + "=" * 80)
+ print("COMBO ADDITIVITY SUMMARY")
+ print("=" * 80)
+ for ds in ['Cora', 'CiteSeer', 'DBLP']:
+ print(f"\n{ds}:")
+ gr_m = results.get(f"{ds}_GRAFT", {}).get('mean', 0)
+ for combo, fwd in [('GRAFT+ResGCN', 'ResGCN'), ('GRAFT+DropEdge', 'DropEdge'),
+ ('GRAFT+PairNorm', 'PairNorm'), ('GRAFT+JKNet', 'JKNet')]:
+ ck = f"{ds}_{combo}"
+ fk = f"{ds}_{fwd}"
+ if ck in results and fk in results:
+ c_m = results[ck]['mean']
+ f_m = results[fk]['mean']
+ vs_gr = results[ck].get('vs_GRAFT', {})
+ vs_fw = results[ck].get(f'vs_{fwd}', {})
+ better_than_both = c_m > gr_m and c_m > f_m
+ marker = "✓ ADDITIVE" if better_than_both else "✗ not additive"
+ print(f" {combo}: {c_m:.1f} | GRAFT={gr_m:.1f} | {fwd}={f_m:.1f} → {marker}")
+
+ # Save
+ with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
+ json.dump(results, f, indent=2)
+ print(f"\nSaved to {OUT_DIR}/results.json")
+
+
+if __name__ == '__main__':
+ main()