summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 06:37:23 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 06:37:23 -0500
commitb4d276f6a4b20c7766e0bceb687e42ecd4869fef (patch)
tree8f5452bb6e38174a519abb2168541d31c6a45604 /experiments
parentafc2821acceb11d50b74d68584b1bf8378adc9c7 (diff)
Round 38: add --penalty_lam flag to cifar_resmlp.py for Mode 2 cross-method test
Patches: - main(): add --penalty_lam (separate from CB's bridge temperature args.lam) - train_dfa block update (line 195): add penalty_lam * (f_l**2).sum(-1).mean() - train_state_bridge block update (line 326): same penalty - train_credit_bridge block update (line 533): same penalty Codex round 38 GO STAGE: keep penalty separate from CB lam, blocks-only, sanity-check that hidden_norms remain nontrivial (not silencing the blocks). 2-epoch smoke (results/round38_smoke_sbcb_pen) passes the silencing check: SB ||h_L||=229, CB ||h_L||=1258, both nontrivial. Deep cosines positive across all layers for SB ([0.28, 0.25, 0.23]) and rising for CB ([0.04, 0.08, 0.13, 0.15]). Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cifar_resmlp.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py
index 4324e9e..7aba671 100644
--- a/experiments/cifar_resmlp.py
+++ b/experiments/cifar_resmlp.py
@@ -193,6 +193,8 @@ def train_dfa(model, train_loader, test_loader, device, args):
# Local surrogate
f_l = model.blocks[l](h_l)
local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean()
+ if getattr(args, 'penalty_lam', 0.0) > 0.0:
+ local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean()
block_opts[l].zero_grad()
local_loss.backward()
block_opts[l].step()
@@ -322,6 +324,8 @@ def train_state_bridge(model, train_loader, test_loader, device, args):
a_norm = a / rms
f_l = model.blocks[l](h_l)
local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ if getattr(args, 'penalty_lam', 0.0) > 0.0:
+ local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean()
block_opts[l].zero_grad()
local_loss.backward()
block_opts[l].step()
@@ -527,6 +531,8 @@ def train_credit_bridge(model, train_loader, test_loader, device, args):
a_norm = a / rms
f_l = model.blocks[l](h_l)
local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ if getattr(args, 'penalty_lam', 0.0) > 0.0:
+ local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean()
block_opts[l].zero_grad()
local_loss.backward()
block_opts[l].step()
@@ -791,6 +797,9 @@ def main():
help='Subset of methods to run.')
parser.add_argument('--random_targets', action='store_true',
help='Replace each minibatch label with i.i.d. random class targets (Mode 1 data-agnostic test).')
+ parser.add_argument('--penalty_lam', type=float, default=0.0,
+ help='Per-block residual-branch penalty strength: add penalty_lam * mean(||f_l(h_l)||^2) '
+ 'to each block local loss for DFA/SB/CB. Codex round 38 Mode 2 cross-method test.')
args = parser.parse_args()
run_experiment(args)