From 52b421fde3faa673e7007a456846f8195cb45942 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Fri, 3 Apr 2026 17:27:08 -0500 Subject: Fix CNN compute_bp_grads: remove inter-layer detach so gradients flow to all layers MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Old code detached hidden states between layers, making layers 0-2 disconnected from the loss (gradient = None → 0). Fixed by keeping the forward graph connected. BP CNN Gamma per-layer now: [0.985, 0.990, 0.987, 0.967] (was [0, 0, 0, 0.967]) But gradient norms are ~1e-17 (genuine numerical precision issue with CNN architecture). Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/cnn_baseline.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'experiments/cnn_baseline.py') diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py index 7cf9184..84f415e 100644 --- a/experiments/cnn_baseline.py +++ b/experiments/cnn_baseline.py @@ -697,17 +697,20 @@ def compute_bp_grads(model, x, y): # Re-run forward with requires_grad on intermediate activations # We build the forward manually to hook into each h_l + # Build forward graph keeping all h[l] connected so gradients flow through h = [None] * L inp = x for l in range(L): if l == 3: inp = inp.flatten(1) if inp.dim() > 2 else inp - h[l] = model.blocks[l](inp.detach().requires_grad_(False)) - h[l] = h[l].detach().requires_grad_(True) - inp = h[l] + out = model.blocks[l](inp) + h[l] = out # keep in graph (no detach between layers) + inp = out logits = model.out_head(h[3]) loss = F.cross_entropy(logits, y) + + # Compute gradient w.r.t. all hidden states gs = torch.autograd.grad(loss, h, allow_unused=True) return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h -- cgit v1.2.3