summaryrefslogtreecommitdiff
path: root/src/trainers.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/trainers.py')
-rw-r--r--src/trainers.py10
1 files changed, 5 insertions, 5 deletions
diff --git a/src/trainers.py b/src/trainers.py
index 651dffc..7589fc9 100644
--- a/src/trainers.py
+++ b/src/trainers.py
@@ -1,5 +1,5 @@
"""
-Training methods for Graph-GrAPE experiments.
+Training methods for KAFT experiments.
Generalized to L-layer residual GCN.
Methods compared:
@@ -7,7 +7,7 @@ Methods compared:
DFA — Fixed random R, P=I
DFA-GNN — Fixed random R, P=Â^{L-l}
VanillaGrAPE — Aligned R (per layer), P=I
- GraphGrAPE — Aligned R (per layer) + topology P=Â^{L-l}
+ KAFT — Aligned R (per layer) + topology P=Â^{L-l}
"""
import torch
@@ -179,7 +179,7 @@ class BPTrainer:
# ---------------------------------------------------------------------------
class _FeedbackTrainerBase:
- """Shared logic for DFA / GrAPE variants, generalized to L layers."""
+ """Shared logic for DFA / KAFT variants, generalized to L layers."""
def __init__(self, data, hidden_dim, lr, weight_decay,
diffusion_alpha, diffusion_iters,
@@ -608,10 +608,10 @@ class VanillaGrAPETrainer(_FeedbackTrainerBase):
# ---------------------------------------------------------------------------
-# Graph-GrAPE Trainer
+# KAFT Trainer
# ---------------------------------------------------------------------------
-class GraphGrAPETrainer(_FeedbackTrainerBase):
+class KAFTTrainer(_FeedbackTrainerBase):
"""Aligned R per layer + topology P = Â^{min(L-l, max_power)}."""
def __init__(self, data, hidden_dim, lr, weight_decay,