summaryrefslogtreecommitdiff
path: root/experiments/run_grad_reach_20seeds.py
diff options
context:
space:
mode:
Diffstat (limited to 'experiments/run_grad_reach_20seeds.py')
-rw-r--r--experiments/run_grad_reach_20seeds.py22
1 files changed, 11 insertions, 11 deletions
diff --git a/experiments/run_grad_reach_20seeds.py b/experiments/run_grad_reach_20seeds.py
index b5ad53b..b9ac8b9 100644
--- a/experiments/run_grad_reach_20seeds.py
+++ b/experiments/run_grad_reach_20seeds.py
@@ -12,7 +12,7 @@ import json
import os
from scipy import stats as scipy_stats
from src.data import load_dataset, spmm
-from src.trainers import BPTrainer, GraphGrAPETrainer
+from src.trainers import BPTrainer, KAFTTrainer
device = 'cuda:0'
ALL_SEEDS = list(range(20))
@@ -31,7 +31,7 @@ def measure_one(data, L, backbone, seed):
torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
bp = BPTrainer(**common)
torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
- gr = GraphGrAPETrainer(**common, **grape_extra)
+ gr = KAFTTrainer(**common, **grape_extra)
gr.align_mode = 'chain_norm'
for _ in range(EPOCHS):
@@ -46,10 +46,10 @@ def measure_one(data, L, backbone, seed):
loss.backward()
bp_norms = [bp.weights[l].grad.norm().item() for l in range(L)]
- # GRAFT feedback norms
+ # KAFT feedback norms
Z_gr, inter = gr.forward()
E0, E_bar = gr._output_error(Z_gr)
- graft_norms = []
+ kaft_norms = []
for l in range(L - 1):
power = min(L - l, gr.max_topo_power)
topo_E = E_bar
@@ -57,13 +57,13 @@ def measure_one(data, L, backbone, seed):
topo_E = spmm(A, topo_E)
fb = topo_E @ gr.Rs[l]
relu_gate = (inter['Zs'][l].detach() > 0).float()
- graft_norms.append((relu_gate * fb).norm().item())
+ kaft_norms.append((relu_gate * fb).norm().item())
bp_acc = bp.evaluate('test_mask')
gr_acc = gr.evaluate('test_mask')
del bp, gr; torch.cuda.empty_cache()
- return bp_norms, graft_norms, bp_acc, gr_acc
+ return bp_norms, kaft_norms, bp_acc, gr_acc
def main():
@@ -101,10 +101,10 @@ def main():
bn, gn, ba, ga = measure_one(data, L, backbone, seed)
per_seed_data[key][seed_key] = {
- 'bp_norms': bn, 'graft_norms': gn,
+ 'bp_norms': bn, 'kaft_norms': gn,
'bp_acc': ba, 'gr_acc': ga
}
- print(f" seed {seed}: BP {ba*100:.1f}% GRAFT {ga*100:.1f}%", flush=True)
+ print(f" seed {seed}: BP {ba*100:.1f}% KAFT {ga*100:.1f}%", flush=True)
# Save incrementally
with open(old_per_seed_file, 'w') as f:
@@ -121,7 +121,7 @@ def main():
t_stat, p_val = scipy_stats.ttest_rel(gr_accs, bp_accs)
avg_bp_norms = np.mean([sd[str(s)]['bp_norms'] for s in ALL_SEEDS], axis=0)
- avg_gr_norms = np.mean([sd[str(s)]['graft_norms'] for s in ALL_SEEDS], axis=0)
+ avg_gr_norms = np.mean([sd[str(s)]['kaft_norms'] for s in ALL_SEEDS], axis=0)
results[key] = {
'bp_acc_mean': float(bp_accs.mean()),
@@ -142,10 +142,10 @@ def main():
sig = '***' if p_val < 0.001 else ('**' if p_val < 0.01 else ('*' if p_val < 0.05 else 'ns'))
print(f"\n {key}:")
print(f" BP: {bp_accs.mean():.1f} ± {bp_accs.std():.1f}%")
- print(f" GRAFT: {gr_accs.mean():.1f} ± {gr_accs.std():.1f}%")
+ print(f" KAFT: {gr_accs.mean():.1f} ± {gr_accs.std():.1f}%")
print(f" Δ: {(gr_accs-bp_accs).mean():+.1f} ± {(gr_accs-bp_accs).std():.1f}% t={t_stat:.2f} p={p_val:.4f} {sig}")
print(f" BP norm L0: {avg_bp_norms[0]:.6f}")
- print(f" GRAFT norm L0: {avg_gr_norms[0]:.4f}")
+ print(f" KAFT norm L0: {avg_gr_norms[0]:.4f}")
with open(os.path.join(OUT_DIR, 'results.json'), 'w') as f:
json.dump(results, f, indent=2)