diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 06:37:23 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 06:37:23 -0500 |
| commit | b4d276f6a4b20c7766e0bceb687e42ecd4869fef (patch) | |
| tree | 8f5452bb6e38174a519abb2168541d31c6a45604 | |
| parent | afc2821acceb11d50b74d68584b1bf8378adc9c7 (diff) | |
Round 38: add --penalty_lam flag to cifar_resmlp.py for Mode 2 cross-method test
Patches:
- main(): add --penalty_lam (separate from CB's bridge temperature args.lam)
- train_dfa block update (line 195): add penalty_lam * (f_l**2).sum(-1).mean()
- train_state_bridge block update (line 326): same penalty
- train_credit_bridge block update (line 533): same penalty
Codex round 38 GO STAGE: keep penalty separate from CB lam, blocks-only,
sanity-check that hidden_norms remain nontrivial (not silencing the blocks).
2-epoch smoke (results/round38_smoke_sbcb_pen) passes the silencing check:
SB ||h_L||=229, CB ||h_L||=1258, both nontrivial. Deep cosines positive across
all layers for SB ([0.28, 0.25, 0.23]) and rising for CB ([0.04, 0.08, 0.13, 0.15]).
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
| -rw-r--r-- | experiments/cifar_resmlp.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py index 4324e9e..7aba671 100644 --- a/experiments/cifar_resmlp.py +++ b/experiments/cifar_resmlp.py @@ -193,6 +193,8 @@ def train_dfa(model, train_loader, test_loader, device, args): # Local surrogate f_l = model.blocks[l](h_l) local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() @@ -322,6 +324,8 @@ def train_state_bridge(model, train_loader, test_loader, device, args): a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() @@ -527,6 +531,8 @@ def train_credit_bridge(model, train_loader, test_loader, device, args): a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() @@ -791,6 +797,9 @@ def main(): help='Subset of methods to run.') parser.add_argument('--random_targets', action='store_true', help='Replace each minibatch label with i.i.d. random class targets (Mode 1 data-agnostic test).') + parser.add_argument('--penalty_lam', type=float, default=0.0, + help='Per-block residual-branch penalty strength: add penalty_lam * mean(||f_l(h_l)||^2) ' + 'to each block local loss for DFA/SB/CB. Codex round 38 Mode 2 cross-method test.') args = parser.parse_args() run_experiment(args) |
