summaryrefslogtreecommitdiff
path: root/experiments/run_dblp_depth.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/run_dblp_depth.py')
-rw-r--r--experiments/run_dblp_depth.py8
1 files changed, 3 insertions, 5 deletions
diff --git a/experiments/run_dblp_depth.py b/experiments/run_dblp_depth.py
index d63b94a..f91440a 100644
--- a/experiments/run_dblp_depth.py
+++ b/experiments/run_dblp_depth.py
@@ -11,8 +11,7 @@ import os
import time
from torch_geometric.datasets import CitationFull
from src.data import build_normalized_adj, build_row_normalized_adj, spmm, precompute_traces
-from src.trainers import BPTrainer, DFATrainer, GraphGrAPETrainer
-from benchmark_efficient import GraphGrAPEEfficient
+from src.trainers import BPTrainer, DFATrainer, KAFTTrainer
device = 'cuda:0'
SEEDS = [0, 1, 2, 3, 4]
@@ -108,7 +107,7 @@ def main():
row = {}
for mname, cls, extra in [('BP', BPTrainer, {}),
('DFA', DFATrainer, dict(diffusion_alpha=0.5, diffusion_iters=10)),
- ('GrAPE', GraphGrAPETrainer, grape_extra)]:
+ ('GrAPE', KAFTTrainer, grape_extra)]:
accs = [train_one(cls, common, extra, s) for s in SEEDS]
row[mname] = {'mean': float(np.mean(accs)), 'std': float(np.std(accs))}
results[key] = row
@@ -123,7 +122,6 @@ def main():
common = dict(data=dblp, hidden_dim=64, lr=0.01, weight_decay=5e-4,
num_layers=L, residual_alpha=0.0, backbone=bb)
bp_ms = time_method(BPTrainer, common, {})
- eff_ms = time_method(GraphGrAPEEfficient, common,
dict(lr_feedback=0.5, num_probes=64, max_topo_power=3,
diff_alpha=0.5, align_every=10))
key = f"DBLP_eff|{bb}|L={L}"
@@ -146,7 +144,7 @@ def main():
key = f"{ds_name}|{bb}|L={L}|lr=0.01"
row = {}
for mname, cls, extra in [('BP', BPTrainer, {}),
- ('GrAPE', GraphGrAPETrainer, grape_extra)]:
+ ('GrAPE', KAFTTrainer, grape_extra)]:
accs = [train_one(cls, common, extra, s) for s in SEEDS]
row[mname] = {'mean': float(np.mean(accs)), 'std': float(np.std(accs))}
results[key] = row