diff options
| -rw-r--r-- | experiments/cifar_resmlp.py | 9 |
1 files changed, 9 insertions, 0 deletions
diff --git a/experiments/cifar_resmlp.py b/experiments/cifar_resmlp.py index 4324e9e..7aba671 100644 --- a/experiments/cifar_resmlp.py +++ b/experiments/cifar_resmlp.py @@ -193,6 +193,8 @@ def train_dfa(model, train_loader, test_loader, device, args): # Local surrogate f_l = model.blocks[l](h_l) local_loss = (f_l * a_dfa_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() @@ -322,6 +324,8 @@ def train_state_bridge(model, train_loader, test_loader, device, args): a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() @@ -527,6 +531,8 @@ def train_credit_bridge(model, train_loader, test_loader, device, args): a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() + if getattr(args, 'penalty_lam', 0.0) > 0.0: + local_loss = local_loss + args.penalty_lam * (f_l ** 2).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() block_opts[l].step() @@ -791,6 +797,9 @@ def main(): help='Subset of methods to run.') parser.add_argument('--random_targets', action='store_true', help='Replace each minibatch label with i.i.d. random class targets (Mode 1 data-agnostic test).') + parser.add_argument('--penalty_lam', type=float, default=0.0, + help='Per-block residual-branch penalty strength: add penalty_lam * mean(||f_l(h_l)||^2) ' + 'to each block local loss for DFA/SB/CB. Codex round 38 Mode 2 cross-method test.') args = parser.parse_args() run_experiment(args) |
