summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 17:27:08 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 17:27:08 -0500
commit52b421fde3faa673e7007a456846f8195cb45942 (patch)
treed60eba25ed0ca0880e3753def92bf7cd742205b2 /experiments
parent58e859d83c77002d22571003075150d7e20d18a4 (diff)
Fix CNN compute_bp_grads: remove inter-layer detach so gradients flow to all layers
Old code detached hidden states between layers, making layers 0-2 disconnected from the loss (gradient = None → 0). Fixed by keeping the forward graph connected. BP CNN Gamma per-layer now: [0.985, 0.990, 0.987, 0.967] (was [0, 0, 0, 0.967]) But gradient norms are ~1e-17 (genuine numerical precision issue with CNN architecture). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cnn_baseline.py9
1 files changed, 6 insertions, 3 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py
index 7cf9184..84f415e 100644
--- a/experiments/cnn_baseline.py
+++ b/experiments/cnn_baseline.py
@@ -697,17 +697,20 @@ def compute_bp_grads(model, x, y):
# Re-run forward with requires_grad on intermediate activations
# We build the forward manually to hook into each h_l
+ # Build forward graph keeping all h[l] connected so gradients flow through
h = [None] * L
inp = x
for l in range(L):
if l == 3:
inp = inp.flatten(1) if inp.dim() > 2 else inp
- h[l] = model.blocks[l](inp.detach().requires_grad_(False))
- h[l] = h[l].detach().requires_grad_(True)
- inp = h[l]
+ out = model.blocks[l](inp)
+ h[l] = out # keep in graph (no detach between layers)
+ inp = out
logits = model.out_head(h[3])
loss = F.cross_entropy(logits, y)
+
+ # Compute gradient w.r.t. all hidden states
gs = torch.autograd.grad(loss, h, allow_unused=True)
return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h