summaryrefslogtreecommitdiff
path: root/experiments/run_ablation_20seeds.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/run_ablation_20seeds.py')
-rw-r--r--experiments/run_ablation_20seeds.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/experiments/run_ablation_20seeds.py b/experiments/run_ablation_20seeds.py
index 61055ed..d2bd434 100644
--- a/experiments/run_ablation_20seeds.py
+++ b/experiments/run_ablation_20seeds.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → VanillaGrAPE → GRAFT."""
+"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → VanillaGrAPE → KAFT."""
import torch
import numpy as np
@@ -7,7 +7,7 @@ import json
import os
from scipy import stats as scipy_stats
from src.data import load_dataset
-from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, VanillaGrAPETrainer, GraphGrAPETrainer
+from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, VanillaGrAPETrainer, KAFTTrainer
device = 'cuda:0'
SEEDS = list(range(20))
@@ -22,7 +22,7 @@ METHODS = {
'diffusion_alpha': 0.5, 'diffusion_iters': 10,
'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A'
}),
- 'GRAFT': (GraphGrAPETrainer, {
+ 'KAFT': (KAFTTrainer, {
'diffusion_alpha': 0.5, 'diffusion_iters': 10,
'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A'
}),