From 174083f4e340afc192d57aacc6422c99530e9e59 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 4 May 2026 23:12:29 -0500 Subject: =?UTF-8?q?Drop=20VanillaGrAPE=20=E2=80=94=20control=20method=20no?= =?UTF-8?q?t=20exposed=20in=20release?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - src/trainers.py: removed class VanillaGrAPETrainer (~30 lines) and cleaned the module-level methods-compared docstring. - experiments/run_ablation_20seeds.py: dropped VanillaGrAPE row from the 4-method ablation grid; sweep is now BP → DFA → DFA-GNN → KAFT. Smoke test: BPTrainer / DFATrainer / DFAGNNTrainer / KAFTTrainer all train cleanly at GCN L=4, Cora, 50 epochs (test acc 77.3 / 76.6 / 78.4 / 79.0). --- experiments/run_ablation_20seeds.py | 8 ++----- src/trainers.py | 42 ++++--------------------------------- 2 files changed, 6 insertions(+), 44 deletions(-) diff --git a/experiments/run_ablation_20seeds.py b/experiments/run_ablation_20seeds.py index d2bd434..e90917d 100644 --- a/experiments/run_ablation_20seeds.py +++ b/experiments/run_ablation_20seeds.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → VanillaGrAPE → KAFT.""" +"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → KAFT.""" import torch import numpy as np @@ -7,7 +7,7 @@ import json import os from scipy import stats as scipy_stats from src.data import load_dataset -from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, VanillaGrAPETrainer, KAFTTrainer +from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, KAFTTrainer device = 'cuda:0' SEEDS = list(range(20)) @@ -18,10 +18,6 @@ METHODS = { 'BP': (BPTrainer, {}), 'DFA': (DFATrainer, {}), 'DFA-GNN': (DFAGNNTrainer, {'topo_mode': 'fixed_A'}), - 'VanillaGrAPE': (VanillaGrAPETrainer, { - 'diffusion_alpha': 0.5, 'diffusion_iters': 10, - 'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A' - }), 'KAFT': (KAFTTrainer, { 'diffusion_alpha': 0.5, 'diffusion_iters': 10, 'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A' diff --git a/src/trainers.py b/src/trainers.py index 7589fc9..9da6a6a 100644 --- a/src/trainers.py +++ b/src/trainers.py @@ -3,11 +3,10 @@ Training methods for KAFT experiments. Generalized to L-layer residual GCN. Methods compared: - BP — Standard backprop GCN - DFA — Fixed random R, P=I - DFA-GNN — Fixed random R, P=Â^{L-l} - VanillaGrAPE — Aligned R (per layer), P=I - KAFT — Aligned R (per layer) + topology P=Â^{L-l} + BP — Standard backprop GCN + DFA — Fixed random R, P=I + DFA-GNN — Fixed random R, P=Â^{L-l} + KAFT — Learned R (per layer) + topology P=Â^{min(L-1-l, K)} """ import torch @@ -574,39 +573,6 @@ class DFAGNNTrainer(_FeedbackTrainerBase): return out @ self.R_fixed -# --------------------------------------------------------------------------- -# Vanilla GrAPE Trainer -# --------------------------------------------------------------------------- - -class VanillaGrAPETrainer(_FeedbackTrainerBase): - """Aligned R per layer, no topology (P=I).""" - - def __init__(self, data, hidden_dim, lr, weight_decay, - lr_feedback=0.5, num_probes=64, - diffusion_alpha=0.5, diffusion_iters=10, - num_layers=2, residual_alpha=0.0, backbone='gcn', **_kw): - super().__init__(data, hidden_dim, lr, weight_decay, - diffusion_alpha, diffusion_iters, - num_layers, residual_alpha, backbone, - _kw.get('use_batchnorm', False), _kw.get('dropout', 0.0)) - self.lr_fb = lr_feedback - self.num_probes = num_probes - # One R per hidden layer - self.Rs = [torch.randn(self.d_out, hidden_dim, device=self.device) * 0.01 - for _ in range(num_layers - 1)] - - def _alignment_step(self, inter): - metrics = {} - for l in range(self.num_layers - 1): - cos = _align_R_layer(self, l) - metrics[f'cos_feat_L{l}'] = cos - metrics['cos_feat'] = sum(metrics.values()) / len(metrics) - return metrics - - def _compute_hidden_feedback(self, l, inter, E_bar): - return E_bar @ self.Rs[l] - - # --------------------------------------------------------------------------- # KAFT Trainer # --------------------------------------------------------------------------- -- cgit v1.2.3