diff options
Diffstat (limited to 'experiments/cnn_baseline.py')
| -rw-r--r-- | experiments/cnn_baseline.py | 3 |
1 files changed, 2 insertions, 1 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py index 75c3ff8..7cf9184 100644 --- a/experiments/cnn_baseline.py +++ b/experiments/cnn_baseline.py @@ -738,7 +738,8 @@ def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nud with torch.no_grad(): _, h_free = model(x, return_hidden=True) h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge) - credits = [flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)] + # Negate: EP nudge moves h toward lower loss, opposite to BP grad direction + credits = [-flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)] elif method in ('state_bridge', 'credit_bridge'): with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) |
