summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 11:28:13 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-02 11:28:13 -0500
commitef80d52840a1c6fb7f9a22985784ce311edc59a4 (patch)
tree00e674c9bf71f31cecd330a115d75cb8a417cea8 /experiments
parent61204b6010e403b4c61b093f2a208a881b20fa11 (diff)
Add CNN baseline: SmallCNN with BP/DFA/EP on CIFAR-10
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cnn_baseline.py600
1 files changed, 600 insertions, 0 deletions
diff --git a/experiments/cnn_baseline.py b/experiments/cnn_baseline.py
new file mode 100644
index 0000000..f55b77b
--- /dev/null
+++ b/experiments/cnn_baseline.py
@@ -0,0 +1,600 @@
+"""
+CNN baseline for CIFAR-10: BP / DFA / EP on a small ConvNet.
+One method+seed per invocation for clean process isolation.
+
+Architecture:
+ Conv2d(3,32,3,padding=1) -> ReLU
+ Conv2d(32,64,3,padding=1) -> ReLU -> MaxPool(2) [32->16]
+ Conv2d(64,128,3,padding=1) -> ReLU -> MaxPool(2) [16->8]
+ flatten -> FC(128*8*8=8192, 256) -> ReLU -> FC(256, 10)
+
+Blocks (for local update):
+ block 0 : Conv1 (Conv2d 3->32)
+ block 1 : Conv2 (Conv2d 32->64) + MaxPool
+ block 2 : Conv3 (Conv2d 64->128) + MaxPool
+ block 3 : FC1 (Linear 8192->256)
+ block 4 : FC2 (Linear 256->10) -- output head, always trained with loss
+
+Hidden states (post-activation, for credit):
+ h0 : (B, 32, 32, 32) after Conv1+ReLU
+ h1 : (B, 64, 16, 16) after Conv2+ReLU+MaxPool
+ h2 : (B, 128, 8, 8) after Conv3+ReLU+MaxPool
+ h3 : (B, 256) after flatten+FC1+ReLU
+
+DFA: flatten each h_l to (B, d_l), random feedback B_l: (d_l, 10)
+EP: energy E = sum_l 0.5 ||h_{l+1} - F_l(h_l)||^2 adapted for CNN
+
+Usage: python cnn_baseline.py --method bp --seed 42 --gpu 0
+Output: results/cnn_baseline/{method}_s{seed}.json + .pt checkpoint
+"""
+
+import os, sys, json, argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation
+import torchvision, torchvision.transforms as transforms
+
+
+# ---------------------------------------------------------------------------
+# Data
+# ---------------------------------------------------------------------------
+
+def get_cifar10(bs=128):
+ tt = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trl = DataLoader(
+ torchvision.datasets.CIFAR10('./data', True, download=True, transform=tt),
+ bs, True, num_workers=4, pin_memory=True)
+ tel = DataLoader(
+ torchvision.datasets.CIFAR10('./data', False, download=True, transform=tv),
+ bs, False, num_workers=4, pin_memory=True)
+ return trl, tel
+
+
+# ---------------------------------------------------------------------------
+# Model
+# ---------------------------------------------------------------------------
+
+class SmallCNN(nn.Module):
+ """
+ A small 3-conv CNN for CIFAR-10.
+
+ Blocks (nn.Module list, mirrors the 5-block treatment):
+ blocks[0] : Conv1 layer (Conv2d 3->32, BN, ReLU)
+ blocks[1] : Conv2 layer (Conv2d 32->64, BN, ReLU, MaxPool)
+ blocks[2] : Conv3 layer (Conv2d 64->128, BN, ReLU, MaxPool)
+ blocks[3] : FC1 layer (Linear 8192->256, ReLU)
+ out_head : FC2 layer (Linear 256->10)
+
+ forward(x, return_hidden=False):
+ returns logits, or (logits, [h0, h1, h2, h3]) when return_hidden=True.
+ h_l are post-activation tensors; h3 is (B,256) flat.
+ """
+ # flat dim of each hidden state
+ FLAT_DIMS = [32 * 32 * 32, 64 * 16 * 16, 128 * 8 * 8, 256]
+ NUM_BLOCKS = 4 # conv1, conv2, conv3, fc1 (out_head is separate)
+
+ def __init__(self):
+ super().__init__()
+ self.blocks = nn.ModuleList([
+ # block 0: Conv1
+ nn.Sequential(
+ nn.Conv2d(3, 32, 3, padding=1),
+ nn.BatchNorm2d(32),
+ nn.ReLU(inplace=True),
+ ),
+ # block 1: Conv2 + MaxPool
+ nn.Sequential(
+ nn.Conv2d(32, 64, 3, padding=1),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(2),
+ ),
+ # block 2: Conv3 + MaxPool
+ nn.Sequential(
+ nn.Conv2d(64, 128, 3, padding=1),
+ nn.BatchNorm2d(128),
+ nn.ReLU(inplace=True),
+ nn.MaxPool2d(2),
+ ),
+ # block 3: FC1
+ nn.Sequential(
+ nn.Linear(128 * 8 * 8, 256),
+ nn.ReLU(inplace=True),
+ ),
+ ])
+ self.out_head = nn.Linear(256, 10)
+ self.num_blocks = self.NUM_BLOCKS
+ self.flat_dims = self.FLAT_DIMS
+
+ def forward(self, x, return_hidden=False):
+ """
+ x: (B, 3, 32, 32)
+ Returns logits (B,10), optionally with list of 4 hidden states.
+ h0: (B,32,32,32) h1: (B,64,16,16) h2: (B,128,8,8) h3: (B,256)
+ """
+ h0 = self.blocks[0](x) # (B, 32, 32, 32)
+ h1 = self.blocks[1](h0) # (B, 64, 16, 16)
+ h2 = self.blocks[2](h1) # (B, 128, 8, 8)
+ h3 = self.blocks[3](h2.flatten(1)) # (B, 256)
+ logits = self.out_head(h3) # (B, 10)
+ if return_hidden:
+ return logits, [h0, h1, h2, h3]
+ return logits
+
+ def forward_from(self, h, layer_idx):
+ """
+ Run the network from hidden state h at layer `layer_idx` to logits.
+ layer_idx in {0, 1, 2, 3} (0=after block0, 3=after block3).
+ h should be the post-activation tensor at that layer.
+ """
+ c = h
+ for i in range(layer_idx + 1, self.num_blocks):
+ if i == 3:
+ c = self.blocks[i](c.flatten(1) if c.dim() > 2 else c)
+ else:
+ c = self.blocks[i](c)
+ if c.dim() > 2:
+ c = c.flatten(1)
+ logits = self.out_head(c if c.dim() == 2 else c.flatten(1))
+ return logits
+
+
+def evaluate(model, loader, dev):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in loader:
+ x, y = x.to(dev), y.to(dev)
+ correct += (model(x).argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+# ---------------------------------------------------------------------------
+# Helper: flatten hidden state for credit computation
+# ---------------------------------------------------------------------------
+
+def flat(h):
+ """Flatten spatial dims: (B, C, H, W) -> (B, C*H*W) or (B, D) -> (B, D)."""
+ return h.flatten(1) if h.dim() > 2 else h
+
+
+# ---------------------------------------------------------------------------
+# Training: BP
+# ---------------------------------------------------------------------------
+
+def train_bp(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01):
+ opt = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=epochs)
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in trl:
+ x, y = x.to(dev), y.to(dev)
+ F.cross_entropy(model(x), y).backward()
+ opt.step()
+ opt.zero_grad()
+ sch.step()
+ if ep % 20 == 0:
+ print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True)
+ return model
+
+
+# ---------------------------------------------------------------------------
+# Training: DFA
+# ---------------------------------------------------------------------------
+
+def train_dfa(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01):
+ """
+ Direct Feedback Alignment for CNN.
+
+ For each block l, a random matrix B_l: (flat_dim_l, 10) maps the global
+ error signal e_T (softmax-CE gradient at output) back to the hidden space.
+ The local surrogate loss is:
+ L_l = < F_l(h_{l-1}), a_l / ||a_l||_rms >
+ where a_l = B_l @ e_T (flattened credit, then reshaped if needed).
+ The out_head is trained with standard cross-entropy on the final hidden state.
+ """
+ L = model.num_blocks # 4 blocks (conv1, conv2, conv3, fc1)
+ C = 10
+ flat_dims = model.flat_dims # [32768, 16384, 8192, 256]
+
+ # Random feedback matrices (fixed, not trained)
+ Bs = [torch.randn(flat_dims[l], C, device=dev) / np.sqrt(C) for l in range(L)]
+
+ # Per-block optimizers + head optimizer
+ block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ all_opts = block_opts + [head_opt]
+ schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_opts]
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in trl:
+ x, y = x.to(dev), y.to(dev)
+ B = x.size(0)
+
+ # Forward pass (no grad) to get hidden states and global error
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ probs = logits.softmax(-1) # (B, 10)
+ e_T = probs.clone()
+ e_T[torch.arange(B), y] -= 1.0 # (B, 10)
+
+ # --- Train out_head with standard CE on detached h3 ---
+ h3_det = hiddens[3].detach()
+ ce_loss = F.cross_entropy(model.out_head(h3_det), y)
+ head_opt.zero_grad()
+ ce_loss.backward()
+ head_opt.step()
+
+ # --- Train each block with DFA local surrogate ---
+ # For conv blocks (l=0,1,2) we need to re-run the block forward
+ # starting from the *previous* hidden state.
+ # The "input" to block l is:
+ # l=0: x (raw input image)
+ # l=1: h0
+ # l=2: h1
+ # l=3: h2 (flattened)
+
+ inputs = [x, hiddens[0].detach(), hiddens[1].detach(), hiddens[2].detach()]
+
+ for l in range(L):
+ # Compute DFA credit signal (flattened)
+ a_l_flat = (e_T @ Bs[l].T).detach() # (B, flat_dim_l)
+ rms = (a_l_flat ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_l_norm = a_l_flat / rms # (B, flat_dim_l)
+
+ # Forward through block l with grad
+ inp = inputs[l].detach()
+ if l == 3:
+ out_l = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp)
+ else:
+ out_l = model.blocks[l](inp)
+
+ # Local surrogate: <F_l(inp), a_l_norm> (summed over spatial, averaged over batch)
+ out_flat = flat(out_l) # (B, flat_dim_l)
+ local_loss = (out_flat * a_l_norm).sum(-1).mean()
+
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ for s in schedulers:
+ s.step()
+ if ep % 20 == 0:
+ print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True)
+
+ return model
+
+
+# ---------------------------------------------------------------------------
+# Training: EP (Equilibrium Propagation adapted for CNN)
+# ---------------------------------------------------------------------------
+
+def ep_energy_cnn(model, hiddens, x):
+ """
+ CNN EP energy: E = sum_l 0.5 ||h_l - F_l(inp_l)||^2 (flattened).
+
+ hiddens[0] = h0 (B,32,32,32) -- target for block 0 applied to x
+ hiddens[1] = h1 (B,64,16,16) -- target for block 1 applied to h0
+ hiddens[2] = h2 (B,128,8,8) -- target for block 2 applied to h1
+ hiddens[3] = h3 (B,256) -- target for block 3 applied to h2.flatten
+ """
+ inputs = [x, hiddens[0], hiddens[1], hiddens[2]]
+ E = 0.0
+ for l in range(model.num_blocks):
+ inp = inputs[l]
+ if l == 3:
+ pred = model.blocks[l](inp.flatten(1) if inp.dim() > 2 else inp)
+ else:
+ pred = model.blocks[l](inp)
+ # Compare flattened versions
+ pred_f = flat(pred)
+ h_f = flat(hiddens[l])
+ residual = h_f - pred_f # (B, d_l)
+ E = E + 0.5 * (residual ** 2).sum(-1) # (B,)
+ return E
+
+
+def ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge):
+ """
+ Nudged phase: minimize E(h) + beta * CE(out_head(h3), y)
+ w.r.t. h0, h1, h2, h3 (all free hidden states).
+ x is fixed (pixel input, not a hidden state).
+ """
+ L = model.num_blocks
+ # Initialise from free phase
+ h_nudged = [h.clone().detach() for h in h_free]
+ for i in range(L):
+ h_nudged[i].requires_grad_(True)
+
+ inner_opt = optim.SGD(h_nudged, lr=alpha_nudge)
+
+ for _ in range(T_nudge):
+ E = ep_energy_cnn(model, h_nudged, x) # (B,)
+ logits = model.out_head(h_nudged[3]) # (B, 10)
+ C_loss = F.cross_entropy(logits, y, reduction='none') # (B,)
+ total = (E + beta * C_loss).mean()
+ inner_opt.zero_grad()
+ total.backward()
+ inner_opt.step()
+
+ return [h.detach() for h in h_nudged]
+
+
+def train_ep(model, trl, tel, dev, epochs=100, lr=1e-3, wd=0.01,
+ beta=0.5, T_nudge=20, alpha_nudge=0.05):
+ """
+ Equilibrium Propagation for the small CNN.
+ Weight update rule:
+ Δθ ∝ (dE_nudged/dθ - dE_free/dθ) / beta
+ For the out_head: standard CE on nudged output (no dE/dtheta_head term).
+ """
+ L = model.num_blocks
+
+ block_opts = [optim.AdamW(model.blocks[l].parameters(), lr=lr, weight_decay=wd) for l in range(L)]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=lr, weight_decay=wd)
+ all_opts = block_opts + [head_opt]
+ schedulers = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in all_opts]
+
+ for ep in range(1, epochs + 1):
+ model.train()
+ for x, y in trl:
+ x, y = x.to(dev), y.to(dev)
+
+ # Free phase: standard forward pass
+ with torch.no_grad():
+ _, h_free = model(x, return_hidden=True)
+
+ # Nudged phase
+ h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge)
+
+ # Zero all grads
+ for o in all_opts:
+ o.zero_grad()
+
+ # EP weight update per block:
+ # dE/dtheta_l = -residual_l * dF_l/dtheta_l (same as MLP EP)
+ inputs_free = [x, h_free[0].detach(), h_free[1].detach(), h_free[2].detach()]
+ inputs_nudge = [x, h_nudged[0].detach(), h_nudged[1].detach(), h_nudged[2].detach()]
+
+ for l in range(L):
+ inp_f = inputs_free[l].detach()
+ inp_n = inputs_nudge[l].detach()
+
+ if l == 3:
+ f_free = model.blocks[l](inp_f.flatten(1) if inp_f.dim() > 2 else inp_f)
+ f_nudge = model.blocks[l](inp_n.flatten(1) if inp_n.dim() > 2 else inp_n)
+ else:
+ f_free = model.blocks[l](inp_f)
+ f_nudge = model.blocks[l](inp_n)
+
+ # residuals (detached target - computed output)
+ res_free = (flat(h_free[l]).detach() - flat(f_free).detach()) # (B, d_l)
+ res_nudge = (flat(h_nudged[l]).detach() - flat(f_nudge).detach())
+
+ # dE/dtheta = -(res * dF/dtheta) => gradient via autograd trick
+ # loss_free_l = -(res_free * f_l_free).sum() gives dE_free/dtheta
+ # loss_nudge_l = -(res_nudge * f_l_nudge).sum() gives dE_nudge/dtheta
+ loss_free_l = -(res_free * flat(f_free)).sum()
+ loss_nudge_l = -(res_nudge * flat(f_nudge)).sum()
+
+ ep_loss_l = (loss_nudge_l - loss_free_l) / beta
+ ep_loss_l.backward()
+
+ # Head: CE on nudged h3
+ head_loss = F.cross_entropy(model.out_head(h_nudged[3].detach()), y)
+ head_loss.backward()
+
+ torch.nn.utils.clip_grad_norm_(list(model.parameters()), 1.0)
+ for o in all_opts:
+ o.step()
+
+ for s in schedulers:
+ s.step()
+ if ep % 20 == 0:
+ print(f" Ep {ep}: acc={evaluate(model, tel, dev):.4f}", flush=True)
+
+ return model
+
+
+# ---------------------------------------------------------------------------
+# Diagnostics
+# ---------------------------------------------------------------------------
+
+def compute_bp_grads(model, x, y):
+ """
+ Compute BP gradients w.r.t. each hidden state h_l via autograd.
+ Returns list of grad tensors (same shape as h_l), and the hidden states.
+ """
+ model.eval()
+ L = model.num_blocks
+
+ # Re-run forward with requires_grad on intermediate activations
+ # We build the forward manually to hook into each h_l
+ h = [None] * L
+ inp = x
+ for l in range(L):
+ if l == 3:
+ inp = inp.flatten(1) if inp.dim() > 2 else inp
+ h[l] = model.blocks[l](inp.detach().requires_grad_(False))
+ h[l] = h[l].detach().requires_grad_(True)
+ inp = h[l]
+
+ logits = model.out_head(h[3])
+ loss = F.cross_entropy(logits, y)
+ gs = torch.autograd.grad(loss, h, allow_unused=True)
+ return [g.detach() if g is not None else torch.zeros_like(h[i]) for i, g in enumerate(gs)], h
+
+
+def compute_diagnostics(model, tel, dev, method, beta=0.5, T_nudge=20, alpha_nudge=0.05):
+ model.eval()
+ L = model.num_blocks
+
+ # Grab one batch
+ for x, y in tel:
+ x, y = x.to(dev), y.to(dev)
+ break
+
+ # BP gradients
+ bp_grads, h_bp = compute_bp_grads(model, x, y)
+
+ # Credit signals depending on method
+ if method == 'ep':
+ with torch.no_grad():
+ _, h_free = model(x, return_hidden=True)
+ h_nudged = ep_nudged_phase_cnn(model, x, y, h_free, beta, T_nudge, alpha_nudge)
+ credits = [flat((h_nudged[l] - h_free[l]) / beta) for l in range(L)]
+ else:
+ # For BP and DFA, use BP grads directly (BP self-cosine = 1 by definition)
+ credits = [flat(bp_grads[l]) for l in range(L)]
+
+ bp_grads_flat = [flat(g) for g in bp_grads]
+
+ # Gamma: cosine similarity between credit and BP grad
+ gammas = []
+ for l in range(L):
+ g = cosine_similarity_batch(credits[l], bp_grads_flat[l])
+ gammas.append(float(g))
+
+ # rho: perturbation correlation using forward_from
+ with torch.no_grad():
+ _, hiddens = model(x, return_hidden=True)
+
+ rhos = []
+ for l in range(L):
+ h_l = flat(hiddens[l].detach()) # (B, d_l)
+ a_l = credits[l].detach() # (B, d_l)
+
+ # forward_fn: perturbed flat h_l -> per-sample CE loss
+ # we need to run from layer l+1 onward
+ def make_forward_fn(layer_idx):
+ def forward_fn(h_flat):
+ """h_flat: (B, d_l) flat tensor at layer layer_idx output."""
+ with torch.no_grad():
+ # Reshape back to spatial if needed
+ c = h_flat
+ for i in range(layer_idx + 1, L):
+ if i == 3:
+ c = model.blocks[i](c.flatten(1) if c.dim() > 2 else c)
+ else:
+ # blocks 1,2 expect spatial input; but c here is flat
+ # only happens for i=1 (in_dim 32*32*32->spatial 32,32,32)
+ # and i=2 (64,16,16). Since layer_idx<i we reshape.
+ if layer_idx < 3:
+ # Reconstruct spatial shape from flat
+ shapes = [(32, 32, 32), (64, 16, 16), (128, 8, 8)]
+ C_s, H_s, W_s = shapes[i - 1]
+ c = c.view(c.size(0), C_s, H_s, W_s)
+ c = model.blocks[i](c)
+ if c.dim() > 2:
+ c = c.flatten(1)
+ logits = model.out_head(c)
+ return F.cross_entropy(logits, y, reduction='none')
+ return forward_fn
+
+ rho = perturbation_correlation(h_l, a_l, make_forward_fn(l), epsilon=1e-3, M=16)
+ rhos.append(float(rho))
+
+ return {
+ 'Gamma': float(np.mean(gammas)),
+ 'rho': float(np.mean(rhos)),
+ 'gammas_per_layer': gammas,
+ 'rhos_per_layer': rhos,
+ }
+
+
+# ---------------------------------------------------------------------------
+# Main
+# ---------------------------------------------------------------------------
+
+def main():
+ p = argparse.ArgumentParser(description='CNN baseline for CIFAR-10')
+ p.add_argument('--method', type=str, required=True, choices=['bp', 'dfa', 'ep'])
+ p.add_argument('--seed', type=int, required=True)
+ p.add_argument('--gpu', type=int, default=0)
+ p.add_argument('--output_dir', type=str, default='results/cnn_baseline')
+ p.add_argument('--epochs', type=int, default=100)
+ p.add_argument('--lr', type=float, default=1e-3)
+ p.add_argument('--wd', type=float, default=0.01)
+ # EP hyperparameters
+ p.add_argument('--beta', type=float, default=0.5)
+ p.add_argument('--T_nudge', type=int, default=20)
+ p.add_argument('--alpha_nudge', type=float, default=0.05)
+ args = p.parse_args()
+
+ os.makedirs(args.output_dir, exist_ok=True)
+ dev = torch.device(f'cuda:{args.gpu}')
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ trl, tel = get_cifar10()
+ model = SmallCNN().to(dev)
+
+ print(f"[{args.method} s={args.seed}] Training CNN on CIFAR-10 for {args.epochs} epochs...", flush=True)
+
+ if args.method == 'bp':
+ model = train_bp(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd)
+ elif args.method == 'dfa':
+ model = train_dfa(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd)
+ elif args.method == 'ep':
+ model = train_ep(model, trl, tel, dev, epochs=args.epochs, lr=args.lr, wd=args.wd,
+ beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge)
+
+ acc = evaluate(model, tel, dev)
+ diag = compute_diagnostics(model, tel, dev, args.method,
+ beta=args.beta, T_nudge=args.T_nudge, alpha_nudge=args.alpha_nudge)
+
+ # Save checkpoint
+ ckpt_path = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.pt')
+ torch.save(model.state_dict(), ckpt_path)
+
+ result = {
+ 'method': args.method,
+ 'seed': args.seed,
+ 'acc': float(acc),
+ 'Gamma': diag['Gamma'],
+ 'rho': diag['rho'],
+ 'gammas_per_layer': diag['gammas_per_layer'],
+ 'rhos_per_layer': diag['rhos_per_layer'],
+ 'epochs': args.epochs,
+ 'lr': args.lr,
+ 'wd': args.wd,
+ 'beta': args.beta,
+ 'T_nudge': args.T_nudge,
+ 'alpha_nudge': args.alpha_nudge,
+ }
+
+ json_path = os.path.join(args.output_dir, f'{args.method}_s{args.seed}.json')
+ with open(json_path, 'w') as f:
+ json.dump(result, f, indent=2, default=float)
+
+ print(
+ f"[{args.method} s={args.seed}] acc={acc:.4f} "
+ f"Gamma={diag['Gamma']:.4f} rho={diag['rho']:.4f}",
+ flush=True,
+ )
+ print(f" gammas_per_layer={[f'{g:.4f}' for g in diag['gammas_per_layer']]}", flush=True)
+ print(f" rhos_per_layer ={[f'{r:.4f}' for r in diag['rhos_per_layer']]}", flush=True)
+ print(f" Saved: {json_path}", flush=True)
+
+
+if __name__ == '__main__':
+ main()