diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-05-04 23:12:29 -0500 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-05-04 23:12:29 -0500 |
| commit | 174083f4e340afc192d57aacc6422c99530e9e59 (patch) | |
| tree | 28932cc5d17ed2869dfe33b470f30cd602555edf /experiments/run_ablation_20seeds.py | |
| parent | ba6ead6d7a41b7ed78bb228181b7262d0c75d2eb (diff) | |
- src/trainers.py: removed class VanillaGrAPETrainer (~30 lines) and
cleaned the module-level methods-compared docstring.
- experiments/run_ablation_20seeds.py: dropped VanillaGrAPE row from the
4-method ablation grid; sweep is now BP → DFA → DFA-GNN → KAFT.
Smoke test: BPTrainer / DFATrainer / DFAGNNTrainer / KAFTTrainer all train
cleanly at GCN L=4, Cora, 50 epochs (test acc 77.3 / 76.6 / 78.4 / 79.0).
Diffstat (limited to 'experiments/run_ablation_20seeds.py')
| -rw-r--r-- | experiments/run_ablation_20seeds.py | 8 |
1 files changed, 2 insertions, 6 deletions
diff --git a/experiments/run_ablation_20seeds.py b/experiments/run_ablation_20seeds.py index d2bd434..e90917d 100644 --- a/experiments/run_ablation_20seeds.py +++ b/experiments/run_ablation_20seeds.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → VanillaGrAPE → KAFT.""" +"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → KAFT.""" import torch import numpy as np @@ -7,7 +7,7 @@ import json import os from scipy import stats as scipy_stats from src.data import load_dataset -from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, VanillaGrAPETrainer, KAFTTrainer +from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, KAFTTrainer device = 'cuda:0' SEEDS = list(range(20)) @@ -18,10 +18,6 @@ METHODS = { 'BP': (BPTrainer, {}), 'DFA': (DFATrainer, {}), 'DFA-GNN': (DFAGNNTrainer, {'topo_mode': 'fixed_A'}), - 'VanillaGrAPE': (VanillaGrAPETrainer, { - 'diffusion_alpha': 0.5, 'diffusion_iters': 10, - 'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A' - }), 'KAFT': (KAFTTrainer, { 'diffusion_alpha': 0.5, 'diffusion_iters': 10, 'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A' |
