summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 10:22:38 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 10:22:38 -0500
commit68c0e7e16f2ea59bcfec2b8b83673da8eb9c3921 (patch)
tree124496f4afbe36fca9b283d447fb8d5497c6e102 /experiments
parent6cfb1f2825899630355cc9d0c0f0b81c64ce96b0 (diff)
Fix EP credit sign in cnn_baseline.py
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cnn_baseline.py3
1 files changed, 2 insertions, 1 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py
index 75c3ff8..7cf9184 100644
--- a/experiments/cnn_baseline.py
+++ b/experiments/cnn_baseline.py
@@ -738,7 +738,8 @@ def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nud
with torch.no_grad():
_, h_free = model(x, return_hidden=True)
h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge)
- credits = [flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)]
+ # Negate: EP nudge moves h toward lower loss, opposite to BP grad direction
+ credits = [-flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)]
elif method in ('state_bridge', 'credit_bridge'):
with torch.no_grad():
logits, hiddens = model(x, return_hidden=True)