summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--experiments/cnn_baseline.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py
index 7cf9184..84f415e 100644
--- a/experiments/cnn_baseline.py
+++ b/experiments/cnn_baseline.py
@@ -697,17 +697,20 @@ def compute_bp_grads(model, x, y):
# Re-run forward with requires_grad on intermediate activations
# We build the forward manually to hook into each h_l
+ # Build forward graph keeping all h[l] connected so gradients flow through
h = [None] * L
inp = x
for l in range(L):
if l == 3:
inp = inp.flatten(1) if inp.dim() > 2 else inp
- h[l] = model.blocks[l](inp.detach().requires_grad_(False))
- h[l] = h[l].detach().requires_grad_(True)
- inp = h[l]
+ out = model.blocks[l](inp)
+ h[l] = out # keep in graph (no detach between layers)
+ inp = out
logits = model.out_head(h[3])
loss = F.cross_entropy(logits, y)
+
+ # Compute gradient w.r.t. all hidden states
gs = torch.autograd.grad(loss, h, allow_unused=True)
return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h