summaryrefslogtreecommitdiff
path: root/experiments/run_resgcn_20seeds.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/run_resgcn_20seeds.py')
-rw-r--r--experiments/run_resgcn_20seeds.py20
1 files changed, 10 insertions, 10 deletions
diff --git a/experiments/run_resgcn_20seeds.py b/experiments/run_resgcn_20seeds.py
index 995a568..b56bb04 100644
--- a/experiments/run_resgcn_20seeds.py
+++ b/experiments/run_resgcn_20seeds.py
@@ -1,5 +1,5 @@
#!/usr/bin/env python3
-"""Task 7016bd94 Part 1: ResGCN vs GRAFT, 20 seeds, paired t-tests."""
+"""Task 7016bd94 Part 1: ResGCN vs KAFT, 20 seeds, paired t-tests."""
import torch
import numpy as np
@@ -7,7 +7,7 @@ import json
import os
from scipy import stats as scipy_stats
from src.data import load_dataset
-from src.trainers import BPTrainer, GraphGrAPETrainer
+from src.trainers import BPTrainer, KAFTTrainer
from run_deep_baselines import ResGCNTrainer
from run_dblp_depth import load_dblp
@@ -48,7 +48,7 @@ def main():
METHODS = {
'BP': (BPTrainer, {}),
'ResGCN': (ResGCNTrainer, {}),
- 'GRAFT': (GraphGrAPETrainer, grape_extra),
+ 'KAFT': (KAFTTrainer, grape_extra),
}
datasets_cfg = {
@@ -92,9 +92,9 @@ def main():
del data; torch.cuda.empty_cache()
- # Paired t-tests: GRAFT vs ResGCN
+ # Paired t-tests: KAFT vs ResGCN
print("\n" + "=" * 70)
- print("Paired t-tests: GRAFT vs ResGCN (20 seeds)")
+ print("Paired t-tests: KAFT vs ResGCN (20 seeds)")
print("-" * 70)
for ds in ['Cora', 'CiteSeer', 'DBLP']:
@@ -102,7 +102,7 @@ def main():
res_accs = np.array(results[f"{ds}_ResGCN"]['accs'])
gr_accs = np.array(results[f"{ds}_GRAFT"]['accs'])
- # GRAFT vs ResGCN
+ # KAFT vs ResGCN
t_stat, p_val = scipy_stats.ttest_rel(gr_accs, res_accs)
delta = gr_accs.mean() - res_accs.mean()
sig = '***' if p_val < 0.001 else ('**' if p_val < 0.01 else ('*' if p_val < 0.05 else 'ns'))
@@ -112,7 +112,7 @@ def main():
'p_value': float(p_val), 'significant': bool(p_val < 0.05),
}
- # GRAFT vs BP
+ # KAFT vs BP
t2, p2 = scipy_stats.ttest_rel(gr_accs, bp_accs)
d2 = gr_accs.mean() - bp_accs.mean()
sig2 = '***' if p2 < 0.001 else ('**' if p2 < 0.01 else ('*' if p2 < 0.05 else 'ns'))
@@ -134,9 +134,9 @@ def main():
print(f"\n{ds}:")
print(f" BP: {bp_accs.mean():.1f} ± {bp_accs.std():.1f}")
print(f" ResGCN: {res_accs.mean():.1f} ± {res_accs.std():.1f}")
- print(f" GRAFT: {gr_accs.mean():.1f} ± {gr_accs.std():.1f}")
- print(f" GRAFT vs ResGCN: Δ{delta:+.1f}% p={p_val:.6f} {sig}")
- print(f" GRAFT vs BP: Δ{d2:+.1f}% p={p2:.6f} {sig2}")
+ print(f" KAFT: {gr_accs.mean():.1f} ± {gr_accs.std():.1f}")
+ print(f" KAFT vs ResGCN: Δ{delta:+.1f}% p={p_val:.6f} {sig}")
+ print(f" KAFT vs BP: Δ{d2:+.1f}% p={p2:.6f} {sig2}")
with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
json.dump(results, f, indent=2)