summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:12:29 -0500
committerYurenHao0426 <blackhao0426@gmail.com>2026-05-04 23:12:29 -0500
commit174083f4e340afc192d57aacc6422c99530e9e59 (patch)
tree28932cc5d17ed2869dfe33b470f30cd602555edf
parentba6ead6d7a41b7ed78bb228181b7262d0c75d2eb (diff)
Drop VanillaGrAPE — control method not exposed in releaseHEADmain
- src/trainers.py: removed class VanillaGrAPETrainer (~30 lines) and cleaned the module-level methods-compared docstring. - experiments/run_ablation_20seeds.py: dropped VanillaGrAPE row from the 4-method ablation grid; sweep is now BP → DFA → DFA-GNN → KAFT. Smoke test: BPTrainer / DFATrainer / DFAGNNTrainer / KAFTTrainer all train cleanly at GCN L=4, Cora, 50 epochs (test acc 77.3 / 76.6 / 78.4 / 79.0).
-rw-r--r--experiments/run_ablation_20seeds.py8
-rw-r--r--src/trainers.py42
2 files changed, 6 insertions, 44 deletions
diff --git a/experiments/run_ablation_20seeds.py b/experiments/run_ablation_20seeds.py
index d2bd434..e90917d 100644
--- a/experiments/run_ablation_20seeds.py
+++ b/experiments/run_ablation_20seeds.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → VanillaGrAPE → KAFT."""
+"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → KAFT."""
import torch
import numpy as np
@@ -7,7 +7,7 @@ import json
import os
from scipy import stats as scipy_stats
from src.data import load_dataset
-from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, VanillaGrAPETrainer, KAFTTrainer
+from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, KAFTTrainer
device = 'cuda:0'
SEEDS = list(range(20))
@@ -18,10 +18,6 @@ METHODS = {
'BP': (BPTrainer, {}),
'DFA': (DFATrainer, {}),
'DFA-GNN': (DFAGNNTrainer, {'topo_mode': 'fixed_A'}),
- 'VanillaGrAPE': (VanillaGrAPETrainer, {
- 'diffusion_alpha': 0.5, 'diffusion_iters': 10,
- 'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A'
- }),
'KAFT': (KAFTTrainer, {
'diffusion_alpha': 0.5, 'diffusion_iters': 10,
'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A'
diff --git a/src/trainers.py b/src/trainers.py
index 7589fc9..9da6a6a 100644
--- a/src/trainers.py
+++ b/src/trainers.py
@@ -3,11 +3,10 @@ Training methods for KAFT experiments.
Generalized to L-layer residual GCN.
Methods compared:
- BP — Standard backprop GCN
- DFA — Fixed random R, P=I
- DFA-GNN — Fixed random R, P=Â^{L-l}
- VanillaGrAPE — Aligned R (per layer), P=I
- KAFT — Aligned R (per layer) + topology P=Â^{L-l}
+ BP — Standard backprop GCN
+ DFA — Fixed random R, P=I
+ DFA-GNN — Fixed random R, P=Â^{L-l}
+ KAFT — Learned R (per layer) + topology P=Â^{min(L-1-l, K)}
"""
import torch
@@ -575,39 +574,6 @@ class DFAGNNTrainer(_FeedbackTrainerBase):
# ---------------------------------------------------------------------------
-# Vanilla GrAPE Trainer
-# ---------------------------------------------------------------------------
-
-class VanillaGrAPETrainer(_FeedbackTrainerBase):
- """Aligned R per layer, no topology (P=I)."""
-
- def __init__(self, data, hidden_dim, lr, weight_decay,
- lr_feedback=0.5, num_probes=64,
- diffusion_alpha=0.5, diffusion_iters=10,
- num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw):
- super().__init__(data, hidden_dim, lr, weight_decay,
- diffusion_alpha, diffusion_iters,
- num_layers, residual_alpha, backbone,
- _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0))
- self.lr_fb = lr_feedback
- self.num_probes = num_probes
- # One R per hidden layer
- self.Rs = [torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01
- for _ in range(num_layers - 1)]
-
- def _alignment_step(self, inter):
- metrics = {}
- for l in range(self.num_layers - 1):
- cos = _align_R_layer(self, l)
- metrics[f'cos_feat_L{l}'] = cos
- metrics['cos_feat'] = sum(metrics.values()) / len(metrics)
- return metrics
-
- def _compute_hidden_feedback(self, l, inter, E_bar):
- return E_bar @ self.Rs[l]
-
-
-# ---------------------------------------------------------------------------
# KAFT Trainer
# ---------------------------------------------------------------------------