diff options
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/cifar_depth_scan.py | 584 | ||||
| -rw-r--r-- | experiments/plot_synth_ladder.py | 449 | ||||
| -rw-r--r-- | experiments/synth_nonlinearity_ladder.py | 822 |
3 files changed, 1855 insertions, 0 deletions
diff --git a/experiments/cifar_depth_scan.py b/experiments/cifar_depth_scan.py new file mode 100644 index 0000000..0a16201 --- /dev/null +++ b/experiments/cifar_depth_scan.py @@ -0,0 +1,584 @@ +""" +Phase 2: CIFAR-10 Depth Scan. +Find the "Goldilocks regime" where Credit Bridge outperforms DFA. + +Sweep: L in {2, 4, 6, 8, 12}, d in {256, 512} +Methods: DFA (3 seeds), Credit Bridge (3 seeds), BP (1 seed as reference) + +Reuses training logic from cifar_resmlp.py. +""" +import os +import sys +import json +import argparse +import time +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test, + offline_bp_cosine, feature_drift +) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader, 32 * 32 * 3, 10 + + +def evaluate(model, test_loader, device): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + return correct / total + + +# ============================================================================= +# Training methods (adapted from cifar_resmlp.py) +# ============================================================================= +def train_bp(model, train_loader, test_loader, device, args): + optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + loss = F.cross_entropy(logits, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + scheduler.step() + log['train_loss'].append(total_loss / total) + log['train_acc'].append(correct / total) + log['test_acc'].append(evaluate(model, test_loader, device)) + if epoch % 10 == 0 or epoch == 1: + print(f" [BP] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " + f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f}") + return log + + +def train_dfa(model, train_loader, test_loader, device, args): + d = model.d_hidden + L = model.num_blocks + C = 10 + + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + + hL_det = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + for l in range(L): + h_l = hiddens[l].detach() + a_dfa = (e_T @ Bs[l].T).detach() + rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_dfa / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + a_0_dfa = (e_T @ Bs[0].T).detach() + rms_0 = (a_0_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0_dfa / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for s in all_schedulers: + s.step() + log['train_loss'].append(total_loss / total) + log['train_acc'].append(correct / total) + log['test_acc'].append(evaluate(model, test_loader, device)) + if epoch % 10 == 0 or epoch == 1: + print(f" [DFA] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " + f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f}") + return log, Bs + + +def train_credit_bridge(model, train_loader, test_loader, device, args): + d = model.d_hidden + L = model.num_blocks + C = 10 + warmup_epochs = max(1, args.epochs // 5) + + value_net = ValueNet(d_hidden=d, s_dim=C, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + value_net_ema = create_ema_model(value_net) + + Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + lam = args.lam + K_samples = args.K + sigma_bridge = args.sigma_bridge + ema_momentum = args.ema_momentum + term_grad_weight = args.term_grad_weight + + log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []} + + for epoch in range(1, args.epochs + 1): + model.train() + value_net.train() + total_loss, correct, total = 0, 0, 0 + total_vloss = 0 + + if epoch <= warmup_epochs: + credit_blend = 0.0 + else: + credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + + hL_det = hiddens[-1].detach() + + # Train value net + t_L = torch.ones(batch, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + loss_tgrad = torch.tensor(0.0, device=device) + if term_grad_weight > 0: + hL_req = hL_det.clone().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + hL_req2 = hL_det.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req2)) + ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') + a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() + loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + t_l_next = torch.full((batch,), (l + 1) / L, device=device) + V_l = value_net(h_l_det, t_l, s) + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K_samples): + noise = sigma_bridge * torch.randn_like(h_next_det) + V_next = value_net_ema(h_next_det + noise, t_l_next, s) + log_terms.append(-V_next / lam) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples)) + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + loss_bridge = loss_bridge / L + + value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad + value_opt.zero_grad() + value_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + value_opt.step() + update_ema(value_net, value_net_ema, ema_momentum) + total_vloss += value_loss.item() * batch + + # Compute credits + cb_credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V_l = value_net(h_l_det, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] + cb_credits.append(a_l.detach()) + + dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] + + credits = [] + for l in range(L): + if credit_blend >= 1.0: + a = cb_credits[l] + elif credit_blend <= 0.0: + a = dfa_credits[l] + else: + cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) + credits.append(a) + + # Update output head + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # Update blocks + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + # Update embedding + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0 / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in all_schedulers: + sch.step() + + log['train_loss'].append(total_loss / total) + log['train_acc'].append(correct / total) + log['test_acc'].append(evaluate(model, test_loader, device)) + log['value_loss'].append(total_vloss / total) + if epoch % 10 == 0 or epoch == 1: + phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" + print(f" [CB] Ep {epoch} ({phase}): loss={log['train_loss'][-1]:.4f} " + f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} " + f"vloss={log['value_loss'][-1]:.6f}") + return log, value_net, value_net_ema + + +# ============================================================================= +# Diagnostics +# ============================================================================= +def compute_diagnostics(model, method_name, test_loader, device, args, + value_net=None, dfa_Bs=None): + model.eval() + if value_net is not None: + value_net.eval() + + d = model.d_hidden + L = model.num_blocks + + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + break + + batch = x.size(0) + + # BP gradients via manual forward + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + loss_bp = F.cross_entropy(logits_bp, y) + loss_bp.backward() + bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + results = { + 'bp_cosine': [], + 'perturbation_rho': [], + 'nudging': {'0.001': [], '0.003': [], '0.01': []}, + } + + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + if method_name == 'bp': + a_l = bp_grads[l] + elif method_name == 'dfa': + a_l = (e_T @ dfa_Bs[l].T).detach() + elif method_name == 'credit_bridge': + h_l_req = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() + else: + raise ValueError(f"Unknown method: {method_name}") + + bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) + results['bp_cosine'].append(bp_cos) + + def make_fwd_fn(start_l): + def fwd_fn(h): + with torch.no_grad(): + curr = h + for i in range(start_l, L): + curr = curr + model.blocks[i](curr) + out = model.out_head(model.out_ln(curr)) + return F.cross_entropy(out, y, reduction='none') + return fwd_fn + + fwd_fn = make_fwd_fn(l) + rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) + results['perturbation_rho'].append(rho) + + for eta in [0.001, 0.003, 0.01]: + nud = nudging_test(h_l, a_l, fwd_fn, eta=eta) + results['nudging'][str(eta)].append(nud) + + return results + + +# ============================================================================= +# Main +# ============================================================================= +def serialize(obj): + if isinstance(obj, dict): + return {str(k): serialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [serialize(v) for v in obj] + elif isinstance(obj, (np.floating, np.integer)): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, torch.Tensor): + return obj.cpu().numpy().tolist() + return obj + + +def run_config(d_hidden, num_blocks, seed, methods, args, device): + """Run specified methods for a single (d, L, seed) config.""" + input_dim = 32 * 32 * 3 + num_classes = 10 + args.num_classes = num_classes + + train_loader, test_loader, _, _ = get_cifar10(args.batch_size) + + results = {} + + for method in methods: + torch.manual_seed(seed) + np.random.seed(seed) + torch.cuda.manual_seed_all(seed) + + model = ResidualMLP(input_dim, d_hidden, num_classes, num_blocks).to(device) + init_params = {n: p.clone().detach() for n, p in model.named_parameters()} + + print(f"\n d={d_hidden}, L={num_blocks}, seed={seed}, method={method}") + t0 = time.time() + + if method == 'bp': + log = train_bp(model, train_loader, test_loader, device, args) + diag = compute_diagnostics(model, 'bp', test_loader, device, args) + results['bp'] = {'log': log, 'diagnostics': diag} + + elif method == 'dfa': + log, Bs = train_dfa(model, train_loader, test_loader, device, args) + diag = compute_diagnostics(model, 'dfa', test_loader, device, args, dfa_Bs=Bs) + results['dfa'] = {'log': log, 'diagnostics': diag} + + elif method == 'credit_bridge': + log, vnet, _ = train_credit_bridge(model, train_loader, test_loader, device, args) + diag = compute_diagnostics(model, 'credit_bridge', test_loader, device, args, + value_net=vnet) + results['credit_bridge'] = {'log': log, 'diagnostics': diag} + + drift = feature_drift(init_params, {n: p.detach() for n, p in model.named_parameters()}) + results[method]['drift'] = drift + + elapsed = time.time() - t0 + test_acc = log['test_acc'][-1] + mean_gamma = np.mean(diag['bp_cosine']) + mean_rho = np.mean(diag['perturbation_rho']) + print(f" Done in {elapsed:.0f}s: acc={test_acc:.4f} Gamma={mean_gamma:.4f} rho={mean_rho:.4f}") + + return results + + +def main(): + parser = argparse.ArgumentParser(description='CIFAR-10 Depth Scan') + parser.add_argument('--depths', type=int, nargs='+', default=[2, 4, 6, 8, 12]) + parser.add_argument('--widths', type=int, nargs='+', default=[256, 512]) + parser.add_argument('--seeds', type=int, nargs='+', default=[42, 123, 456]) + parser.add_argument('--bp_seeds', type=int, nargs='+', default=[42], + help='Seeds for BP (reference only)') + parser.add_argument('--methods', type=str, nargs='+', + default=['bp', 'dfa', 'credit_bridge']) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=4) + parser.add_argument('--sigma_bridge', type=float, default=0.05) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--term_grad_weight', type=float, default=1.0) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/cifar_depth_scan') + args = parser.parse_args() + + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Device: {device}") + print(f"Depths: {args.depths}, Widths: {args.widths}") + print(f"Seeds: {args.seeds}, BP seeds: {args.bp_seeds}") + os.makedirs(args.output_dir, exist_ok=True) + + all_summary = {} + + for d_hidden in args.widths: + for num_blocks in args.depths: + for seed in args.seeds: + key = f"d{d_hidden}_L{num_blocks}_s{seed}" + + # Determine which methods to run for this seed + methods = [] + for m in args.methods: + if m == 'bp' and seed not in args.bp_seeds: + continue + methods.append(m) + + if not methods: + continue + + print(f"\n{'='*60}") + print(f"Config: {key}, methods: {methods}") + print(f"{'='*60}") + + result = run_config(d_hidden, num_blocks, seed, methods, args, device) + + # Save per-config result + out_path = os.path.join(args.output_dir, f'{key}.json') + with open(out_path, 'w') as f: + json.dump(serialize(result), f, indent=2) + + # Summary + summary = {} + for method in result: + diag = result[method]['diagnostics'] + summary[method] = { + 'test_acc': result[method]['log']['test_acc'][-1], + 'mean_bp_cosine': float(np.mean(diag['bp_cosine'])), + 'mean_rho': float(np.mean(diag['perturbation_rho'])), + 'mean_nudge_01': float(np.mean(diag['nudging']['0.01'])), + 'bp_cosine_per_layer': [float(x) for x in diag['bp_cosine']], + 'rho_per_layer': [float(x) for x in diag['perturbation_rho']], + } + all_summary[key] = summary + + # Save full summary + summary_path = os.path.join(args.output_dir, 'summary.json') + with open(summary_path, 'w') as f: + json.dump(all_summary, f, indent=2) + print(f"\nSummary saved to {summary_path}") + + # Print table + print("\n" + "=" * 90) + print("CIFAR-10 DEPTH SCAN SUMMARY") + print("=" * 90) + print(f"{'Config':<25} {'Method':<18} {'Acc':>8} {'Gamma':>8} {'rho':>8} {'nudge':>10}") + print("-" * 90) + for key in sorted(all_summary.keys()): + for method in sorted(all_summary[key].keys()): + s = all_summary[key][method] + print(f"{key:<25} {method:<18} {s['test_acc']:>8.4f} {s['mean_bp_cosine']:>8.4f} " + f"{s['mean_rho']:>8.4f} {s['mean_nudge_01']:>10.6f}") + print() + + +if __name__ == '__main__': + main() diff --git a/experiments/plot_synth_ladder.py b/experiments/plot_synth_ladder.py new file mode 100644 index 0000000..a3fb0b4 --- /dev/null +++ b/experiments/plot_synth_ladder.py @@ -0,0 +1,449 @@ +"""Generate phase diagram plots from synthetic nonlinearity ladder results.""" +import os +import sys +import json +import numpy as np +import matplotlib +matplotlib.use('Agg') +import matplotlib.pyplot as plt +from collections import defaultdict + +output_dir = 'report_explore' +os.makedirs(output_dir, exist_ok=True) + + +def load_summaries(result_dirs): + """Load and merge summary.json files from multiple result directories.""" + merged = {} + for d in result_dirs: + path = os.path.join(d, 'summary.json') + if os.path.exists(path): + with open(path) as f: + data = json.load(f) + merged.update(data) + return merged + + +def parse_key(key): + """Parse 'a0.5_L8_s42' -> (alpha, L, seed).""" + parts = key.split('_') + alpha = float(parts[0][1:]) + L = int(parts[1][1:]) + seed = int(parts[2][1:]) + return alpha, L, seed + + +def aggregate_seeds(summary): + """Aggregate over seeds: group by (alpha, L).""" + groups = defaultdict(list) + for key, data in summary.items(): + alpha, L, seed = parse_key(key) + groups[(alpha, L)].append(data) + + agg = {} + for (alpha, L), entries in groups.items(): + agg[(alpha, L)] = {} + for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']: + vals = {} + for metric in ['test_acc', 'mean_bp_cosine', 'mean_rho', 'mean_nudge_01', + 'mean_nudge_001', 'mean_nudge_003']: + arr = [e[method].get(metric, 0) for e in entries if method in e] + # Filter out blown-up values + arr = [v for v in arr if abs(v) < 1e6] + if arr: + vals[f'{metric}_mean'] = np.mean(arr) + vals[f'{metric}_std'] = np.std(arr) + else: + vals[f'{metric}_mean'] = np.nan + vals[f'{metric}_std'] = np.nan + + # State prediction error + if method == 'state_bridge': + arr = [e[method].get('mean_state_pred_error', 0) for e in entries if method in e] + arr = [v for v in arr if abs(v) < 1e6] + if arr: + vals['state_pred_error_mean'] = np.mean(arr) + vals['state_pred_error_std'] = np.std(arr) + else: + vals['state_pred_error_mean'] = np.nan + vals['state_pred_error_std'] = np.nan + + # Per-layer diagnostics (average over seeds) + for pl_metric in ['bp_cosine_per_layer', 'rho_per_layer', 'nudge_per_layer']: + arrays = [] + for e in entries: + if method in e and pl_metric in e[method]: + arr = e[method][pl_metric] + # Check for blowup + if all(abs(v) < 1e6 for v in arr): + arrays.append(arr) + if arrays: + arr2d = np.array(arrays) + vals[f'{pl_metric}_mean'] = arr2d.mean(axis=0).tolist() + vals[f'{pl_metric}_std'] = arr2d.std(axis=0).tolist() + + agg[(alpha, L)][method] = vals + return agg + + +def plot_phase_diagrams(agg, alphas, depths, save_prefix='synth'): + """Generate the 4 key phase diagram plots.""" + methods = ['dfa', 'state_bridge', 'credit_bridge'] + colors = {'bp': '#F44336', 'dfa': '#2196F3', 'state_bridge': '#FF9800', 'credit_bridge': '#4CAF50'} + labels = {'bp': 'BP', 'dfa': 'DFA', 'state_bridge': 'State Bridge', 'credit_bridge': 'Credit Bridge'} + + alphas = sorted(alphas) + depths = sorted(depths) + + # ======================================================================== + # Plot 1: Phase diagram heatmaps (Gamma, rho, nudge) for each method + # ======================================================================== + fig, axes = plt.subplots(3, 4, figsize=(20, 12)) + metrics_info = [ + ('mean_bp_cosine_mean', 'BP Cosine (Γ)', 'RdYlGn', -0.1, 1.0), + ('mean_rho_mean', 'Perturbation ρ', 'RdYlGn', -0.2, 1.0), + ('mean_nudge_01_mean', 'Nudge (η=0.01)', 'RdYlGn_r', None, None), + ] + all_methods = ['bp', 'dfa', 'state_bridge', 'credit_bridge'] + + for row, (metric_key, metric_name, cmap, vmin, vmax) in enumerate(metrics_info): + for col, method in enumerate(all_methods): + ax = axes[row, col] + grid = np.full((len(alphas), len(depths)), np.nan) + for i, alpha in enumerate(alphas): + for j, L in enumerate(depths): + if (alpha, L) in agg and method in agg[(alpha, L)]: + val = agg[(alpha, L)][method].get(metric_key, np.nan) + grid[i, j] = val + + if vmin is None: + valid = grid[~np.isnan(grid)] + if len(valid) > 0: + vmin_use = np.nanmin(grid) + vmax_use = min(0, np.nanmax(grid)) # nudge should be <=0 + else: + vmin_use, vmax_use = -1, 0 + else: + vmin_use, vmax_use = vmin, vmax + + im = ax.imshow(grid, cmap=cmap, vmin=vmin_use, vmax=vmax_use, aspect='auto', + origin='lower') + ax.set_xticks(range(len(depths))) + ax.set_xticklabels([str(d) for d in depths]) + ax.set_yticks(range(len(alphas))) + ax.set_yticklabels([str(a) for a in alphas]) + + # Annotate cells + for i in range(len(alphas)): + for j in range(len(depths)): + val = grid[i, j] + if not np.isnan(val): + txt = f'{val:.3f}' if abs(val) < 100 else f'{val:.1e}' + ax.text(j, i, txt, ha='center', va='center', fontsize=8, + color='black' if abs(val) < 0.5 else 'white') + else: + ax.text(j, i, 'X', ha='center', va='center', fontsize=10, + color='red', fontweight='bold') + + if row == 0: + ax.set_title(labels[method], fontsize=12) + if col == 0: + ax.set_ylabel(f'{metric_name}\nα (nonlinearity)', fontsize=10) + if row == len(metrics_info) - 1: + ax.set_xlabel('Depth (L)', fontsize=10) + plt.colorbar(im, ax=ax, fraction=0.046, pad=0.04) + + fig.suptitle('Synthetic Ladder: Phase Diagrams (X = blown up)', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_phase_heatmaps.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_phase_heatmaps.png") + + # ======================================================================== + # Plot 2: Line plots - metric vs alpha for each depth + # ======================================================================== + fig, axes = plt.subplots(2, max(len(depths), 1), figsize=(6 * len(depths), 10), squeeze=False) + + for j, L in enumerate(depths): + # Row 0: Gamma + ax = axes[0, j] + for method in all_methods: + vals = [] + errs = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg and method in agg[(alpha, L)]: + v = agg[(alpha, L)][method].get('mean_bp_cosine_mean', np.nan) + e = agg[(alpha, L)][method].get('mean_bp_cosine_std', 0) + if not np.isnan(v): + vals.append(v) + errs.append(e) + valid_alphas.append(alpha) + if vals: + ax.errorbar(valid_alphas, vals, yerr=errs, marker='o', color=colors[method], + label=labels[method], capsize=3, linewidth=2, markersize=6) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel('Mean BP Cosine (Γ)', fontsize=11) + ax.set_title(f'L={L}', fontsize=13) + ax.legend(fontsize=9) + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.1, 1.05) + + # Row 1: rho + ax = axes[1, j] + for method in all_methods: + vals = [] + errs = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg and method in agg[(alpha, L)]: + v = agg[(alpha, L)][method].get('mean_rho_mean', np.nan) + e = agg[(alpha, L)][method].get('mean_rho_std', 0) + if not np.isnan(v): + vals.append(v) + errs.append(e) + valid_alphas.append(alpha) + if vals: + ax.errorbar(valid_alphas, vals, yerr=errs, marker='s', color=colors[method], + label=labels[method], capsize=3, linewidth=2, markersize=6) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel('Mean Perturbation ρ', fontsize=11) + ax.set_title(f'L={L}', fontsize=13) + ax.legend(fontsize=9) + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + + fig.suptitle('Synthetic Ladder: Credit Quality vs Nonlinearity', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_gamma_rho_vs_alpha.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_gamma_rho_vs_alpha.png") + + # ======================================================================== + # Plot 3: State bridge prediction error vs credit quality + # ======================================================================== + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # Left: State prediction error vs alpha + ax = axes[0] + for j, L in enumerate(depths): + vals = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg and 'state_bridge' in agg[(alpha, L)]: + v = agg[(alpha, L)]['state_bridge'].get('state_pred_error_mean', np.nan) + if not np.isnan(v): + vals.append(v) + valid_alphas.append(alpha) + if vals: + ax.plot(valid_alphas, vals, 'o-', label=f'L={L}', markersize=6, linewidth=2) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel('State Prediction Error', fontsize=11) + ax.set_title('State Bridge: Terminal Prediction Error', fontsize=12) + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_yscale('log') + + # Right: State bridge Gamma vs credit bridge Gamma + ax = axes[1] + for L in depths: + sb_gammas = [] + cb_gammas = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg: + sb_g = agg[(alpha, L)].get('state_bridge', {}).get('mean_bp_cosine_mean', np.nan) + cb_g = agg[(alpha, L)].get('credit_bridge', {}).get('mean_bp_cosine_mean', np.nan) + if not np.isnan(sb_g) and not np.isnan(cb_g): + sb_gammas.append(sb_g) + cb_gammas.append(cb_g) + valid_alphas.append(alpha) + if sb_gammas: + ax.scatter(sb_gammas, cb_gammas, s=80, label=f'L={L}', zorder=5) + for a, sx, sy in zip(valid_alphas, sb_gammas, cb_gammas): + ax.annotate(f'α={a}', (sx, sy), textcoords="offset points", + xytext=(5, 5), fontsize=8) + ax.plot([0, 1], [0, 1], 'k--', alpha=0.3, label='y=x') + ax.set_xlabel('State Bridge Γ', fontsize=11) + ax.set_ylabel('Credit Bridge Γ', fontsize=11) + ax.set_title('State Bridge vs Credit Bridge BP Cosine', fontsize=12) + ax.legend() + ax.grid(True, alpha=0.3) + ax.set_xlim(-0.1, 1.0) + ax.set_ylim(-0.1, 1.0) + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_state_vs_credit.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_state_vs_credit.png") + + # ======================================================================== + # Plot 4: Nudging and accuracy comparison + # ======================================================================== + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + + # Left: Nudge vs alpha for each depth + ax = axes[0] + for method in ['dfa', 'state_bridge', 'credit_bridge']: + for L in depths: + vals = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg and method in agg[(alpha, L)]: + v = agg[(alpha, L)][method].get('mean_nudge_01_mean', np.nan) + if not np.isnan(v): + vals.append(v) + valid_alphas.append(alpha) + if vals: + ls = '-' if L == min(depths) else '--' + ax.plot(valid_alphas, vals, f'o{ls}', color=colors[method], + label=f'{labels[method]} L={L}', markersize=5) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel('Mean Nudge (η=0.01, negative=good)', fontsize=11) + ax.set_title('Nudging Test', fontsize=12) + ax.legend(fontsize=8, ncol=2) + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + + # Right: Test accuracy + ax = axes[1] + for method in all_methods: + for L in depths: + vals = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg and method in agg[(alpha, L)]: + v = agg[(alpha, L)][method].get('test_acc_mean', np.nan) + if not np.isnan(v): + vals.append(v) + valid_alphas.append(alpha) + if vals: + ls = '-' if L == min(depths) else '--' + ax.plot(valid_alphas, vals, f'o{ls}', color=colors[method], + label=f'{labels[method]} L={L}', markersize=5) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel('Test Accuracy', fontsize=11) + ax.set_title('Test Accuracy', fontsize=12) + ax.legend(fontsize=8, ncol=2) + ax.grid(True, alpha=0.3) + + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_nudge_acc.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_nudge_acc.png") + + # ======================================================================== + # Plot 5: Per-layer diagnostics for selected (alpha, L) combos + # ======================================================================== + combos = [(a, L) for a in alphas for L in depths if (a, L) in agg] + # Only plot non-blown-up combos + valid_combos = [] + for a, L in combos: + if not np.isnan(agg[(a, L)].get('credit_bridge', {}).get('mean_bp_cosine_mean', np.nan)): + valid_combos.append((a, L)) + + if valid_combos: + n_combos = len(valid_combos) + fig, axes = plt.subplots(2, n_combos, figsize=(5 * n_combos, 8), squeeze=False) + + for idx, (alpha, L) in enumerate(valid_combos): + # Row 0: BP cosine per layer + ax = axes[0, idx] + for method in all_methods: + if method in agg[(alpha, L)]: + pl = agg[(alpha, L)][method].get('bp_cosine_per_layer_mean', None) + if pl is not None: + ax.plot(range(len(pl)), pl, 'o-', color=colors[method], + label=labels[method], markersize=4) + ax.set_xlabel('Layer') + ax.set_ylabel('BP Cosine') + ax.set_title(f'α={alpha}, L={L}', fontsize=11) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + ax.set_ylim(-0.3, 1.05) + + # Row 1: rho per layer + ax = axes[1, idx] + for method in all_methods: + if method in agg[(alpha, L)]: + pl = agg[(alpha, L)][method].get('rho_per_layer_mean', None) + if pl is not None: + ax.plot(range(len(pl)), pl, 's-', color=colors[method], + label=labels[method], markersize=4) + ax.set_xlabel('Layer') + ax.set_ylabel('Perturbation ρ') + ax.set_title(f'α={alpha}, L={L}', fontsize=11) + ax.legend(fontsize=8) + ax.grid(True, alpha=0.3) + ax.axhline(y=0, color='gray', linestyle='--', alpha=0.5) + + fig.suptitle('Per-Layer Diagnostics', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_per_layer.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_per_layer.png") + + # ======================================================================== + # Plot 6: Advantage plots: CB - DFA + # ======================================================================== + fig, axes = plt.subplots(1, 3, figsize=(18, 5)) + metrics = [ + ('mean_bp_cosine_mean', 'Γ(CB) - Γ(DFA)'), + ('mean_rho_mean', 'ρ(CB) - ρ(DFA)'), + ('mean_nudge_01_mean', 'nudge(CB) - nudge(DFA)'), + ] + + for ax, (metric_key, ylabel) in zip(axes, metrics): + for L in depths: + diffs = [] + valid_alphas = [] + for alpha in alphas: + if (alpha, L) in agg: + cb_v = agg[(alpha, L)].get('credit_bridge', {}).get(metric_key, np.nan) + dfa_v = agg[(alpha, L)].get('dfa', {}).get(metric_key, np.nan) + if not np.isnan(cb_v) and not np.isnan(dfa_v): + diffs.append(cb_v - dfa_v) + valid_alphas.append(alpha) + if diffs: + ax.plot(valid_alphas, diffs, 'o-', label=f'L={L}', markersize=6, linewidth=2) + ax.set_xlabel('α (nonlinearity)', fontsize=11) + ax.set_ylabel(ylabel, fontsize=11) + ax.axhline(y=0, color='red', linestyle='--', alpha=0.7, label='zero (parity)') + ax.legend(fontsize=10) + ax.grid(True, alpha=0.3) + + fig.suptitle('Credit Bridge Advantage over DFA', fontsize=14, y=1.02) + fig.tight_layout() + fig.savefig(os.path.join(output_dir, f'{save_prefix}_cb_advantage.png'), dpi=150, bbox_inches='tight') + plt.close(fig) + print(f"Saved {save_prefix}_cb_advantage.png") + + +if __name__ == '__main__': + import sys + result_dirs = sys.argv[1:] if len(sys.argv) > 1 else ['results/synth_ladder_smoke'] + + print(f"Loading from: {result_dirs}") + summary = load_summaries(result_dirs) + + if not summary: + print("No results found!") + sys.exit(1) + + # Extract alphas and depths from keys + alphas = set() + depths = set() + for key in summary: + alpha, L, seed = parse_key(key) + alphas.add(alpha) + depths.add(L) + + alphas = sorted(alphas) + depths = sorted(depths) + print(f"Alphas: {alphas}") + print(f"Depths: {depths}") + print(f"Total configs: {len(summary)}") + + agg = aggregate_seeds(summary) + plot_phase_diagrams(agg, alphas, depths) + print("\nAll plots generated!") diff --git a/experiments/synth_nonlinearity_ladder.py b/experiments/synth_nonlinearity_ladder.py new file mode 100644 index 0000000..d5ed9aa --- /dev/null +++ b/experiments/synth_nonlinearity_ladder.py @@ -0,0 +1,822 @@ +""" +Synthetic Nonlinearity Ladder: teacher-student classification with controllable nonlinearity. + +Teacher dynamics: + h_{l+1}^* = h_l^* + W_l^* phi_alpha(h_l^*) + phi_alpha(z) = (1-alpha)*z + alpha*tanh(z) + +alpha=0 -> linear, alpha=1 -> fully nonlinear. +Sweep (alpha, L) to find where state bridge fails and credit bridge degrades. + +Methods: BP, DFA, State Bridge, Credit Bridge (all reuse project conventions). +""" +import os +import sys +import json +import argparse +import time +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP, ResidualBlock +from models.value_net import ValueNet, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test, + offline_bp_cosine, feature_drift +) + + +# ============================================================================= +# Teacher network and data generation +# ============================================================================= +class TeacherNet: + """Fixed teacher network with controllable nonlinearity.""" + + def __init__(self, d_hidden, num_blocks, num_classes, alpha, seed=0): + rng = np.random.RandomState(seed) + self.d_hidden = d_hidden + self.num_blocks = num_blocks + self.num_classes = num_classes + self.alpha = alpha + + # Teacher block weights: scaled for stability + self.Ws = [] + for l in range(num_blocks): + W = rng.randn(d_hidden, d_hidden).astype(np.float32) + # Scale so spectral norm of residual part is small + W = W / (np.linalg.norm(W, ord=2) + 1e-8) * 0.3 + self.Ws.append(torch.from_numpy(W)) + + # Output projection + U = rng.randn(num_classes, d_hidden).astype(np.float32) + U = U / (np.linalg.norm(U, ord=2) + 1e-8) + self.U = torch.from_numpy(U) + + def to(self, device): + self.Ws = [W.to(device) for W in self.Ws] + self.U = self.U.to(device) + return self + + def phi(self, z): + """Controllable activation: (1-alpha)*z + alpha*tanh(z).""" + return (1 - self.alpha) * z + self.alpha * torch.tanh(z) + + def forward(self, h0): + """Forward pass through teacher, returns logits and hidden states.""" + h = h0 + hiddens = [h] + for l in range(self.num_blocks): + f = F.linear(self.phi(h), self.Ws[l]) + h = h + f + hiddens.append(h) + logits = F.linear(h, self.U) + return logits, hiddens + + +def generate_dataset(teacher, num_samples, d_hidden, device, seed=0): + """Generate classification dataset from teacher network.""" + torch.manual_seed(seed) + X = torch.randn(num_samples, d_hidden, device=device) + with torch.no_grad(): + logits, _ = teacher.forward(X) + Y = logits.argmax(dim=-1) + return X, Y + + +# ============================================================================= +# Student network (same architecture family as teacher) +# ============================================================================= +class StudentBlock(nn.Module): + """Residual block with pre-LayerNorm + phi_alpha.""" + + def __init__(self, d_hidden, alpha): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.w = nn.Linear(d_hidden, d_hidden, bias=False) + self.alpha = alpha + # Small init for stability + nn.init.normal_(self.w.weight, std=0.01) + + def phi(self, z): + return (1 - self.alpha) * z + self.alpha * torch.tanh(z) + + def forward(self, h): + """Returns residual F_l(h), NOT h + F_l(h).""" + return self.w(self.phi(self.ln(h))) + + +class StudentNet(nn.Module): + """Student network: L residual blocks + linear output head.""" + + def __init__(self, d_hidden, num_classes, num_blocks, alpha): + super().__init__() + self.blocks = nn.ModuleList([StudentBlock(d_hidden, alpha) for _ in range(num_blocks)]) + self.out_head = nn.Linear(d_hidden, num_classes) + self.num_blocks = num_blocks + self.d_hidden = d_hidden + + def forward(self, x, return_hidden=False): + h = x # No embedding needed, input is already d_hidden + hiddens = [h] if return_hidden else None + for block in self.blocks: + f = block(h) + h = h + f + if return_hidden: + hiddens.append(h) + logits = self.out_head(h) + if return_hidden: + return logits, hiddens + return logits + + def forward_from_layer(self, h, start_layer): + """Run forward from a given layer to output.""" + for i in range(start_layer, self.num_blocks): + f = self.blocks[i](h) + h = h + f + return self.out_head(h) + + +# ============================================================================= +# Training methods +# ============================================================================= +def evaluate(model, test_loader, device): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x, y = x.to(device), y.to(device) + logits = model(x) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + return correct / total + + +def train_bp(model, train_loader, test_loader, device, args): + """Standard BP training.""" + optimizer = optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.wd) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=args.epochs) + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(device), y.to(device) + logits = model(x) + loss = F.cross_entropy(logits, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + scheduler.step() + log['train_loss'].append(total_loss / total) + log['train_acc'].append(correct / total) + log['test_acc'].append(evaluate(model, test_loader, device)) + if epoch % 20 == 0 or epoch == 1: + print(f" [BP] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " + f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f}") + return log + + +def train_dfa(model, train_loader, test_loader, device, args): + """DFA training with fixed random feedback matrices.""" + d = model.d_hidden + L = model.num_blocks + C = args.num_classes + + Bs = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': []} + + for epoch in range(1, args.epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x, y = x.to(device), y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + + # Update output head (h_L detached) + hL_det = hiddens[-1].detach() + logits_out = model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # Update each block with DFA local surrogate + for l in range(L): + h_l = hiddens[l].detach() + a_dfa = (e_T @ Bs[l].T).detach() + rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a_dfa / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for s in all_schedulers: + s.step() + log['train_loss'].append(total_loss / total) + log['train_acc'].append(correct / total) + log['test_acc'].append(evaluate(model, test_loader, device)) + if epoch % 20 == 0 or epoch == 1: + print(f" [DFA] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " + f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f}") + return log, Bs + + +def train_state_bridge(model, train_loader, test_loader, device, args): + """State Bridge training.""" + d = model.d_hidden + L = model.num_blocks + C = args.num_classes + + state_pred = StateBridgeNet(d_hidden=d, s_dim=C, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd) + state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'state_pred_error': []} + + for epoch in range(1, args.epochs + 1): + model.train() + state_pred.train() + total_loss, correct, total = 0, 0, 0 + total_se = 0 + + for x, y in train_loader: + x, y = x.to(device), y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + hL_det = hiddens[-1].detach() + + # Train state predictor + state_loss = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + pred_hL = state_pred(h_l_det, t_l, s) + target = hL_det + target_norm = target.norm(dim=-1, keepdim=True).clamp(min=1.0) + state_loss = state_loss + (((pred_hL - target) / target_norm) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + state_opt.zero_grad() + state_loss.backward() + state_opt.step() + total_se += state_loss.item() * batch + + # Compute credits + credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + pred_hL = state_pred(h_l_det, t_l, s) + pred_logits = model.out_head(pred_hL) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0] + credits.append(a_l.detach()) + + # Update output head + logits_out = model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # Update blocks + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in all_schedulers: + sch.step() + + log['train_loss'].append(total_loss / total) + log['train_acc'].append(correct / total) + log['test_acc'].append(evaluate(model, test_loader, device)) + log['state_pred_error'].append(total_se / total) + if epoch % 20 == 0 or epoch == 1: + print(f" [SB] Ep {epoch}: loss={log['train_loss'][-1]:.4f} " + f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} " + f"se={log['state_pred_error'][-1]:.6f}") + return log, state_pred + + +def train_credit_bridge(model, train_loader, test_loader, device, args): + """Credit Bridge training with terminal gradient matching + bridge consistency.""" + d = model.d_hidden + L = model.num_blocks + C = args.num_classes + warmup_epochs = max(1, args.epochs // 5) + + value_net = ValueNet(d_hidden=d, s_dim=C, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + value_net_ema = create_ema_model(value_net) + + Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd) + value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + lam = args.lam + K_samples = args.K + sigma_bridge = args.sigma_bridge + ema_momentum = args.ema_momentum + term_grad_weight = args.term_grad_weight + + log = {'train_loss': [], 'train_acc': [], 'test_acc': [], + 'value_loss': [], 'term_loss': [], 'bridge_loss': [], 'tgrad_loss': []} + + for epoch in range(1, args.epochs + 1): + model.train() + value_net.train() + total_loss, correct, total = 0, 0, 0 + total_vloss, total_term, total_bridge, total_tgrad = 0, 0, 0, 0 + + if epoch <= warmup_epochs: + credit_blend = 0.0 + else: + credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) + + for x, y in train_loader: + x, y = x.to(device), y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + + hL_det = hiddens[-1].detach() + + # ---- Train value net ---- + t_L = torch.ones(batch, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + loss_tgrad = torch.tensor(0.0, device=device) + if term_grad_weight > 0: + hL_req = hL_det.clone().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + hL_req2 = hL_det.clone().requires_grad_(True) + logits_tgt = model.out_head(hL_req2) + ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') + a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() + loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + t_l_next = torch.full((batch,), (l + 1) / L, device=device) + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K_samples): + noise = sigma_bridge * torch.randn_like(h_next_det) + V_next = value_net_ema(h_next_det + noise, t_l_next, s) + log_terms.append(-V_next / lam) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + loss_bridge = loss_bridge / L + + value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad + value_opt.zero_grad() + value_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + value_opt.step() + update_ema(value_net, value_net_ema, ema_momentum) + + total_vloss += value_loss.item() * batch + total_term += loss_term.item() * batch + total_bridge += (loss_bridge.item() if isinstance(loss_bridge, torch.Tensor) else loss_bridge) * batch + total_tgrad += loss_tgrad.item() * batch + + # ---- Compute credits ---- + cb_credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V_l = value_net(h_l_det, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] + cb_credits.append(a_l.detach()) + + dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] + + credits = [] + for l in range(L): + if credit_blend >= 1.0: + a = cb_credits[l] + elif credit_blend <= 0.0: + a = dfa_credits[l] + else: + cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) + credits.append(a) + + # ---- Update output head ---- + logits_out = model.out_head(hL_det) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # ---- Update blocks ---- + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in all_schedulers: + sch.step() + + log['train_loss'].append(total_loss / total) + log['train_acc'].append(correct / total) + log['test_acc'].append(evaluate(model, test_loader, device)) + log['value_loss'].append(total_vloss / total) + log['term_loss'].append(total_term / total) + log['bridge_loss'].append(total_bridge / total) + log['tgrad_loss'].append(total_tgrad / total) + if epoch % 20 == 0 or epoch == 1: + phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" + print(f" [CB] Ep {epoch} ({phase}): loss={log['train_loss'][-1]:.4f} " + f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} " + f"vloss={log['value_loss'][-1]:.6f}") + return log, value_net, value_net_ema + + +# ============================================================================= +# Diagnostics (per-layer) +# ============================================================================= +def compute_diagnostics(model, method_name, test_loader, device, args, + value_net=None, state_predictor=None, dfa_Bs=None): + """Compute per-layer diagnostic metrics.""" + model.eval() + if value_net is not None: + value_net.eval() + if state_predictor is not None: + state_predictor.eval() + + d = model.d_hidden + L = model.num_blocks + C = args.num_classes + + # Get one batch + for x, y in test_loader: + x, y = x.to(device), y.to(device) + break + + batch = x.size(0) + + # BP gradients (for offline BP cosine) + # Manual forward to get gradable hidden states + h = x.detach().requires_grad_(True) + hiddens_bp = [h] + for block in model.blocks: + f = block(hiddens_bp[-1]) + h_next = hiddens_bp[-1] + f + hiddens_bp.append(h_next) + logits_bp = model.out_head(hiddens_bp[-1]) + loss_bp = F.cross_entropy(logits_bp, y) + bp_grads = {} + grads = torch.autograd.grad(loss_bp, hiddens_bp, retain_graph=False) + for l in range(L + 1): + bp_grads[l] = grads[l].detach().clone() + + # Clean forward + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + results = { + 'bp_cosine': [], + 'perturbation_rho': [], + 'nudging': {'0.001': [], '0.003': [], '0.01': []}, + } + + # State bridge prediction error (if applicable) + if method_name == 'state_bridge' and state_predictor is not None: + state_pred_errors = [] + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + with torch.no_grad(): + pred_hL = state_predictor(h_l_det, t_l, s) + hL_det = hiddens[-1].detach() + err = ((pred_hL - hL_det) ** 2).sum(dim=-1).mean().item() + state_pred_errors.append(err) + results['state_pred_error_per_layer'] = state_pred_errors + + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + if method_name == 'bp': + a_l = bp_grads[l] + elif method_name == 'dfa': + a_l = (e_T @ dfa_Bs[l].T).detach() + elif method_name == 'state_bridge': + h_l_req = h_l.clone().requires_grad_(True) + pred_hL = state_predictor(h_l_req, t_l, s) + pred_logits = model.out_head(pred_hL) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() + elif method_name == 'credit_bridge': + h_l_req = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() + else: + raise ValueError(f"Unknown method: {method_name}") + + # BP cosine + bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) + results['bp_cosine'].append(bp_cos) + + # Forward function for perturbation and nudging + def make_fwd_fn(start_l): + def fwd_fn(h): + with torch.no_grad(): + curr = h + for i in range(start_l, L): + curr = curr + model.blocks[i](curr) + out = model.out_head(curr) + return F.cross_entropy(out, y, reduction='none') + return fwd_fn + + fwd_fn = make_fwd_fn(l) + rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) + results['perturbation_rho'].append(rho) + + for eta in [0.001, 0.003, 0.01]: + nud = nudging_test(h_l, a_l, fwd_fn, eta=eta) + results['nudging'][str(eta)].append(nud) + + return results + + +# ============================================================================= +# Main experiment runner +# ============================================================================= +def run_single(alpha, L, seed, args, device): + """Run all methods for a single (alpha, L, seed) configuration.""" + d = args.d_hidden + C = args.num_classes + + print(f"\n === alpha={alpha}, L={L}, seed={seed} ===") + t0 = time.time() + + # Generate data from teacher + teacher = TeacherNet(d, L, C, alpha, seed=0).to(device) # Fixed teacher seed=0 + X_train, Y_train = generate_dataset(teacher, args.n_train, d, device, seed=seed) + X_test, Y_test = generate_dataset(teacher, args.n_test, d, device, seed=seed + 10000) + + train_ds = TensorDataset(X_train, Y_train) + test_ds = TensorDataset(X_test, Y_test) + train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True) + test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False) + + results = {} + + # ---- BP ---- + print(" --- BP ---") + torch.manual_seed(seed) + model_bp = StudentNet(d, C, L, alpha).to(device) + bp_log = train_bp(model_bp, train_loader, test_loader, device, args) + bp_diag = compute_diagnostics(model_bp, 'bp', test_loader, device, args) + results['bp'] = {'log': bp_log, 'diagnostics': bp_diag} + + # ---- DFA ---- + print(" --- DFA ---") + torch.manual_seed(seed) + model_dfa = StudentNet(d, C, L, alpha).to(device) + dfa_log, dfa_Bs = train_dfa(model_dfa, train_loader, test_loader, device, args) + dfa_diag = compute_diagnostics(model_dfa, 'dfa', test_loader, device, args, dfa_Bs=dfa_Bs) + results['dfa'] = {'log': dfa_log, 'diagnostics': dfa_diag} + + # ---- State Bridge ---- + print(" --- State Bridge ---") + torch.manual_seed(seed) + model_sb = StudentNet(d, C, L, alpha).to(device) + sb_log, state_pred = train_state_bridge(model_sb, train_loader, test_loader, device, args) + sb_diag = compute_diagnostics(model_sb, 'state_bridge', test_loader, device, args, + state_predictor=state_pred) + results['state_bridge'] = {'log': sb_log, 'diagnostics': sb_diag} + + # ---- Credit Bridge ---- + print(" --- Credit Bridge ---") + torch.manual_seed(seed) + model_cb = StudentNet(d, C, L, alpha).to(device) + cb_log, vnet, vnet_ema = train_credit_bridge(model_cb, train_loader, test_loader, device, args) + cb_diag = compute_diagnostics(model_cb, 'credit_bridge', test_loader, device, args, + value_net=vnet) + results['credit_bridge'] = {'log': cb_log, 'diagnostics': cb_diag} + + elapsed = time.time() - t0 + print(f" === Done alpha={alpha}, L={L}, seed={seed} in {elapsed:.1f}s ===") + + # Summary + for m in ['bp', 'dfa', 'state_bridge', 'credit_bridge']: + test_acc = results[m]['log']['test_acc'][-1] + mean_gamma = np.mean(results[m]['diagnostics']['bp_cosine']) + mean_rho = np.mean(results[m]['diagnostics']['perturbation_rho']) + mean_nudge = np.mean(results[m]['diagnostics']['nudging']['0.01']) + print(f" {m:20s}: acc={test_acc:.4f} Gamma={mean_gamma:.4f} " + f"rho={mean_rho:.4f} nudge={mean_nudge:.6f}") + + return results + + +def serialize(obj): + if isinstance(obj, dict): + return {str(k): serialize(v) for k, v in obj.items()} + elif isinstance(obj, list): + return [serialize(v) for v in obj] + elif isinstance(obj, (np.floating, np.integer)): + return float(obj) + elif isinstance(obj, np.ndarray): + return obj.tolist() + elif isinstance(obj, torch.Tensor): + return obj.cpu().numpy().tolist() + return obj + + +def main(): + parser = argparse.ArgumentParser(description='Synthetic Nonlinearity Ladder') + parser.add_argument('--alphas', type=float, nargs='+', default=[0.0, 0.5, 1.0]) + parser.add_argument('--depths', type=int, nargs='+', default=[2, 8]) + parser.add_argument('--seeds', type=int, nargs='+', default=[42]) + parser.add_argument('--d_hidden', type=int, default=128) + parser.add_argument('--num_classes', type=int, default=10) + parser.add_argument('--n_train', type=int, default=10000) + parser.add_argument('--n_test', type=int, default=2000) + parser.add_argument('--batch_size', type=int, default=256) + parser.add_argument('--epochs', type=int, default=60) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=4) + parser.add_argument('--sigma_bridge', type=float, default=0.05) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--term_grad_weight', type=float, default=1.0) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/synth_ladder') + args = parser.parse_args() + + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Device: {device}") + print(f"Alphas: {args.alphas}") + print(f"Depths: {args.depths}") + print(f"Seeds: {args.seeds}") + + os.makedirs(args.output_dir, exist_ok=True) + + all_results = {} + + for alpha in args.alphas: + for L in args.depths: + for seed in args.seeds: + key = f"a{alpha}_L{L}_s{seed}" + result = run_single(alpha, L, seed, args, device) + all_results[key] = result + + # Save incrementally + out_path = os.path.join(args.output_dir, f'synth_{key}.json') + with open(out_path, 'w') as f: + json.dump(serialize(result), f, indent=2) + + # Save combined summary + summary = {} + for key, result in all_results.items(): + s = {} + for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']: + r = result[method] + diag = r['diagnostics'] + s[method] = { + 'test_acc': r['log']['test_acc'][-1], + 'mean_bp_cosine': float(np.mean(diag['bp_cosine'])), + 'mean_rho': float(np.mean(diag['perturbation_rho'])), + 'mean_nudge_001': float(np.mean(diag['nudging']['0.001'])), + 'mean_nudge_003': float(np.mean(diag['nudging']['0.003'])), + 'mean_nudge_01': float(np.mean(diag['nudging']['0.01'])), + 'bp_cosine_per_layer': [float(x) for x in diag['bp_cosine']], + 'rho_per_layer': [float(x) for x in diag['perturbation_rho']], + 'nudge_per_layer': [float(x) for x in diag['nudging']['0.01']], + } + if method == 'state_bridge' and 'state_pred_error_per_layer' in diag: + s[method]['state_pred_error_per_layer'] = [float(x) for x in diag['state_pred_error_per_layer']] + s[method]['mean_state_pred_error'] = float(np.mean(diag['state_pred_error_per_layer'])) + if method == 'credit_bridge': + s[method]['final_value_loss'] = r['log']['value_loss'][-1] + s[method]['final_term_loss'] = r['log']['term_loss'][-1] + s[method]['final_bridge_loss'] = r['log']['bridge_loss'][-1] + s[method]['final_tgrad_loss'] = r['log']['tgrad_loss'][-1] + summary[key] = s + + summary_path = os.path.join(args.output_dir, 'summary.json') + with open(summary_path, 'w') as f: + json.dump(summary, f, indent=2) + print(f"\nSummary saved to {summary_path}") + + # Save config + config = serialize(vars(args)) + config_path = os.path.join(args.output_dir, 'config.json') + with open(config_path, 'w') as f: + json.dump(config, f, indent=2) + + # Print final summary table + print("\n" + "=" * 100) + print("SYNTHETIC NONLINEARITY LADDER - SUMMARY") + print("=" * 100) + print(f"{'Config':<20} {'Method':<20} {'Acc':>8} {'Gamma':>8} {'rho':>8} {'nudge':>10}") + print("-" * 100) + for key in sorted(summary.keys()): + for method in ['bp', 'dfa', 'state_bridge', 'credit_bridge']: + s = summary[key][method] + print(f"{key:<20} {method:<20} {s['test_acc']:>8.4f} {s['mean_bp_cosine']:>8.4f} " + f"{s['mean_rho']:>8.4f} {s['mean_nudge_01']:>10.6f}") + print("-" * 100) + + +if __name__ == '__main__': + main() |
