diff options
Diffstat (limited to 'experiments/run_ablation_20seeds.py')
| -rw-r--r-- | experiments/run_ablation_20seeds.py | 6 |
1 files changed, 3 insertions, 3 deletions
diff --git a/experiments/run_ablation_20seeds.py b/experiments/run_ablation_20seeds.py index 61055ed..d2bd434 100644 --- a/experiments/run_ablation_20seeds.py +++ b/experiments/run_ablation_20seeds.py @@ -1,5 +1,5 @@ #!/usr/bin/env python3 -"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → VanillaGrAPE → GRAFT.""" +"""Ablation study with 20 seeds: BP → DFA → DFA-GNN → VanillaGrAPE → KAFT.""" import torch import numpy as np @@ -7,7 +7,7 @@ import json import os from scipy import stats as scipy_stats from src.data import load_dataset -from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, VanillaGrAPETrainer, GraphGrAPETrainer +from src.trainers import BPTrainer, DFATrainer, DFAGNNTrainer, VanillaGrAPETrainer, KAFTTrainer device = 'cuda:0' SEEDS = list(range(20)) @@ -22,7 +22,7 @@ METHODS = { 'diffusion_alpha': 0.5, 'diffusion_iters': 10, 'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A' }), - 'GRAFT': (GraphGrAPETrainer, { + 'KAFT': (KAFTTrainer, { 'diffusion_alpha': 0.5, 'diffusion_iters': 10, 'lr_feedback': 0.5, 'num_probes': 64, 'topo_mode': 'fixed_A' }), |
