diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-03 17:27:08 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-03 17:27:08 -0500 |
| commit | 52b421fde3faa673e7007a456846f8195cb45942 (patch) | |
| tree | d60eba25ed0ca0880e3753def92bf7cd742205b2 /experiments | |
| parent | 58e859d83c77002d22571003075150d7e20d18a4 (diff) | |
Fix CNN compute_bp_grads: remove inter-layer detach so gradients flow to all layers
Old code detached hidden states between layers, making layers 0-2 disconnected
from the loss (gradient = None → 0). Fixed by keeping the forward graph connected.
BP CNN Gamma per-layer now: [0.985, 0.990, 0.987, 0.967] (was [0, 0, 0, 0.967])
But gradient norms are ~1e-17 (genuine numerical precision issue with CNN architecture).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cnn_baseline.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py index 7cf9184..84f415e 100644 --- a/experiments/cnn_baseline.py +++ b/experiments/cnn_baseline.py @@ -697,17 +697,20 @@ def compute_bp_grads(model, x, y): # Re-run forward with requires_grad on intermediate activations # We build the forward manually to hook into each h_l + # Build forward graph keeping all h[l] connected so gradients flow through h = [None] * L inp = x for l in range(L): if l == 3: inp = inp.flatten(1) if inp.dim() > 2 else inp - h[l] = model.blocks[l](inp.detach().requires_grad_(False)) - h[l] = h[l].detach().requires_grad_(True) - inp = h[l] + out = model.blocks[l](inp) + h[l] = out # keep in graph (no detach between layers) + inp = out logits = model.out_head(h[3]) loss = F.cross_entropy(logits, y) + + # Compute gradient w.r.t. all hidden states gs = torch.autograd.grad(loss, h, allow_unused=True) return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h |
