summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-31 22:18:32 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-31 22:18:32 -0500
commit8b78fbb1308d31bedd74f4b1deb250f5e684a6d3 (patch)
treece24ce38c3e6425d32f5873667ea3169d7d970bb /experiments
parent0eddc70d8c89adb2ae7105b7d3be813310fd1b80 (diff)
Update naive StateErr v3: L2 norm ratio formula, with checkpoints saved
Formula: ||h_{L//2} - h_L||_2 / ||h_L||_2 (scalar L2 ratio) A1: 240 rows (3 alpha × 2 depth × 4 methods × 10 seeds) A2: 40 rows (4 methods including BP × 10 seeds) All model checkpoints saved in checkpoints_A1/ and checkpoints_A2/ Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/compute_naive_state_err.py45
1 files changed, 36 insertions, 9 deletions
diff --git a/experiments/compute_naive_state_err.py b/experiments/compute_naive_state_err.py
index 64c8790..09d0195 100644
--- a/experiments/compute_naive_state_err.py
+++ b/experiments/compute_naive_state_err.py
@@ -17,7 +17,7 @@ from models.state_bridge import StateBridgeNet
def compute_naive_state_err(model, dataloader, device, eval_layer=None):
- """Compute ||h_l - h_L|| / ||h_L|| averaged over data."""
+ """Compute ||h_l - h_L||_2 / ||h_L||_2 averaged over data (L2 norm ratio, scalar)."""
model.eval()
L = model.num_blocks
if eval_layer is None:
@@ -32,9 +32,10 @@ def compute_naive_state_err(model, dataloader, device, eval_layer=None):
_, hiddens = model(x, return_hidden=True)
h_l = hiddens[eval_layer]
h_L = hiddens[-1]
- norm_L = h_L.norm(dim=-1, keepdim=True).clamp(min=1.0)
- err = ((h_l - h_L) / norm_L).pow(2).sum(-1).mean()
- total_err += err.item() * x.size(0)
+ diff_norm = (h_l - h_L).norm(dim=-1) # (batch,)
+ hL_norm = h_L.norm(dim=-1).clamp(min=1e-8) # (batch,)
+ ratio = (diff_norm / hL_norm).mean() # scalar
+ total_err += ratio.item() * x.size(0)
n += x.size(0)
return total_err / n
@@ -213,9 +214,17 @@ def run_A1_naive(args):
with torch.no_grad(): eval_y = teacher(eval_x).argmax(-1)
eval_data = [(eval_x, eval_y)]
for method in methods:
+ ckpt_dir = os.path.join(args.output_dir, 'checkpoints_A1')
+ os.makedirs(ckpt_dir, exist_ok=True)
+ ckpt_path = os.path.join(ckpt_dir, f'a{alpha}_L{L}_{method}_s{seed}.pt')
torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
model = StudentNet(d, C, L, alpha=alpha).to(device)
- model = train_synth_method(method, model, teacher, device, d, C, L)
+ if os.path.exists(ckpt_path):
+ model.load_state_dict(torch.load(ckpt_path, map_location=device))
+ print(f" A1 alpha={alpha} L={L} {method} s={seed}: loaded checkpoint", flush=True)
+ else:
+ model = train_synth_method(method, model, teacher, device, d, C, L)
+ torch.save(model.state_dict(), ckpt_path)
nse = compute_naive_state_err(model, eval_data, device, eval_layer=L//2)
rows.append({'alpha': alpha, 'depth': L, 'method': method, 'seed': seed, 'naive_StateErr': nse})
print(f" A1 alpha={alpha} L={L} {method} s={seed}: naive_StateErr={nse:.6f}", flush=True)
@@ -237,7 +246,17 @@ def get_cifar10(bs=128):
def train_cifar_method(method, model, train_loader, test_loader, device, L, d, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01):
"""Train CIFAR model, return trained model."""
C = 10
- if method == 'dfa':
+ if method == 'bp':
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs+1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0),-1).to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y); opt.zero_grad(); loss.backward(); opt.step()
+ sch.step()
+ return model
+ elif method == 'dfa':
Bs = [torch.randn(d, C, device=device)/np.sqrt(C) for _ in range(L)]
bops = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
eop = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
@@ -351,15 +370,23 @@ def run_A2_naive(args):
"""Compute naive state err for CIFAR methods."""
device = torch.device(f'cuda:{args.gpu}')
seeds = [42,123,456,789,1024,2048,3000,4000,5000,6000]
- L, d = 4, 256; methods = ['dfa', 'state_bridge', 'credit_bridge']
+ L, d = 4, 256; methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
train_loader, test_loader = get_cifar10()
rows = []
for seed in seeds:
for method in methods:
+ ckpt_dir = os.path.join(args.output_dir, 'checkpoints_A2')
+ os.makedirs(ckpt_dir, exist_ok=True)
+ ckpt_path = os.path.join(ckpt_dir, f'{method}_s{seed}.pt')
torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
model = ResidualMLP(3072, d, 10, L).to(device)
- print(f" A2 {method} s={seed}: training...", flush=True)
- model = train_cifar_method(method, model, train_loader, test_loader, device, L, d)
+ if os.path.exists(ckpt_path):
+ model.load_state_dict(torch.load(ckpt_path, map_location=device))
+ print(f" A2 {method} s={seed}: loaded checkpoint", flush=True)
+ else:
+ print(f" A2 {method} s={seed}: training...", flush=True)
+ model = train_cifar_method(method, model, train_loader, test_loader, device, L, d)
+ torch.save(model.state_dict(), ckpt_path)
nse = compute_naive_state_err(model, test_loader, device, eval_layer=L//2)
rows.append({'method': method, 'seed': seed, 'naive_StateErr': nse})
print(f" A2 {method} s={seed}: naive_StateErr={nse:.6f}", flush=True)