diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-03 10:22:38 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-03 10:22:38 -0500 |
| commit | 68c0e7e16f2ea59bcfec2b8b83673da8eb9c3921 (patch) | |
| tree | 124496f4afbe36fca9b283d447fb8d5497c6e102 /experiments | |
| parent | 6cfb1f2825899630355cc9d0c0f0b81c64ce96b0 (diff) | |
Fix EP credit sign in cnn_baseline.py
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cnn_baseline.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py index 75c3ff8..7cf9184 100644 --- a/experiments/cnn_baseline.py +++ b/experiments/cnn_baseline.py @@ -738,7 +738,8 @@ def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nud with torch.no_grad(): _, h_free = model(x, return_hidden=True) h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge) - credits = [flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)] + # Negate: EP nudge moves h toward lower loss, opposite to BP grad direction + credits = [-flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)] elif method in ('state_bridge', 'credit_bridge'): with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) |
