diff options
Diffstat (limited to 'src/trainers.py')
| -rw-r--r-- | src/trainers.py | 10 |
1 files changed, 5 insertions, 5 deletions
diff --git a/src/trainers.py b/src/trainers.py index 651dffc..7589fc9 100644 --- a/src/trainers.py +++ b/src/trainers.py @@ -1,5 +1,5 @@ """ -Training methods for Graph-GrAPE experiments. +Training methods for KAFT experiments. Generalized to L-layer residual GCN. Methods compared: @@ -7,7 +7,7 @@ Methods compared: DFA — Fixed random R, P=I DFA-GNN — Fixed random R, P=Â^{L-l} VanillaGrAPE — Aligned R (per layer), P=I - GraphGrAPE — Aligned R (per layer) + topology P=Â^{L-l} + KAFT — Aligned R (per layer) + topology P=Â^{L-l} """ import torch @@ -179,7 +179,7 @@ class BPTrainer: # --------------------------------------------------------------------------- class _FeedbackTrainerBase: - """Shared logic for DFA / GrAPE variants, generalized to L layers.""" + """Shared logic for DFA / KAFT variants, generalized to L layers.""" def __init__(self, data, hidden_dim, lr, weight_decay, diffusion_alpha, diffusion_iters, @@ -608,10 +608,10 @@ class VanillaGrAPETrainer(_FeedbackTrainerBase): # --------------------------------------------------------------------------- -# Graph-GrAPE Trainer +# KAFT Trainer # --------------------------------------------------------------------------- -class GraphGrAPETrainer(_FeedbackTrainerBase): +class KAFTTrainer(_FeedbackTrainerBase): """Aligned R per layer + topology P = Â^{min(L-l, max_power)}.""" def __init__(self, data, hidden_dim, lr, weight_decay, |
