diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-31 22:18:32 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-31 22:18:32 -0500 |
| commit | 8b78fbb1308d31bedd74f4b1deb250f5e684a6d3 (patch) | |
| tree | ce24ce38c3e6425d32f5873667ea3169d7d970bb /experiments | |
| parent | 0eddc70d8c89adb2ae7105b7d3be813310fd1b80 (diff) | |
Update naive StateErr v3: L2 norm ratio formula, with checkpoints saved
Formula: ||h_{L//2} - h_L||_2 / ||h_L||_2 (scalar L2 ratio)
A1: 240 rows (3 alpha × 2 depth × 4 methods × 10 seeds)
A2: 40 rows (4 methods including BP × 10 seeds)
All model checkpoints saved in checkpoints_A1/ and checkpoints_A2/
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/compute_naive_state_err.py | 45 |
1 files changed, 36 insertions, 9 deletions
diff --git a/experiments/compute_naive_state_err.py b/experiments/compute_naive_state_err.py index 64c8790..09d0195 100644 --- a/experiments/compute_naive_state_err.py +++ b/experiments/compute_naive_state_err.py @@ -17,7 +17,7 @@ from models.state_bridge import StateBridgeNet def compute_naive_state_err(model, dataloader, device, eval_layer=None): - """Compute ||h_l - h_L|| / ||h_L|| averaged over data.""" + """Compute ||h_l - h_L||_2 / ||h_L||_2 averaged over data (L2 norm ratio, scalar).""" model.eval() L = model.num_blocks if eval_layer is None: @@ -32,9 +32,10 @@ def compute_naive_state_err(model, dataloader, device, eval_layer=None): _, hiddens = model(x, return_hidden=True) h_l = hiddens[eval_layer] h_L = hiddens[-1] - norm_L = h_L.norm(dim=-1, keepdim=True).clamp(min=1.0) - err = ((h_l - h_L) / norm_L).pow(2).sum(-1).mean() - total_err += err.item() * x.size(0) + diff_norm = (h_l - h_L).norm(dim=-1) # (batch,) + hL_norm = h_L.norm(dim=-1).clamp(min=1e-8) # (batch,) + ratio = (diff_norm / hL_norm).mean() # scalar + total_err += ratio.item() * x.size(0) n += x.size(0) return total_err / n @@ -213,9 +214,17 @@ def run_A1_naive(args): with torch.no_grad(): eval_y = teacher(eval_x).argmax(-1) eval_data = [(eval_x, eval_y)] for method in methods: + ckpt_dir = os.path.join(args.output_dir, 'checkpoints_A1') + os.makedirs(ckpt_dir, exist_ok=True) + ckpt_path = os.path.join(ckpt_dir, f'a{alpha}_L{L}_{method}_s{seed}.pt') torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) model = StudentNet(d, C, L, alpha=alpha).to(device) - model = train_synth_method(method, model, teacher, device, d, C, L) + if os.path.exists(ckpt_path): + model.load_state_dict(torch.load(ckpt_path, map_location=device)) + print(f" A1 alpha={alpha} L={L} {method} s={seed}: loaded checkpoint", flush=True) + else: + model = train_synth_method(method, model, teacher, device, d, C, L) + torch.save(model.state_dict(), ckpt_path) nse = compute_naive_state_err(model, eval_data, device, eval_layer=L//2) rows.append({'alpha': alpha, 'depth': L, 'method': method, 'seed': seed, 'naive_StateErr': nse}) print(f" A1 alpha={alpha} L={L} {method} s={seed}: naive_StateErr={nse:.6f}", flush=True) @@ -237,7 +246,17 @@ def get_cifar10(bs=128): def train_cifar_method(method, model, train_loader, test_loader, device, L, d, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01): """Train CIFAR model, return trained model.""" C = 10 - if method == 'dfa': + if method == 'bp': + opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs) + for ep in range(1, epochs+1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device) + loss = F.cross_entropy(model(x), y); opt.zero_grad(); loss.backward(); opt.step() + sch.step() + return model + elif method == 'dfa': Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)] bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) @@ -351,15 +370,23 @@ def run_A2_naive(args): """Compute naive state err for CIFAR methods.""" device = torch.device(f'cuda:{args.gpu}') seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000] - L, d = 4, 256; methods = ['dfa', 'state_bridge', 'credit_bridge'] + L, d = 4, 256; methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] train_loader, test_loader = get_cifar10() rows = [] for seed in seeds: for method in methods: + ckpt_dir = os.path.join(args.output_dir, 'checkpoints_A2') + os.makedirs(ckpt_dir, exist_ok=True) + ckpt_path = os.path.join(ckpt_dir, f'{method}_s{seed}.pt') torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) model = ResidualMLP(3072, d, 10, L).to(device) - print(f" A2 {method} s={seed}: training...", flush=True) - model = train_cifar_method(method, model, train_loader, test_loader, device, L, d) + if os.path.exists(ckpt_path): + model.load_state_dict(torch.load(ckpt_path, map_location=device)) + print(f" A2 {method} s={seed}: loaded checkpoint", flush=True) + else: + print(f" A2 {method} s={seed}: training...", flush=True) + model = train_cifar_method(method, model, train_loader, test_loader, device, L, d) + torch.save(model.state_dict(), ckpt_path) nse = compute_naive_state_err(model, test_loader, device, eval_layer=L//2) rows.append({'method': method, 'seed': seed, 'naive_StateErr': nse}) print(f" A2 {method} s={seed}: naive_StateErr={nse:.6f}", flush=True) |
