diff options
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cnn_baseline.py | 9 |
1 files changed, 6 insertions, 3 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py index 7cf9184..84f415e 100644 --- a/experiments/cnn_baseline.py +++ b/experiments/cnn_baseline.py @@ -697,17 +697,20 @@ def compute_bp_grads(model, x, y): # Re-run forward with requires_grad on intermediate activations # We build the forward manually to hook into each h_l + # Build forward graph keeping all h[l] connected so gradients flow through h = [None] * L inp = x for l in range(L): if l == 3: inp = inp.flatten(1) if inp.dim() > 2 else inp - h[l] = model.blocks[l](inp.detach().requires_grad_(False)) - h[l] = h[l].detach().requires_grad_(True) - inp = h[l] + out = model.blocks[l](inp) + h[l] = out # keep in graph (no detach between layers) + inp = out logits = model.out_head(h[3]) loss = F.cross_entropy(logits, y) + + # Compute gradient w.r.t. all hidden states gs = torch.autograd.grad(loss, h, allow_unused=True) return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h |
