From b4d276f6a4b20c7766e0bceb687e42ecd4869fef Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 8 Apr 2026 06:37:23 -0500 Subject: Round 38: add --penalty_lam flag to cifar_resmlp.py for Mode 2 cross-method test Patches: - main(): add --penalty_lam (separate from CB's bridge temperature args.lam) - train_dfa block update (line 195): add penalty_lam * (f_l**2).sum(-1).mean() - train_state_bridge block update (line 326): same penalty - train_credit_bridge block update (line 533): same penalty Codex round 38 GO STAGE: keep penalty separate from CB lam, blocks-only, sanity-check that hidden_norms remain nontrivial (not silencing the blocks). 2-epoch smoke (results/round38_smoke_sbcb_pen) passes the silencing check: SB ||h_L||=229, CB ||h_L||=1258, both nontrivial. Deep cosines positive across all layers for SB ([0.28, 0.25, 0.23]) and rising for CB ([0.04, 0.08, 0.13, 0.15]). Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/cifar_resmlp.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py index 4324e9e..7aba671 100644 --- a/experiments/cifar_resmlp.py +++ b/experiments/cifar_resmlp.py @@ -193,6 +193,8 @@ def train_dfa(model, train_loader, test_loader, device, args): # Local surrogate f_l = model.blocks[l](h_l) local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() @@ -322,6 +324,8 @@ def train_state_bridge(model, train_loader, test_loader, device, args): a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() @@ -527,6 +531,8 @@ def train_credit_bridge(model, train_loader, test_loader, device, args): a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() @@ -791,6 +797,9 @@ def main(): help='Subset of methods to run.') parser.add_argument('--random_targets', action='store_true', help='Replace each minibatch label with i.i.d. random class targets (Mode 1 data-agnostic test).') + parser.add_argument('--penalty_lam', type=float, default=0.0, + help='Per-block residual-branch penalty strength: add penalty_lam * mean(||f_l(h_l)||^2) ' + 'to each block local loss for DFA/SB/CB. Codex round 38 Mode 2 cross-method test.') args = parser.parse_args() run_experiment(args) -- cgit v1.2.3