From 68c0e7e16f2ea59bcfec2b8b83673da8eb9c3921 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Fri, 3 Apr 2026 10:22:38 -0500 Subject: Fix EP credit sign in cnn_baseline.py Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/cnn_baseline.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) (limited to 'experiments/cnn_baseline.py') diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py index 75c3ff8..7cf9184 100644 --- a/experiments/cnn_baseline.py +++ b/experiments/cnn_baseline.py @@ -738,7 +738,8 @@ def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nud with torch.no_grad(): _, h_free = model(x, return_hidden=True) h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge) - credits = [flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)] + # Negate: EP nudge moves h toward lower loss, opposite to BP grad direction + credits = [-flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)] elif method in ('state_bridge', 'credit_bridge'): with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) -- cgit v1.2.3