summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 21:04:30 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-23 21:04:30 -0500
commit9940a5e2d3a45fc97eba33fd504bf7b1123a50ab (patch)
tree2084246416ea67b064f01c88cbf2f133e096f2bb /experiments
parent3012cba6032ee04cc0b82c178fbf8df8e47c7d2f (diff)
Add Phase 2 explore experiments: synthetic nonlinearity ladder + CIFAR depth scan
- synth_nonlinearity_ladder.py: teacher-student with phi_alpha(z) = (1-a)z + a*tanh(z) Sweeps alpha x depth to find where state bridge / credit bridge fail - cifar_depth_scan.py: CIFAR-10 with L={2,4,6,8,12}, d={256,512} Finds Goldilocks regime for credit bridge vs DFA - plot_synth_ladder.py: phase diagram visualization Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/cifar_depth_scan.py584
-rw-r--r--experiments/plot_synth_ladder.py449
-rw-r--r--experiments/synth_nonlinearity_ladder.py822
3 files changed, 1855 insertions, 0 deletions
diff --git a/experiments/cifar_depth_scan.py b/experiments/cifar_depth_scan.py
new file mode 100644
index 0000000..0a16201
--- /dev/null
+++ b/experiments/cifar_depth_scan.py
@@ -0,0 +1,584 @@
+"""
+Phase 2: CIFAR-10 Depth Scan.
+Find the "Goldilocks regime" where Credit Bridge outperforms DFA.
+
+Sweep: L in {2, 4, 6, 8, 12}, d in {256, 512}
+Methods: DFA (3 seeds), Credit Bridge (3 seeds), BP (1 seed as reference)
+
+Reuses training logic from cifar_resmlp.py.
+"""
+import os
+import sys
+import json
+import argparse
+import time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test,
+ offline_bp_cosine, feature_drift
+)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader, 32 * 32 * 3, 10
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+# =============================================================================
+# Training methods (adapted from cifar_resmlp.py)
+# =============================================================================
+def train_bp(model, train_loader, test_loader, device, args):
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ scheduler.step()
+ log['train_loss'].append(total_loss / total)
+ log['train_acc'].append(correct / total)
+ log['test_acc'].append(evaluate(model, test_loader, device))
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [BP] Ep {epoch}: loss={log['train_loss'][-1]:.4f} "
+ f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f}")
+ return log
+
+
+def train_dfa(model, train_loader, test_loader, device, args):
+ d = model.d_hidden
+ L = model.num_blocks
+ C = 10
+
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_dfa / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ a_0_dfa = (e_T @ Bs[0].T).detach()
+ rms_0 = (a_0_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0_dfa / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for s in all_schedulers:
+ s.step()
+ log['train_loss'].append(total_loss / total)
+ log['train_acc'].append(correct / total)
+ log['test_acc'].append(evaluate(model, test_loader, device))
+ if epoch % 10 == 0 or epoch == 1:
+ print(f" [DFA] Ep {epoch}: loss={log['train_loss'][-1]:.4f} "
+ f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f}")
+ return log, Bs
+
+
+def train_credit_bridge(model, train_loader, test_loader, device, args):
+ d = model.d_hidden
+ L = model.num_blocks
+ C = 10
+ warmup_epochs = max(1, args.epochs // 5)
+
+ value_net = ValueNet(d_hidden=d, s_dim=C, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ value_net_ema = create_ema_model(value_net)
+
+ Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd
+ )
+ value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ lam = args.lam
+ K_samples = args.K
+ sigma_bridge = args.sigma_bridge
+ ema_momentum = args.ema_momentum
+ term_grad_weight = args.term_grad_weight
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ value_net.train()
+ total_loss, correct, total = 0, 0, 0
+ total_vloss = 0
+
+ if epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # Train value net
+ t_L = torch.ones(batch, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ loss_tgrad = torch.tensor(0.0, device=device)
+ if term_grad_weight > 0:
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ hL_req2 = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req2))
+ ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_l_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K_samples):
+ noise = sigma_bridge * torch.randn_like(h_next_det)
+ V_next = value_net_ema(h_next_det + noise, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples))
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad
+ value_opt.zero_grad()
+ value_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+ total_vloss += value_loss.item() * batch
+
+ # Compute credits
+ cb_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0]
+ cb_credits.append(a_l.detach())
+
+ dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ a = cb_credits[l]
+ elif credit_blend <= 0.0:
+ a = dfa_credits[l]
+ else:
+ cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms)
+ credits.append(a)
+
+ # Update output head
+ logits_out = model.out_head(model.out_ln(hL_det))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ # Update embedding
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_0_norm = a_0 / rms_0
+ h0 = model.embed(x)
+ embed_loss = (h0 * a_0_norm).sum(dim=-1).mean()
+ embed_opt.zero_grad()
+ embed_loss.backward()
+ embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ log['train_loss'].append(total_loss / total)
+ log['train_acc'].append(correct / total)
+ log['test_acc'].append(evaluate(model, test_loader, device))
+ log['value_loss'].append(total_vloss / total)
+ if epoch % 10 == 0 or epoch == 1:
+ phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}"
+ print(f" [CB] Ep {epoch} ({phase}): loss={log['train_loss'][-1]:.4f} "
+ f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} "
+ f"vloss={log['value_loss'][-1]:.6f}")
+ return log, value_net, value_net_ema
+
+
+# =============================================================================
+# Diagnostics
+# =============================================================================
+def compute_diagnostics(model, method_name, test_loader, device, args,
+ value_net=None, dfa_Bs=None):
+ model.eval()
+ if value_net is not None:
+ value_net.eval()
+
+ d = model.d_hidden
+ L = model.num_blocks
+
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ break
+
+ batch = x.size(0)
+
+ # BP gradients via manual forward
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ loss_bp = F.cross_entropy(logits_bp, y)
+ loss_bp.backward()
+ bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)}
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ results = {
+ 'bp_cosine': [],
+ 'perturbation_rho': [],
+ 'nudging': {'0.001': [], '0.003': [], '0.01': []},
+ }
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ if method_name == 'bp':
+ a_l = bp_grads[l]
+ elif method_name == 'dfa':
+ a_l = (e_T @ dfa_Bs[l].T).detach()
+ elif method_name == 'credit_bridge':
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+ else:
+ raise ValueError(f"Unknown method: {method_name}")
+
+ bp_cos = cosine_similarity_batch(a_l, bp_grads[l])
+ results['bp_cosine'].append(bp_cos)
+
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ curr = h
+ for i in range(start_l, L):
+ curr = curr + model.blocks[i](curr)
+ out = model.out_head(model.out_ln(curr))
+ return F.cross_entropy(out, y, reduction='none')
+ return fwd_fn
+
+ fwd_fn = make_fwd_fn(l)
+ rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16)
+ results['perturbation_rho'].append(rho)
+
+ for eta in [0.001, 0.003, 0.01]:
+ nud = nudging_test(h_l, a_l, fwd_fn, eta=eta)
+ results['nudging'][str(eta)].append(nud)
+
+ return results
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def serialize(obj):
+ if isinstance(obj, dict):
+ return {str(k): serialize(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [serialize(v) for v in obj]
+ elif isinstance(obj, (np.floating, np.integer)):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, torch.Tensor):
+ return obj.cpu().numpy().tolist()
+ return obj
+
+
+def run_config(d_hidden, num_blocks, seed, methods, args, device):
+ """Run specified methods for a single (d, L, seed) config."""
+ input_dim = 32 * 32 * 3
+ num_classes = 10
+ args.num_classes = num_classes
+
+ train_loader, test_loader, _, _ = get_cifar10(args.batch_size)
+
+ results = {}
+
+ for method in methods:
+ torch.manual_seed(seed)
+ np.random.seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+ model = ResidualMLP(input_dim, d_hidden, num_classes, num_blocks).to(device)
+ init_params = {n: p.clone().detach() for n, p in model.named_parameters()}
+
+ print(f"\n d={d_hidden}, L={num_blocks}, seed={seed}, method={method}")
+ t0 = time.time()
+
+ if method == 'bp':
+ log = train_bp(model, train_loader, test_loader, device, args)
+ diag = compute_diagnostics(model, 'bp', test_loader, device, args)
+ results['bp'] = {'log': log, 'diagnostics': diag}
+
+ elif method == 'dfa':
+ log, Bs = train_dfa(model, train_loader, test_loader, device, args)
+ diag = compute_diagnostics(model, 'dfa', test_loader, device, args, dfa_Bs=Bs)
+ results['dfa'] = {'log': log, 'diagnostics': diag}
+
+ elif method == 'credit_bridge':
+ log, vnet, _ = train_credit_bridge(model, train_loader, test_loader, device, args)
+ diag = compute_diagnostics(model, 'credit_bridge', test_loader, device, args,
+ value_net=vnet)
+ results['credit_bridge'] = {'log': log, 'diagnostics': diag}
+
+ drift = feature_drift(init_params, {n: p.detach() for n, p in model.named_parameters()})
+ results[method]['drift'] = drift
+
+ elapsed = time.time() - t0
+ test_acc = log['test_acc'][-1]
+ mean_gamma = np.mean(diag['bp_cosine'])
+ mean_rho = np.mean(diag['perturbation_rho'])
+ print(f" Done in {elapsed:.0f}s: acc={test_acc:.4f} Gamma={mean_gamma:.4f} rho={mean_rho:.4f}")
+
+ return results
+
+
+def main():
+ parser = argparse.ArgumentParser(description='CIFAR-10 Depth Scan')
+ parser.add_argument('--depths', type=int, nargs='+', default=[2, 4, 6, 8, 12])
+ parser.add_argument('--widths', type=int, nargs='+', default=[256, 512])
+ parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456])
+ parser.add_argument('--bp_seeds', type=int, nargs='+', default=[42],
+ help='Seeds for BP (reference only)')
+ parser.add_argument('--methods', type=str, nargs='+',
+ default=['bp', 'dfa', 'credit_bridge'])
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--sigma_bridge', type=float, default=0.05)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--term_grad_weight', type=float, default=1.0)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/cifar_depth_scan')
+ args = parser.parse_args()
+
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Device: {device}")
+ print(f"Depths: {args.depths}, Widths: {args.widths}")
+ print(f"Seeds: {args.seeds}, BP seeds: {args.bp_seeds}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ all_summary = {}
+
+ for d_hidden in args.widths:
+ for num_blocks in args.depths:
+ for seed in args.seeds:
+ key = f"d{d_hidden}_L{num_blocks}_s{seed}"
+
+ # Determine which methods to run for this seed
+ methods = []
+ for m in args.methods:
+ if m == 'bp' and seed not in args.bp_seeds:
+ continue
+ methods.append(m)
+
+ if not methods:
+ continue
+
+ print(f"\n{'='*60}")
+ print(f"Config: {key}, methods: {methods}")
+ print(f"{'='*60}")
+
+ result = run_config(d_hidden, num_blocks, seed, methods, args, device)
+
+ # Save per-config result
+ out_path = os.path.join(args.output_dir, f'{key}.json')
+ with open(out_path, 'w') as f:
+ json.dump(serialize(result), f, indent=2)
+
+ # Summary
+ summary = {}
+ for method in result:
+ diag = result[method]['diagnostics']
+ summary[method] = {
+ 'test_acc': result[method]['log']['test_acc'][-1],
+ 'mean_bp_cosine': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge_01': float(np.mean(diag['nudging']['0.01'])),
+ 'bp_cosine_per_layer': [float(x) for x in diag['bp_cosine']],
+ 'rho_per_layer': [float(x) for x in diag['perturbation_rho']],
+ }
+ all_summary[key] = summary
+
+ # Save full summary
+ summary_path = os.path.join(args.output_dir, 'summary.json')
+ with open(summary_path, 'w') as f:
+ json.dump(all_summary, f, indent=2)
+ print(f"\nSummary saved to {summary_path}")
+
+ # Print table
+ print("\n" + "=" * 90)
+ print("CIFAR-10 DEPTH SCAN SUMMARY")
+ print("=" * 90)
+ print(f"{'Config':<25} {'Method':<18} {'Acc':>8} {'Gamma':>8} {'rho':>8} {'nudge':>10}")
+ print("-" * 90)
+ for key in sorted(all_summary.keys()):
+ for method in sorted(all_summary[key].keys()):
+ s = all_summary[key][method]
+ print(f"{key:<25} {method:<18} {s['test_acc']:>8.4f} {s['mean_bp_cosine']:>8.4f} "
+ f"{s['mean_rho']:>8.4f} {s['mean_nudge_01']:>10.6f}")
+ print()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/plot_synth_ladder.py b/experiments/plot_synth_ladder.py
new file mode 100644
index 0000000..a3fb0b4
--- /dev/null
+++ b/experiments/plot_synth_ladder.py
@@ -0,0 +1,449 @@
+"""Generate phase diagram plots from synthetic nonlinearity ladder results."""
+import os
+import sys
+import json
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+from collections import defaultdict
+
+output_dir = 'report_explore'
+os.makedirs(output_dir, exist_ok=True)
+
+
+def load_summaries(result_dirs):
+ """Load and merge summary.json files from multiple result directories."""
+ merged = {}
+ for d in result_dirs:
+ path = os.path.join(d, 'summary.json')
+ if os.path.exists(path):
+ with open(path) as f:
+ data = json.load(f)
+ merged.update(data)
+ return merged
+
+
+def parse_key(key):
+ """Parse 'a0.5_L8_s42' -> (alpha, L, seed)."""
+ parts = key.split('_')
+ alpha = float(parts[0][1:])
+ L = int(parts[1][1:])
+ seed = int(parts[2][1:])
+ return alpha, L, seed
+
+
+def aggregate_seeds(summary):
+ """Aggregate over seeds: group by (alpha, L)."""
+ groups = defaultdict(list)
+ for key, data in summary.items():
+ alpha, L, seed = parse_key(key)
+ groups[(alpha, L)].append(data)
+
+ agg = {}
+ for (alpha, L), entries in groups.items():
+ agg[(alpha, L)] = {}
+ for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']:
+ vals = {}
+ for metric in ['test_acc', 'mean_bp_cosine', 'mean_rho', 'mean_nudge_01',
+ 'mean_nudge_001', 'mean_nudge_003']:
+ arr = [e[method].get(metric, 0) for e in entries if method in e]
+ # Filter out blown-up values
+ arr = [v for v in arr if abs(v) < 1e6]
+ if arr:
+ vals[f'{metric}_mean'] = np.mean(arr)
+ vals[f'{metric}_std'] = np.std(arr)
+ else:
+ vals[f'{metric}_mean'] = np.nan
+ vals[f'{metric}_std'] = np.nan
+
+ # State prediction error
+ if method == 'state_bridge':
+ arr = [e[method].get('mean_state_pred_error', 0) for e in entries if method in e]
+ arr = [v for v in arr if abs(v) < 1e6]
+ if arr:
+ vals['state_pred_error_mean'] = np.mean(arr)
+ vals['state_pred_error_std'] = np.std(arr)
+ else:
+ vals['state_pred_error_mean'] = np.nan
+ vals['state_pred_error_std'] = np.nan
+
+ # Per-layer diagnostics (average over seeds)
+ for pl_metric in ['bp_cosine_per_layer', 'rho_per_layer', 'nudge_per_layer']:
+ arrays = []
+ for e in entries:
+ if method in e and pl_metric in e[method]:
+ arr = e[method][pl_metric]
+ # Check for blowup
+ if all(abs(v) < 1e6 for v in arr):
+ arrays.append(arr)
+ if arrays:
+ arr2d = np.array(arrays)
+ vals[f'{pl_metric}_mean'] = arr2d.mean(axis=0).tolist()
+ vals[f'{pl_metric}_std'] = arr2d.std(axis=0).tolist()
+
+ agg[(alpha, L)][method] = vals
+ return agg
+
+
+def plot_phase_diagrams(agg, alphas, depths, save_prefix='synth'):
+ """Generate the 4 key phase diagram plots."""
+ methods = ['dfa', 'state_bridge', 'credit_bridge']
+ colors = {'bp': '#F44336', 'dfa': '#2196F3', 'state_bridge': '#FF9800', 'credit_bridge': '#4CAF50'}
+ labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'}
+
+ alphas = sorted(alphas)
+ depths = sorted(depths)
+
+ # ========================================================================
+ # Plot 1: Phase diagram heatmaps (Gamma, rho, nudge) for each method
+ # ========================================================================
+ fig, axes = plt.subplots(3, 4, figsize=(20, 12))
+ metrics_info = [
+ ('mean_bp_cosine_mean', 'BP Cosine (Γ)', 'RdYlGn', -0.1, 1.0),
+ ('mean_rho_mean', 'Perturbation ρ', 'RdYlGn', -0.2, 1.0),
+ ('mean_nudge_01_mean', 'Nudge (η=0.01)', 'RdYlGn_r', None, None),
+ ]
+ all_methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge']
+
+ for row, (metric_key, metric_name, cmap, vmin, vmax) in enumerate(metrics_info):
+ for col, method in enumerate(all_methods):
+ ax = axes[row, col]
+ grid = np.full((len(alphas), len(depths)), np.nan)
+ for i, alpha in enumerate(alphas):
+ for j, L in enumerate(depths):
+ if (alpha, L) in agg and method in agg[(alpha, L)]:
+ val = agg[(alpha, L)][method].get(metric_key, np.nan)
+ grid[i, j] = val
+
+ if vmin is None:
+ valid = grid[~np.isnan(grid)]
+ if len(valid) > 0:
+ vmin_use = np.nanmin(grid)
+ vmax_use = min(0, np.nanmax(grid)) # nudge should be <=0
+ else:
+ vmin_use, vmax_use = -1, 0
+ else:
+ vmin_use, vmax_use = vmin, vmax
+
+ im = ax.imshow(grid, cmap=cmap, vmin=vmin_use, vmax=vmax_use, aspect='auto',
+ origin='lower')
+ ax.set_xticks(range(len(depths)))
+ ax.set_xticklabels([str(d) for d in depths])
+ ax.set_yticks(range(len(alphas)))
+ ax.set_yticklabels([str(a) for a in alphas])
+
+ # Annotate cells
+ for i in range(len(alphas)):
+ for j in range(len(depths)):
+ val = grid[i, j]
+ if not np.isnan(val):
+ txt = f'{val:.3f}' if abs(val) < 100 else f'{val:.1e}'
+ ax.text(j, i, txt, ha='center', va='center', fontsize=8,
+ color='black' if abs(val) < 0.5 else 'white')
+ else:
+ ax.text(j, i, 'X', ha='center', va='center', fontsize=10,
+ color='red', fontweight='bold')
+
+ if row == 0:
+ ax.set_title(labels[method], fontsize=12)
+ if col == 0:
+ ax.set_ylabel(f'{metric_name}\nα (nonlinearity)', fontsize=10)
+ if row == len(metrics_info) - 1:
+ ax.set_xlabel('Depth (L)', fontsize=10)
+ plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04)
+
+ fig.suptitle('Synthetic Ladder: Phase Diagrams (X = blown up)', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_phase_heatmaps.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_phase_heatmaps.png")
+
+ # ========================================================================
+ # Plot 2: Line plots - metric vs alpha for each depth
+ # ========================================================================
+ fig, axes = plt.subplots(2, max(len(depths), 1), figsize=(6 * len(depths), 10), squeeze=False)
+
+ for j, L in enumerate(depths):
+ # Row 0: Gamma
+ ax = axes[0, j]
+ for method in all_methods:
+ vals = []
+ errs = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg and method in agg[(alpha, L)]:
+ v = agg[(alpha, L)][method].get('mean_bp_cosine_mean', np.nan)
+ e = agg[(alpha, L)][method].get('mean_bp_cosine_std', 0)
+ if not np.isnan(v):
+ vals.append(v)
+ errs.append(e)
+ valid_alphas.append(alpha)
+ if vals:
+ ax.errorbar(valid_alphas, vals, yerr=errs, marker='o', color=colors[method],
+ label=labels[method], capsize=3, linewidth=2, markersize=6)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel('Mean BP Cosine (Γ)', fontsize=11)
+ ax.set_title(f'L={L}', fontsize=13)
+ ax.legend(fontsize=9)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.1, 1.05)
+
+ # Row 1: rho
+ ax = axes[1, j]
+ for method in all_methods:
+ vals = []
+ errs = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg and method in agg[(alpha, L)]:
+ v = agg[(alpha, L)][method].get('mean_rho_mean', np.nan)
+ e = agg[(alpha, L)][method].get('mean_rho_std', 0)
+ if not np.isnan(v):
+ vals.append(v)
+ errs.append(e)
+ valid_alphas.append(alpha)
+ if vals:
+ ax.errorbar(valid_alphas, vals, yerr=errs, marker='s', color=colors[method],
+ label=labels[method], capsize=3, linewidth=2, markersize=6)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel('Mean Perturbation ρ', fontsize=11)
+ ax.set_title(f'L={L}', fontsize=13)
+ ax.legend(fontsize=9)
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ fig.suptitle('Synthetic Ladder: Credit Quality vs Nonlinearity', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_gamma_rho_vs_alpha.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_gamma_rho_vs_alpha.png")
+
+ # ========================================================================
+ # Plot 3: State bridge prediction error vs credit quality
+ # ========================================================================
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+
+ # Left: State prediction error vs alpha
+ ax = axes[0]
+ for j, L in enumerate(depths):
+ vals = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg and 'state_bridge' in agg[(alpha, L)]:
+ v = agg[(alpha, L)]['state_bridge'].get('state_pred_error_mean', np.nan)
+ if not np.isnan(v):
+ vals.append(v)
+ valid_alphas.append(alpha)
+ if vals:
+ ax.plot(valid_alphas, vals, 'o-', label=f'L={L}', markersize=6, linewidth=2)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel('State Prediction Error', fontsize=11)
+ ax.set_title('State Bridge: Terminal Prediction Error', fontsize=12)
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_yscale('log')
+
+ # Right: State bridge Gamma vs credit bridge Gamma
+ ax = axes[1]
+ for L in depths:
+ sb_gammas = []
+ cb_gammas = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg:
+ sb_g = agg[(alpha, L)].get('state_bridge', {}).get('mean_bp_cosine_mean', np.nan)
+ cb_g = agg[(alpha, L)].get('credit_bridge', {}).get('mean_bp_cosine_mean', np.nan)
+ if not np.isnan(sb_g) and not np.isnan(cb_g):
+ sb_gammas.append(sb_g)
+ cb_gammas.append(cb_g)
+ valid_alphas.append(alpha)
+ if sb_gammas:
+ ax.scatter(sb_gammas, cb_gammas, s=80, label=f'L={L}', zorder=5)
+ for a, sx, sy in zip(valid_alphas, sb_gammas, cb_gammas):
+ ax.annotate(f'α={a}', (sx, sy), textcoords="offset points",
+ xytext=(5, 5), fontsize=8)
+ ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='y=x')
+ ax.set_xlabel('State Bridge Γ', fontsize=11)
+ ax.set_ylabel('Credit Bridge Γ', fontsize=11)
+ ax.set_title('State Bridge vs Credit Bridge BP Cosine', fontsize=12)
+ ax.legend()
+ ax.grid(True, alpha=0.3)
+ ax.set_xlim(-0.1, 1.0)
+ ax.set_ylim(-0.1, 1.0)
+
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_state_vs_credit.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_state_vs_credit.png")
+
+ # ========================================================================
+ # Plot 4: Nudging and accuracy comparison
+ # ========================================================================
+ fig, axes = plt.subplots(1, 2, figsize=(14, 5))
+
+ # Left: Nudge vs alpha for each depth
+ ax = axes[0]
+ for method in ['dfa', 'state_bridge', 'credit_bridge']:
+ for L in depths:
+ vals = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg and method in agg[(alpha, L)]:
+ v = agg[(alpha, L)][method].get('mean_nudge_01_mean', np.nan)
+ if not np.isnan(v):
+ vals.append(v)
+ valid_alphas.append(alpha)
+ if vals:
+ ls = '-' if L == min(depths) else '--'
+ ax.plot(valid_alphas, vals, f'o{ls}', color=colors[method],
+ label=f'{labels[method]} L={L}', markersize=5)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel('Mean Nudge (η=0.01, negative=good)', fontsize=11)
+ ax.set_title('Nudging Test', fontsize=12)
+ ax.legend(fontsize=8, ncol=2)
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ # Right: Test accuracy
+ ax = axes[1]
+ for method in all_methods:
+ for L in depths:
+ vals = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg and method in agg[(alpha, L)]:
+ v = agg[(alpha, L)][method].get('test_acc_mean', np.nan)
+ if not np.isnan(v):
+ vals.append(v)
+ valid_alphas.append(alpha)
+ if vals:
+ ls = '-' if L == min(depths) else '--'
+ ax.plot(valid_alphas, vals, f'o{ls}', color=colors[method],
+ label=f'{labels[method]} L={L}', markersize=5)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel('Test Accuracy', fontsize=11)
+ ax.set_title('Test Accuracy', fontsize=12)
+ ax.legend(fontsize=8, ncol=2)
+ ax.grid(True, alpha=0.3)
+
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_nudge_acc.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_nudge_acc.png")
+
+ # ========================================================================
+ # Plot 5: Per-layer diagnostics for selected (alpha, L) combos
+ # ========================================================================
+ combos = [(a, L) for a in alphas for L in depths if (a, L) in agg]
+ # Only plot non-blown-up combos
+ valid_combos = []
+ for a, L in combos:
+ if not np.isnan(agg[(a, L)].get('credit_bridge', {}).get('mean_bp_cosine_mean', np.nan)):
+ valid_combos.append((a, L))
+
+ if valid_combos:
+ n_combos = len(valid_combos)
+ fig, axes = plt.subplots(2, n_combos, figsize=(5 * n_combos, 8), squeeze=False)
+
+ for idx, (alpha, L) in enumerate(valid_combos):
+ # Row 0: BP cosine per layer
+ ax = axes[0, idx]
+ for method in all_methods:
+ if method in agg[(alpha, L)]:
+ pl = agg[(alpha, L)][method].get('bp_cosine_per_layer_mean', None)
+ if pl is not None:
+ ax.plot(range(len(pl)), pl, 'o-', color=colors[method],
+ label=labels[method], markersize=4)
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('BP Cosine')
+ ax.set_title(f'α={alpha}, L={L}', fontsize=11)
+ ax.legend(fontsize=8)
+ ax.grid(True, alpha=0.3)
+ ax.set_ylim(-0.3, 1.05)
+
+ # Row 1: rho per layer
+ ax = axes[1, idx]
+ for method in all_methods:
+ if method in agg[(alpha, L)]:
+ pl = agg[(alpha, L)][method].get('rho_per_layer_mean', None)
+ if pl is not None:
+ ax.plot(range(len(pl)), pl, 's-', color=colors[method],
+ label=labels[method], markersize=4)
+ ax.set_xlabel('Layer')
+ ax.set_ylabel('Perturbation ρ')
+ ax.set_title(f'α={alpha}, L={L}', fontsize=11)
+ ax.legend(fontsize=8)
+ ax.grid(True, alpha=0.3)
+ ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5)
+
+ fig.suptitle('Per-Layer Diagnostics', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_per_layer.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_per_layer.png")
+
+ # ========================================================================
+ # Plot 6: Advantage plots: CB - DFA
+ # ========================================================================
+ fig, axes = plt.subplots(1, 3, figsize=(18, 5))
+ metrics = [
+ ('mean_bp_cosine_mean', 'Γ(CB) - Γ(DFA)'),
+ ('mean_rho_mean', 'ρ(CB) - ρ(DFA)'),
+ ('mean_nudge_01_mean', 'nudge(CB) - nudge(DFA)'),
+ ]
+
+ for ax, (metric_key, ylabel) in zip(axes, metrics):
+ for L in depths:
+ diffs = []
+ valid_alphas = []
+ for alpha in alphas:
+ if (alpha, L) in agg:
+ cb_v = agg[(alpha, L)].get('credit_bridge', {}).get(metric_key, np.nan)
+ dfa_v = agg[(alpha, L)].get('dfa', {}).get(metric_key, np.nan)
+ if not np.isnan(cb_v) and not np.isnan(dfa_v):
+ diffs.append(cb_v - dfa_v)
+ valid_alphas.append(alpha)
+ if diffs:
+ ax.plot(valid_alphas, diffs, 'o-', label=f'L={L}', markersize=6, linewidth=2)
+ ax.set_xlabel('α (nonlinearity)', fontsize=11)
+ ax.set_ylabel(ylabel, fontsize=11)
+ ax.axhline(y=0, color='red', linestyle='--', alpha=0.7, label='zero (parity)')
+ ax.legend(fontsize=10)
+ ax.grid(True, alpha=0.3)
+
+ fig.suptitle('Credit Bridge Advantage over DFA', fontsize=14, y=1.02)
+ fig.tight_layout()
+ fig.savefig(os.path.join(output_dir, f'{save_prefix}_cb_advantage.png'), dpi=150, bbox_inches='tight')
+ plt.close(fig)
+ print(f"Saved {save_prefix}_cb_advantage.png")
+
+
+if __name__ == '__main__':
+ import sys
+ result_dirs = sys.argv[1:] if len(sys.argv) > 1 else ['results/synth_ladder_smoke']
+
+ print(f"Loading from: {result_dirs}")
+ summary = load_summaries(result_dirs)
+
+ if not summary:
+ print("No results found!")
+ sys.exit(1)
+
+ # Extract alphas and depths from keys
+ alphas = set()
+ depths = set()
+ for key in summary:
+ alpha, L, seed = parse_key(key)
+ alphas.add(alpha)
+ depths.add(L)
+
+ alphas = sorted(alphas)
+ depths = sorted(depths)
+ print(f"Alphas: {alphas}")
+ print(f"Depths: {depths}")
+ print(f"Total configs: {len(summary)}")
+
+ agg = aggregate_seeds(summary)
+ plot_phase_diagrams(agg, alphas, depths)
+ print("\nAll plots generated!")
diff --git a/experiments/synth_nonlinearity_ladder.py b/experiments/synth_nonlinearity_ladder.py
new file mode 100644
index 0000000..d5ed9aa
--- /dev/null
+++ b/experiments/synth_nonlinearity_ladder.py
@@ -0,0 +1,822 @@
+"""
+Synthetic Nonlinearity Ladder: teacher-student classification with controllable nonlinearity.
+
+Teacher dynamics:
+ h_{l+1}^* = h_l^* + W_l^* phi_alpha(h_l^*)
+ phi_alpha(z) = (1-alpha)*z + alpha*tanh(z)
+
+alpha=0 -> linear, alpha=1 -> fully nonlinear.
+Sweep (alpha, L) to find where state bridge fails and credit bridge degrades.
+
+Methods: BP, DFA, State Bridge, Credit Bridge (all reuse project conventions).
+"""
+import os
+import sys
+import json
+import argparse
+import time
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP, ResidualBlock
+from models.value_net import ValueNet, create_ema_model, update_ema
+from models.state_bridge import StateBridgeNet
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test,
+ offline_bp_cosine, feature_drift
+)
+
+
+# =============================================================================
+# Teacher network and data generation
+# =============================================================================
+class TeacherNet:
+ """Fixed teacher network with controllable nonlinearity."""
+
+ def __init__(self, d_hidden, num_blocks, num_classes, alpha, seed=0):
+ rng = np.random.RandomState(seed)
+ self.d_hidden = d_hidden
+ self.num_blocks = num_blocks
+ self.num_classes = num_classes
+ self.alpha = alpha
+
+ # Teacher block weights: scaled for stability
+ self.Ws = []
+ for l in range(num_blocks):
+ W = rng.randn(d_hidden, d_hidden).astype(np.float32)
+ # Scale so spectral norm of residual part is small
+ W = W / (np.linalg.norm(W, ord=2) + 1e-8) * 0.3
+ self.Ws.append(torch.from_numpy(W))
+
+ # Output projection
+ U = rng.randn(num_classes, d_hidden).astype(np.float32)
+ U = U / (np.linalg.norm(U, ord=2) + 1e-8)
+ self.U = torch.from_numpy(U)
+
+ def to(self, device):
+ self.Ws = [W.to(device) for W in self.Ws]
+ self.U = self.U.to(device)
+ return self
+
+ def phi(self, z):
+ """Controllable activation: (1-alpha)*z + alpha*tanh(z)."""
+ return (1 - self.alpha) * z + self.alpha * torch.tanh(z)
+
+ def forward(self, h0):
+ """Forward pass through teacher, returns logits and hidden states."""
+ h = h0
+ hiddens = [h]
+ for l in range(self.num_blocks):
+ f = F.linear(self.phi(h), self.Ws[l])
+ h = h + f
+ hiddens.append(h)
+ logits = F.linear(h, self.U)
+ return logits, hiddens
+
+
+def generate_dataset(teacher, num_samples, d_hidden, device, seed=0):
+ """Generate classification dataset from teacher network."""
+ torch.manual_seed(seed)
+ X = torch.randn(num_samples, d_hidden, device=device)
+ with torch.no_grad():
+ logits, _ = teacher.forward(X)
+ Y = logits.argmax(dim=-1)
+ return X, Y
+
+
+# =============================================================================
+# Student network (same architecture family as teacher)
+# =============================================================================
+class StudentBlock(nn.Module):
+ """Residual block with pre-LayerNorm + phi_alpha."""
+
+ def __init__(self, d_hidden, alpha):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.w = nn.Linear(d_hidden, d_hidden, bias=False)
+ self.alpha = alpha
+ # Small init for stability
+ nn.init.normal_(self.w.weight, std=0.01)
+
+ def phi(self, z):
+ return (1 - self.alpha) * z + self.alpha * torch.tanh(z)
+
+ def forward(self, h):
+ """Returns residual F_l(h), NOT h + F_l(h)."""
+ return self.w(self.phi(self.ln(h)))
+
+
+class StudentNet(nn.Module):
+ """Student network: L residual blocks + linear output head."""
+
+ def __init__(self, d_hidden, num_classes, num_blocks, alpha):
+ super().__init__()
+ self.blocks = nn.ModuleList([StudentBlock(d_hidden, alpha) for _ in range(num_blocks)])
+ self.out_head = nn.Linear(d_hidden, num_classes)
+ self.num_blocks = num_blocks
+ self.d_hidden = d_hidden
+
+ def forward(self, x, return_hidden=False):
+ h = x # No embedding needed, input is already d_hidden
+ hiddens = [h] if return_hidden else None
+ for block in self.blocks:
+ f = block(h)
+ h = h + f
+ if return_hidden:
+ hiddens.append(h)
+ logits = self.out_head(h)
+ if return_hidden:
+ return logits, hiddens
+ return logits
+
+ def forward_from_layer(self, h, start_layer):
+ """Run forward from a given layer to output."""
+ for i in range(start_layer, self.num_blocks):
+ f = self.blocks[i](h)
+ h = h + f
+ return self.out_head(h)
+
+
+# =============================================================================
+# Training methods
+# =============================================================================
+def evaluate(model, test_loader, device):
+ model.eval()
+ correct, total = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x, y = x.to(device), y.to(device)
+ logits = model(x)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return correct / total
+
+
+def train_bp(model, train_loader, test_loader, device, args):
+ """Standard BP training."""
+ optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs)
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ total_loss += loss.item() * x.size(0)
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ scheduler.step()
+ log['train_loss'].append(total_loss / total)
+ log['train_acc'].append(correct / total)
+ log['test_acc'].append(evaluate(model, test_loader, device))
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [BP] Ep {epoch}: loss={log['train_loss'][-1]:.4f} "
+ f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f}")
+ return log
+
+
+def train_dfa(model, train_loader, test_loader, device, args):
+ """DFA training with fixed random feedback matrices."""
+ d = model.d_hidden
+ L = model.num_blocks
+ C = args.num_classes
+
+ Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+
+ # Update output head (h_L detached)
+ hL_det = hiddens[-1].detach()
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Update each block with DFA local surrogate
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_dfa = (e_T @ Bs[l].T).detach()
+ rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a_dfa / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for s in all_schedulers:
+ s.step()
+ log['train_loss'].append(total_loss / total)
+ log['train_acc'].append(correct / total)
+ log['test_acc'].append(evaluate(model, test_loader, device))
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [DFA] Ep {epoch}: loss={log['train_loss'][-1]:.4f} "
+ f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f}")
+ return log, Bs
+
+
+def train_state_bridge(model, train_loader, test_loader, device, args):
+ """State Bridge training."""
+ d = model.d_hidden
+ L = model.num_blocks
+ C = args.num_classes
+
+ state_pred = StateBridgeNet(d_hidden=d, s_dim=C, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd)
+ state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'state_pred_error': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ state_pred.train()
+ total_loss, correct, total = 0, 0, 0
+ total_se = 0
+
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # Train state predictor
+ state_loss = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ pred_hL = state_pred(h_l_det, t_l, s)
+ target = hL_det
+ target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0)
+ state_loss = state_loss + (((pred_hL - target) / target_norm) ** 2).sum(dim=-1).mean()
+ state_loss = state_loss / L
+ state_opt.zero_grad()
+ state_loss.backward()
+ state_opt.step()
+ total_se += state_loss.item() * batch
+
+ # Compute credits
+ credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ pred_hL = state_pred(h_l_det, t_l, s)
+ pred_logits = model.out_head(pred_hL)
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0]
+ credits.append(a_l.detach())
+
+ # Update output head
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ log['train_loss'].append(total_loss / total)
+ log['train_acc'].append(correct / total)
+ log['test_acc'].append(evaluate(model, test_loader, device))
+ log['state_pred_error'].append(total_se / total)
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [SB] Ep {epoch}: loss={log['train_loss'][-1]:.4f} "
+ f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} "
+ f"se={log['state_pred_error'][-1]:.6f}")
+ return log, state_pred
+
+
+def train_credit_bridge(model, train_loader, test_loader, device, args):
+ """Credit Bridge training with terminal gradient matching + bridge consistency."""
+ d = model.d_hidden
+ L = model.num_blocks
+ C = args.num_classes
+ warmup_epochs = max(1, args.epochs // 5)
+
+ value_net = ValueNet(d_hidden=d, s_dim=C, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ value_net_ema = create_ema_model(value_net)
+
+ Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)]
+
+ block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd)
+ for block in model.blocks]
+ head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd)
+ value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb)
+
+ all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts]
+ + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)])
+
+ lam = args.lam
+ K_samples = args.K
+ sigma_bridge = args.sigma_bridge
+ ema_momentum = args.ema_momentum
+ term_grad_weight = args.term_grad_weight
+
+ log = {'train_loss': [], 'train_acc': [], 'test_acc': [],
+ 'value_loss': [], 'term_loss': [], 'bridge_loss': [], 'tgrad_loss': []}
+
+ for epoch in range(1, args.epochs + 1):
+ model.train()
+ value_net.train()
+ total_loss, correct, total = 0, 0, 0
+ total_vloss, total_term, total_bridge, total_tgrad = 0, 0, 0, 0
+
+ if epoch <= warmup_epochs:
+ credit_blend = 0.0
+ else:
+ credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs))
+
+ for x, y in train_loader:
+ x, y = x.to(device), y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+
+ hL_det = hiddens[-1].detach()
+
+ # ---- Train value net ----
+ t_L = torch.ones(batch, device=device)
+ V_terminal = value_net(hL_det, t_L, s)
+ loss_term = ((V_terminal - true_loss) ** 2).mean()
+
+ loss_tgrad = torch.tensor(0.0, device=device)
+ if term_grad_weight > 0:
+ hL_req = hL_det.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ hL_req2 = hL_det.clone().requires_grad_(True)
+ logits_tgt = model.out_head(hL_req2)
+ ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_l_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+
+ with torch.no_grad():
+ h_next_det = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(K_samples):
+ noise = sigma_bridge * torch.randn_like(h_next_det)
+ V_next = value_net_ema(h_next_det + noise, t_l_next, s)
+ log_terms.append(-V_next / lam)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples))
+
+ loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge = loss_bridge / L
+
+ value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad
+ value_opt.zero_grad()
+ value_loss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, ema_momentum)
+
+ total_vloss += value_loss.item() * batch
+ total_term += loss_term.item() * batch
+ total_bridge += (loss_bridge.item() if isinstance(loss_bridge, torch.Tensor) else loss_bridge) * batch
+ total_tgrad += loss_tgrad.item() * batch
+
+ # ---- Compute credits ----
+ cb_credits = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V_l = value_net(h_l_det, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0]
+ cb_credits.append(a_l.detach())
+
+ dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if credit_blend >= 1.0:
+ a = cb_credits[l]
+ elif credit_blend <= 0.0:
+ a = dfa_credits[l]
+ else:
+ cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms)
+ credits.append(a)
+
+ # ---- Update output head ----
+ logits_out = model.out_head(hL_det)
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad()
+ loss_out.backward()
+ head_opt.step()
+
+ # ---- Update blocks ----
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_opts[l].zero_grad()
+ local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in all_schedulers:
+ sch.step()
+
+ log['train_loss'].append(total_loss / total)
+ log['train_acc'].append(correct / total)
+ log['test_acc'].append(evaluate(model, test_loader, device))
+ log['value_loss'].append(total_vloss / total)
+ log['term_loss'].append(total_term / total)
+ log['bridge_loss'].append(total_bridge / total)
+ log['tgrad_loss'].append(total_tgrad / total)
+ if epoch % 20 == 0 or epoch == 1:
+ phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}"
+ print(f" [CB] Ep {epoch} ({phase}): loss={log['train_loss'][-1]:.4f} "
+ f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} "
+ f"vloss={log['value_loss'][-1]:.6f}")
+ return log, value_net, value_net_ema
+
+
+# =============================================================================
+# Diagnostics (per-layer)
+# =============================================================================
+def compute_diagnostics(model, method_name, test_loader, device, args,
+ value_net=None, state_predictor=None, dfa_Bs=None):
+ """Compute per-layer diagnostic metrics."""
+ model.eval()
+ if value_net is not None:
+ value_net.eval()
+ if state_predictor is not None:
+ state_predictor.eval()
+
+ d = model.d_hidden
+ L = model.num_blocks
+ C = args.num_classes
+
+ # Get one batch
+ for x, y in test_loader:
+ x, y = x.to(device), y.to(device)
+ break
+
+ batch = x.size(0)
+
+ # BP gradients (for offline BP cosine)
+ # Manual forward to get gradable hidden states
+ h = x.detach().requires_grad_(True)
+ hiddens_bp = [h]
+ for block in model.blocks:
+ f = block(hiddens_bp[-1])
+ h_next = hiddens_bp[-1] + f
+ hiddens_bp.append(h_next)
+ logits_bp = model.out_head(hiddens_bp[-1])
+ loss_bp = F.cross_entropy(logits_bp, y)
+ bp_grads = {}
+ grads = torch.autograd.grad(loss_bp, hiddens_bp, retain_graph=False)
+ for l in range(L + 1):
+ bp_grads[l] = grads[l].detach().clone()
+
+ # Clean forward
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ results = {
+ 'bp_cosine': [],
+ 'perturbation_rho': [],
+ 'nudging': {'0.001': [], '0.003': [], '0.01': []},
+ }
+
+ # State bridge prediction error (if applicable)
+ if method_name == 'state_bridge' and state_predictor is not None:
+ state_pred_errors = []
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ with torch.no_grad():
+ pred_hL = state_predictor(h_l_det, t_l, s)
+ hL_det = hiddens[-1].detach()
+ err = ((pred_hL - hL_det) ** 2).sum(dim=-1).mean().item()
+ state_pred_errors.append(err)
+ results['state_pred_error_per_layer'] = state_pred_errors
+
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ if method_name == 'bp':
+ a_l = bp_grads[l]
+ elif method_name == 'dfa':
+ a_l = (e_T @ dfa_Bs[l].T).detach()
+ elif method_name == 'state_bridge':
+ h_l_req = h_l.clone().requires_grad_(True)
+ pred_hL = state_predictor(h_l_req, t_l, s)
+ pred_logits = model.out_head(pred_hL)
+ pred_loss = F.cross_entropy(pred_logits, y, reduction='sum')
+ a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach()
+ elif method_name == 'credit_bridge':
+ h_l_req = h_l.clone().requires_grad_(True)
+ V_l = value_net(h_l_req, t_l, s)
+ a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach()
+ else:
+ raise ValueError(f"Unknown method: {method_name}")
+
+ # BP cosine
+ bp_cos = cosine_similarity_batch(a_l, bp_grads[l])
+ results['bp_cosine'].append(bp_cos)
+
+ # Forward function for perturbation and nudging
+ def make_fwd_fn(start_l):
+ def fwd_fn(h):
+ with torch.no_grad():
+ curr = h
+ for i in range(start_l, L):
+ curr = curr + model.blocks[i](curr)
+ out = model.out_head(curr)
+ return F.cross_entropy(out, y, reduction='none')
+ return fwd_fn
+
+ fwd_fn = make_fwd_fn(l)
+ rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16)
+ results['perturbation_rho'].append(rho)
+
+ for eta in [0.001, 0.003, 0.01]:
+ nud = nudging_test(h_l, a_l, fwd_fn, eta=eta)
+ results['nudging'][str(eta)].append(nud)
+
+ return results
+
+
+# =============================================================================
+# Main experiment runner
+# =============================================================================
+def run_single(alpha, L, seed, args, device):
+ """Run all methods for a single (alpha, L, seed) configuration."""
+ d = args.d_hidden
+ C = args.num_classes
+
+ print(f"\n === alpha={alpha}, L={L}, seed={seed} ===")
+ t0 = time.time()
+
+ # Generate data from teacher
+ teacher = TeacherNet(d, L, C, alpha, seed=0).to(device) # Fixed teacher seed=0
+ X_train, Y_train = generate_dataset(teacher, args.n_train, d, device, seed=seed)
+ X_test, Y_test = generate_dataset(teacher, args.n_test, d, device, seed=seed + 10000)
+
+ train_ds = TensorDataset(X_train, Y_train)
+ test_ds = TensorDataset(X_test, Y_test)
+ train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True)
+ test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False)
+
+ results = {}
+
+ # ---- BP ----
+ print(" --- BP ---")
+ torch.manual_seed(seed)
+ model_bp = StudentNet(d, C, L, alpha).to(device)
+ bp_log = train_bp(model_bp, train_loader, test_loader, device, args)
+ bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args)
+ results['bp'] = {'log': bp_log, 'diagnostics': bp_diag}
+
+ # ---- DFA ----
+ print(" --- DFA ---")
+ torch.manual_seed(seed)
+ model_dfa = StudentNet(d, C, L, alpha).to(device)
+ dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args)
+ dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs)
+ results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag}
+
+ # ---- State Bridge ----
+ print(" --- State Bridge ---")
+ torch.manual_seed(seed)
+ model_sb = StudentNet(d, C, L, alpha).to(device)
+ sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args)
+ sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args,
+ state_predictor=state_pred)
+ results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag}
+
+ # ---- Credit Bridge ----
+ print(" --- Credit Bridge ---")
+ torch.manual_seed(seed)
+ model_cb = StudentNet(d, C, L, alpha).to(device)
+ cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args)
+ cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args,
+ value_net=vnet)
+ results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag}
+
+ elapsed = time.time() - t0
+ print(f" === Done alpha={alpha}, L={L}, seed={seed} in {elapsed:.1f}s ===")
+
+ # Summary
+ for m in ['bp', 'dfa', 'state_bridge', 'credit_bridge']:
+ test_acc = results[m]['log']['test_acc'][-1]
+ mean_gamma = np.mean(results[m]['diagnostics']['bp_cosine'])
+ mean_rho = np.mean(results[m]['diagnostics']['perturbation_rho'])
+ mean_nudge = np.mean(results[m]['diagnostics']['nudging']['0.01'])
+ print(f" {m:20s}: acc={test_acc:.4f} Gamma={mean_gamma:.4f} "
+ f"rho={mean_rho:.4f} nudge={mean_nudge:.6f}")
+
+ return results
+
+
+def serialize(obj):
+ if isinstance(obj, dict):
+ return {str(k): serialize(v) for k, v in obj.items()}
+ elif isinstance(obj, list):
+ return [serialize(v) for v in obj]
+ elif isinstance(obj, (np.floating, np.integer)):
+ return float(obj)
+ elif isinstance(obj, np.ndarray):
+ return obj.tolist()
+ elif isinstance(obj, torch.Tensor):
+ return obj.cpu().numpy().tolist()
+ return obj
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Synthetic Nonlinearity Ladder')
+ parser.add_argument('--alphas', type=float, nargs='+', default=[0.0, 0.5, 1.0])
+ parser.add_argument('--depths', type=int, nargs='+', default=[2, 8])
+ parser.add_argument('--seeds', type=int, nargs='+', default=[42])
+ parser.add_argument('--d_hidden', type=int, default=128)
+ parser.add_argument('--num_classes', type=int, default=10)
+ parser.add_argument('--n_train', type=int, default=10000)
+ parser.add_argument('--n_test', type=int, default=2000)
+ parser.add_argument('--batch_size', type=int, default=256)
+ parser.add_argument('--epochs', type=int, default=60)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--lam', type=float, default=0.1)
+ parser.add_argument('--K', type=int, default=4)
+ parser.add_argument('--sigma_bridge', type=float, default=0.05)
+ parser.add_argument('--ema_momentum', type=float, default=0.995)
+ parser.add_argument('--term_grad_weight', type=float, default=1.0)
+ parser.add_argument('--gpu', type=int, default=1)
+ parser.add_argument('--output_dir', type=str, default='results/synth_ladder')
+ args = parser.parse_args()
+
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Device: {device}")
+ print(f"Alphas: {args.alphas}")
+ print(f"Depths: {args.depths}")
+ print(f"Seeds: {args.seeds}")
+
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ all_results = {}
+
+ for alpha in args.alphas:
+ for L in args.depths:
+ for seed in args.seeds:
+ key = f"a{alpha}_L{L}_s{seed}"
+ result = run_single(alpha, L, seed, args, device)
+ all_results[key] = result
+
+ # Save incrementally
+ out_path = os.path.join(args.output_dir, f'synth_{key}.json')
+ with open(out_path, 'w') as f:
+ json.dump(serialize(result), f, indent=2)
+
+ # Save combined summary
+ summary = {}
+ for key, result in all_results.items():
+ s = {}
+ for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']:
+ r = result[method]
+ diag = r['diagnostics']
+ s[method] = {
+ 'test_acc': r['log']['test_acc'][-1],
+ 'mean_bp_cosine': float(np.mean(diag['bp_cosine'])),
+ 'mean_rho': float(np.mean(diag['perturbation_rho'])),
+ 'mean_nudge_001': float(np.mean(diag['nudging']['0.001'])),
+ 'mean_nudge_003': float(np.mean(diag['nudging']['0.003'])),
+ 'mean_nudge_01': float(np.mean(diag['nudging']['0.01'])),
+ 'bp_cosine_per_layer': [float(x) for x in diag['bp_cosine']],
+ 'rho_per_layer': [float(x) for x in diag['perturbation_rho']],
+ 'nudge_per_layer': [float(x) for x in diag['nudging']['0.01']],
+ }
+ if method == 'state_bridge' and 'state_pred_error_per_layer' in diag:
+ s[method]['state_pred_error_per_layer'] = [float(x) for x in diag['state_pred_error_per_layer']]
+ s[method]['mean_state_pred_error'] = float(np.mean(diag['state_pred_error_per_layer']))
+ if method == 'credit_bridge':
+ s[method]['final_value_loss'] = r['log']['value_loss'][-1]
+ s[method]['final_term_loss'] = r['log']['term_loss'][-1]
+ s[method]['final_bridge_loss'] = r['log']['bridge_loss'][-1]
+ s[method]['final_tgrad_loss'] = r['log']['tgrad_loss'][-1]
+ summary[key] = s
+
+ summary_path = os.path.join(args.output_dir, 'summary.json')
+ with open(summary_path, 'w') as f:
+ json.dump(summary, f, indent=2)
+ print(f"\nSummary saved to {summary_path}")
+
+ # Save config
+ config = serialize(vars(args))
+ config_path = os.path.join(args.output_dir, 'config.json')
+ with open(config_path, 'w') as f:
+ json.dump(config, f, indent=2)
+
+ # Print final summary table
+ print("\n" + "=" * 100)
+ print("SYNTHETIC NONLINEARITY LADDER - SUMMARY")
+ print("=" * 100)
+ print(f"{'Config':<20} {'Method':<20} {'Acc':>8} {'Gamma':>8} {'rho':>8} {'nudge':>10}")
+ print("-" * 100)
+ for key in sorted(summary.keys()):
+ for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']:
+ s = summary[key][method]
+ print(f"{key:<20} {method:<20} {s['test_acc']:>8.4f} {s['mean_bp_cosine']:>8.4f} "
+ f"{s['mean_rho']:>8.4f} {s['mean_nudge_01']:>10.6f}")
+ print("-" * 100)
+
+
+if __name__ == '__main__':
+ main()