diff options
Diffstat (limited to 'experiments/resmlp_frozen_blocks_baseline.py')
| -rw-r--r-- | experiments/resmlp_frozen_blocks_baseline.py | 14 |
1 files changed, 8 insertions, 6 deletions
diff --git a/experiments/resmlp_frozen_blocks_baseline.py b/experiments/resmlp_frozen_blocks_baseline.py index c330be2..6040bd4 100644 --- a/experiments/resmlp_frozen_blocks_baseline.py +++ b/experiments/resmlp_frozen_blocks_baseline.py @@ -137,6 +137,7 @@ def main(): parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--num_blocks', type=int, default=4) args = parser.parse_args() dev = torch.device('cuda:0') @@ -156,10 +157,11 @@ def main(): results['bp_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL BP-shallow: {results['bp_shallow']:.4f}", flush=True) - # Condition 2: BP frozen-blocks (num_blocks=4 frozen) - print(f"\n=== BP frozen-blocks (ResMLP num_blocks=4, blocks frozen), seed={args.seed} ===", flush=True) + # Condition 2: BP frozen-blocks (blocks frozen at random init) + L = args.num_blocks + print(f"\n=== BP frozen-blocks (ResMLP num_blocks={L}, blocks frozen), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) - m = ResidualMLP(input_dim, args.d_hidden, C, 4).to(dev) + m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_bp(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'BP-frozen') @@ -175,10 +177,10 @@ def main(): results['dfa_shallow'] = evaluate(m, test_loader, dev) print(f"FINAL DFA-shallow: {results['dfa_shallow']:.4f}", flush=True) - # Condition 4: DFA frozen-blocks (num_blocks=4 frozen) - print(f"\n=== DFA frozen-blocks (ResMLP num_blocks=4, blocks frozen), seed={args.seed} ===", flush=True) + # Condition 4: DFA frozen-blocks (blocks frozen at random init) + print(f"\n=== DFA frozen-blocks (ResMLP num_blocks={L}, blocks frozen), seed={args.seed} ===", flush=True) torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) - m = ResidualMLP(input_dim, args.d_hidden, C, 4).to(dev) + m = ResidualMLP(input_dim, args.d_hidden, C, L).to(dev) freeze_blocks(m) print(f" n_params: {sum(p.numel() for p in m.parameters())} ({sum(p.numel() for p in m.parameters() if p.requires_grad)} trainable)", flush=True) train_dfa(m, train_loader, test_loader, dev, args.epochs, args.lr, args.wd, 'DFA-frozen') |
