""" Synthetic Nonlinearity Ladder: teacher-student classification with controllable nonlinearity. Teacher dynamics: h_{l+1}^* = h_l^* + W_l^* phi_alpha(h_l^*) phi_alpha(z) = (1-alpha)*z + alpha*tanh(z) alpha=0 -> linear, alpha=1 -> fully nonlinear. Sweep (alpha, L) to find where state bridge fails and credit bridge degrades. Methods: BP, DFA, State Bridge, Credit Bridge (all reuse project conventions). """ import os import sys import json import argparse import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset import copy sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP, ResidualBlock from models.value_net import ValueNet, create_ema_model, update_ema from models.state_bridge import StateBridgeNet from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test, offline_bp_cosine, feature_drift ) # ============================================================================= # Teacher network and data generation # ============================================================================= class TeacherNet: """Fixed teacher network with controllable nonlinearity.""" def __init__(self, d_hidden, num_blocks, num_classes, alpha, seed=0): rng = np.random.RandomState(seed) self.d_hidden = d_hidden self.num_blocks = num_blocks self.num_classes = num_classes self.alpha = alpha # Teacher block weights: scaled for stability self.Ws = [] for l in range(num_blocks): W = rng.randn(d_hidden, d_hidden).astype(np.float32) # Scale so spectral norm of residual part is small W = W / (np.linalg.norm(W, ord=2) + 1e-8) * 0.3 self.Ws.append(torch.from_numpy(W)) # Output projection U = rng.randn(num_classes, d_hidden).astype(np.float32) U = U / (np.linalg.norm(U, ord=2) + 1e-8) self.U = torch.from_numpy(U) def to(self, device): self.Ws = [W.to(device) for W in self.Ws] self.U = self.U.to(device) return self def phi(self, z): """Controllable activation: (1-alpha)*z + alpha*tanh(z).""" return (1 - self.alpha) * z + self.alpha * torch.tanh(z) def forward(self, h0): """Forward pass through teacher, returns logits and hidden states.""" h = h0 hiddens = [h] for l in range(self.num_blocks): f = F.linear(self.phi(h), self.Ws[l]) h = h + f hiddens.append(h) logits = F.linear(h, self.U) return logits, hiddens def generate_dataset(teacher, num_samples, d_hidden, device, seed=0): """Generate classification dataset from teacher network.""" torch.manual_seed(seed) X = torch.randn(num_samples, d_hidden, device=device) with torch.no_grad(): logits, _ = teacher.forward(X) Y = logits.argmax(dim=-1) return X, Y # ============================================================================= # Student network (same architecture family as teacher) # ============================================================================= class StudentBlock(nn.Module): """Residual block with pre-LayerNorm + phi_alpha.""" def __init__(self, d_hidden, alpha): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.w = nn.Linear(d_hidden, d_hidden, bias=False) self.alpha = alpha # Small init for stability nn.init.normal_(self.w.weight, std=0.01) def phi(self, z): return (1 - self.alpha) * z + self.alpha * torch.tanh(z) def forward(self, h): """Returns residual F_l(h), NOT h + F_l(h).""" return self.w(self.phi(self.ln(h))) class StudentNet(nn.Module): """Student network: L residual blocks + linear output head.""" def __init__(self, d_hidden, num_classes, num_blocks, alpha): super().__init__() self.blocks = nn.ModuleList([StudentBlock(d_hidden, alpha) for _ in range(num_blocks)]) self.out_head = nn.Linear(d_hidden, num_classes) self.num_blocks = num_blocks self.d_hidden = d_hidden def forward(self, x, return_hidden=False): h = x # No embedding needed, input is already d_hidden hiddens = [h] if return_hidden else None for block in self.blocks: f = block(h) h = h + f if return_hidden: hiddens.append(h) logits = self.out_head(h) if return_hidden: return logits, hiddens return logits def forward_from_layer(self, h, start_layer): """Run forward from a given layer to output.""" for i in range(start_layer, self.num_blocks): f = self.blocks[i](h) h = h + f return self.out_head(h) # ============================================================================= # Training methods # ============================================================================= def evaluate(model, test_loader, device): model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in test_loader: x, y = x.to(device), y.to(device) logits = model(x) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return correct / total def train_bp(model, train_loader, test_loader, device, args): """Standard BP training.""" optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) log = {'train_loss': [], 'train_acc': [], 'test_acc': []} for epoch in range(1, args.epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for x, y in train_loader: x, y = x.to(device), y.to(device) logits = model(x) loss = F.cross_entropy(logits, y) optimizer.zero_grad() loss.backward() optimizer.step() total_loss += loss.item() * x.size(0) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) scheduler.step() log['train_loss'].append(total_loss / total) log['train_acc'].append(correct / total) log['test_acc'].append(evaluate(model, test_loader, device)) if epoch % 20 == 0 or epoch == 1: print(f" [BP] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f}") return log def train_dfa(model, train_loader, test_loader, device, args): """DFA training with fixed random feedback matrices.""" d = model.d_hidden L = model.num_blocks C = args.num_classes Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) for block in model.blocks] head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) log = {'train_loss': [], 'train_acc': [], 'test_acc': []} for epoch in range(1, args.epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for x, y in train_loader: x, y = x.to(device), y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 # Update output head (h_L detached) hL_det = hiddens[-1].detach() logits_out = model.out_head(hL_det) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update each block with DFA local surrogate for l in range(L): h_l = hiddens[l].detach() a_dfa = (e_T @ Bs[l].T).detach() rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a_dfa / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for s in all_schedulers: s.step() log['train_loss'].append(total_loss / total) log['train_acc'].append(correct / total) log['test_acc'].append(evaluate(model, test_loader, device)) if epoch % 20 == 0 or epoch == 1: print(f" [DFA] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f}") return log, Bs def train_state_bridge(model, train_loader, test_loader, device, args): """State Bridge training.""" d = model.d_hidden L = model.num_blocks C = args.num_classes state_pred = StateBridgeNet(d_hidden=d, s_dim=C, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) for block in model.blocks] head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd) state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'state_pred_error': []} for epoch in range(1, args.epochs + 1): model.train() state_pred.train() total_loss, correct, total = 0, 0, 0 total_se = 0 for x, y in train_loader: x, y = x.to(device), y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL_det = hiddens[-1].detach() # Train state predictor state_loss = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) target = hL_det target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0) state_loss = state_loss + (((pred_hL - target) / target_norm) ** 2).sum(dim=-1).mean() state_loss = state_loss / L state_opt.zero_grad() state_loss.backward() state_opt.step() total_se += state_loss.item() * batch # Compute credits credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) pred_logits = model.out_head(pred_hL) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0] credits.append(a_l.detach()) # Update output head logits_out = model.out_head(hL_det) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for sch in all_schedulers: sch.step() log['train_loss'].append(total_loss / total) log['train_acc'].append(correct / total) log['test_acc'].append(evaluate(model, test_loader, device)) log['state_pred_error'].append(total_se / total) if epoch % 20 == 0 or epoch == 1: print(f" [SB] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} " f"se={log['state_pred_error'][-1]:.6f}") return log, state_pred def train_credit_bridge(model, train_loader, test_loader, device, args): """Credit Bridge training with terminal gradient matching + bridge consistency.""" d = model.d_hidden L = model.num_blocks C = args.num_classes warmup_epochs = max(1, args.epochs // 5) value_net = ValueNet(d_hidden=d, s_dim=C, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) value_net_ema = create_ema_model(value_net) Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) for block in model.blocks] head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd) value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) lam = args.lam K_samples = args.K sigma_bridge = args.sigma_bridge ema_momentum = args.ema_momentum term_grad_weight = args.term_grad_weight log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': [], 'term_loss': [], 'bridge_loss': [], 'tgrad_loss': []} for epoch in range(1, args.epochs + 1): model.train() value_net.train() total_loss, correct, total = 0, 0, 0 total_vloss, total_term, total_bridge, total_tgrad = 0, 0, 0, 0 if epoch <= warmup_epochs: credit_blend = 0.0 else: credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) for x, y in train_loader: x, y = x.to(device), y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL_det = hiddens[-1].detach() # ---- Train value net ---- t_L = torch.ones(batch, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() loss_tgrad = torch.tensor(0.0, device=device) if term_grad_weight > 0: hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] hL_req2 = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(hL_req2) ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_l_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K_samples): noise = sigma_bridge * torch.randn_like(h_next_det) V_next = value_net_ema(h_next_det + noise, t_l_next, s) log_terms.append(-V_next / lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad value_opt.zero_grad() value_loss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, ema_momentum) total_vloss += value_loss.item() * batch total_term += loss_term.item() * batch total_bridge += (loss_bridge.item() if isinstance(loss_bridge, torch.Tensor) else loss_bridge) * batch total_tgrad += loss_tgrad.item() * batch # ---- Compute credits ---- cb_credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V_l = value_net(h_l_det, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] cb_credits.append(a_l.detach()) dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] credits = [] for l in range(L): if credit_blend >= 1.0: a = cb_credits[l] elif credit_blend <= 0.0: a = dfa_credits[l] else: cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) credits.append(a) # ---- Update output head ---- logits_out = model.out_head(hL_det) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # ---- Update blocks ---- for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for sch in all_schedulers: sch.step() log['train_loss'].append(total_loss / total) log['train_acc'].append(correct / total) log['test_acc'].append(evaluate(model, test_loader, device)) log['value_loss'].append(total_vloss / total) log['term_loss'].append(total_term / total) log['bridge_loss'].append(total_bridge / total) log['tgrad_loss'].append(total_tgrad / total) if epoch % 20 == 0 or epoch == 1: phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" print(f" [CB] Ep {epoch} ({phase}): loss={log['train_loss'][-1]:.4f} " f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} " f"vloss={log['value_loss'][-1]:.6f}") return log, value_net, value_net_ema # ============================================================================= # Diagnostics (per-layer) # ============================================================================= def compute_diagnostics(model, method_name, test_loader, device, args, value_net=None, state_predictor=None, dfa_Bs=None): """Compute per-layer diagnostic metrics.""" model.eval() if value_net is not None: value_net.eval() if state_predictor is not None: state_predictor.eval() d = model.d_hidden L = model.num_blocks C = args.num_classes # Get one batch for x, y in test_loader: x, y = x.to(device), y.to(device) break batch = x.size(0) # BP gradients (for offline BP cosine) # Manual forward to get gradable hidden states h = x.detach().requires_grad_(True) hiddens_bp = [h] for block in model.blocks: f = block(hiddens_bp[-1]) h_next = hiddens_bp[-1] + f hiddens_bp.append(h_next) logits_bp = model.out_head(hiddens_bp[-1]) loss_bp = F.cross_entropy(logits_bp, y) bp_grads = {} grads = torch.autograd.grad(loss_bp, hiddens_bp, retain_graph=False) for l in range(L + 1): bp_grads[l] = grads[l].detach().clone() # Clean forward with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() results = { 'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.001': [], '0.003': [], '0.01': []}, } # State bridge prediction error (if applicable) if method_name == 'state_bridge' and state_predictor is not None: state_pred_errors = [] for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) with torch.no_grad(): pred_hL = state_predictor(h_l_det, t_l, s) hL_det = hiddens[-1].detach() err = ((pred_hL - hL_det) ** 2).sum(dim=-1).mean().item() state_pred_errors.append(err) results['state_pred_error_per_layer'] = state_pred_errors for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) if method_name == 'bp': a_l = bp_grads[l] elif method_name == 'dfa': a_l = (e_T @ dfa_Bs[l].T).detach() elif method_name == 'state_bridge': h_l_req = h_l.clone().requires_grad_(True) pred_hL = state_predictor(h_l_req, t_l, s) pred_logits = model.out_head(pred_hL) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() elif method_name == 'credit_bridge': h_l_req = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() else: raise ValueError(f"Unknown method: {method_name}") # BP cosine bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) results['bp_cosine'].append(bp_cos) # Forward function for perturbation and nudging def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): curr = h for i in range(start_l, L): curr = curr + model.blocks[i](curr) out = model.out_head(curr) return F.cross_entropy(out, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) results['perturbation_rho'].append(rho) for eta in [0.001, 0.003, 0.01]: nud = nudging_test(h_l, a_l, fwd_fn, eta=eta) results['nudging'][str(eta)].append(nud) return results # ============================================================================= # Main experiment runner # ============================================================================= def run_single(alpha, L, seed, args, device): """Run all methods for a single (alpha, L, seed) configuration.""" d = args.d_hidden C = args.num_classes print(f"\n === alpha={alpha}, L={L}, seed={seed} ===") t0 = time.time() # Generate data from teacher teacher = TeacherNet(d, L, C, alpha, seed=0).to(device) # Fixed teacher seed=0 X_train, Y_train = generate_dataset(teacher, args.n_train, d, device, seed=seed) X_test, Y_test = generate_dataset(teacher, args.n_test, d, device, seed=seed + 10000) train_ds = TensorDataset(X_train, Y_train) test_ds = TensorDataset(X_test, Y_test) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True) test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False) results = {} # ---- BP ---- print(" --- BP ---") torch.manual_seed(seed) model_bp = StudentNet(d, C, L, alpha).to(device) bp_log = train_bp(model_bp, train_loader, test_loader, device, args) bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args) results['bp'] = {'log': bp_log, 'diagnostics': bp_diag} # ---- DFA ---- print(" --- DFA ---") torch.manual_seed(seed) model_dfa = StudentNet(d, C, L, alpha).to(device) dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args) dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs) results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag} # ---- State Bridge ---- print(" --- State Bridge ---") torch.manual_seed(seed) model_sb = StudentNet(d, C, L, alpha).to(device) sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args) sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args, state_predictor=state_pred) results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag} # ---- Credit Bridge ---- print(" --- Credit Bridge ---") torch.manual_seed(seed) model_cb = StudentNet(d, C, L, alpha).to(device) cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args) cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args, value_net=vnet) results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag} elapsed = time.time() - t0 print(f" === Done alpha={alpha}, L={L}, seed={seed} in {elapsed:.1f}s ===") # Summary for m in ['bp', 'dfa', 'state_bridge', 'credit_bridge']: test_acc = results[m]['log']['test_acc'][-1] mean_gamma = np.mean(results[m]['diagnostics']['bp_cosine']) mean_rho = np.mean(results[m]['diagnostics']['perturbation_rho']) mean_nudge = np.mean(results[m]['diagnostics']['nudging']['0.01']) print(f" {m:20s}: acc={test_acc:.4f} Gamma={mean_gamma:.4f} " f"rho={mean_rho:.4f} nudge={mean_nudge:.6f}") return results def serialize(obj): if isinstance(obj, dict): return {str(k): serialize(v) for k, v in obj.items()} elif isinstance(obj, list): return [serialize(v) for v in obj] elif isinstance(obj, (np.floating, np.integer)): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, torch.Tensor): return obj.cpu().numpy().tolist() return obj def main(): parser = argparse.ArgumentParser(description='Synthetic Nonlinearity Ladder') parser.add_argument('--alphas', type=float, nargs='+', default=[0.0, 0.5, 1.0]) parser.add_argument('--depths', type=int, nargs='+', default=[2, 8]) parser.add_argument('--seeds', type=int, nargs='+', default=[42]) parser.add_argument('--d_hidden', type=int, default=128) parser.add_argument('--num_classes', type=int, default=10) parser.add_argument('--n_train', type=int, default=10000) parser.add_argument('--n_test', type=int, default=2000) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--epochs', type=int, default=60) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--lam', type=float, default=0.1) parser.add_argument('--K', type=int, default=4) parser.add_argument('--sigma_bridge', type=float, default=0.05) parser.add_argument('--ema_momentum', type=float, default=0.995) parser.add_argument('--term_grad_weight', type=float, default=1.0) parser.add_argument('--gpu', type=int, default=1) parser.add_argument('--output_dir', type=str, default='results/synth_ladder') args = parser.parse_args() device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") print(f"Alphas: {args.alphas}") print(f"Depths: {args.depths}") print(f"Seeds: {args.seeds}") os.makedirs(args.output_dir, exist_ok=True) all_results = {} for alpha in args.alphas: for L in args.depths: for seed in args.seeds: key = f"a{alpha}_L{L}_s{seed}" result = run_single(alpha, L, seed, args, device) all_results[key] = result # Save incrementally out_path = os.path.join(args.output_dir, f'synth_{key}.json') with open(out_path, 'w') as f: json.dump(serialize(result), f, indent=2) # Save combined summary summary = {} for key, result in all_results.items(): s = {} for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']: r = result[method] diag = r['diagnostics'] s[method] = { 'test_acc': r['log']['test_acc'][-1], 'mean_bp_cosine': float(np.mean(diag['bp_cosine'])), 'mean_rho': float(np.mean(diag['perturbation_rho'])), 'mean_nudge_001': float(np.mean(diag['nudging']['0.001'])), 'mean_nudge_003': float(np.mean(diag['nudging']['0.003'])), 'mean_nudge_01': float(np.mean(diag['nudging']['0.01'])), 'bp_cosine_per_layer': [float(x) for x in diag['bp_cosine']], 'rho_per_layer': [float(x) for x in diag['perturbation_rho']], 'nudge_per_layer': [float(x) for x in diag['nudging']['0.01']], } if method == 'state_bridge' and 'state_pred_error_per_layer' in diag: s[method]['state_pred_error_per_layer'] = [float(x) for x in diag['state_pred_error_per_layer']] s[method]['mean_state_pred_error'] = float(np.mean(diag['state_pred_error_per_layer'])) if method == 'credit_bridge': s[method]['final_value_loss'] = r['log']['value_loss'][-1] s[method]['final_term_loss'] = r['log']['term_loss'][-1] s[method]['final_bridge_loss'] = r['log']['bridge_loss'][-1] s[method]['final_tgrad_loss'] = r['log']['tgrad_loss'][-1] summary[key] = s summary_path = os.path.join(args.output_dir, 'summary.json') with open(summary_path, 'w') as f: json.dump(summary, f, indent=2) print(f"\nSummary saved to {summary_path}") # Save config config = serialize(vars(args)) config_path = os.path.join(args.output_dir, 'config.json') with open(config_path, 'w') as f: json.dump(config, f, indent=2) # Print final summary table print("\n" + "=" * 100) print("SYNTHETIC NONLINEARITY LADDER - SUMMARY") print("=" * 100) print(f"{'Config':<20} {'Method':<20} {'Acc':>8} {'Gamma':>8} {'rho':>8} {'nudge':>10}") print("-" * 100) for key in sorted(summary.keys()): for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']: s = summary[key][method] print(f"{key:<20} {method:<20} {s['test_acc']:>8.4f} {s['mean_bp_cosine']:>8.4f} " f"{s['mean_rho']:>8.4f} {s['mean_nudge_01']:>10.6f}") print("-" * 100) if __name__ == '__main__': main()