""" Phase B: Online shallow CIFAR with better conditioning. Goal: Find a minimal positive-signal regime on real CIFAR-10 with shallow depth. Sweep L={4,6}, d={256,512}, methods={DFA, CB_eT, CB_deltaL, SB_eT}, warmup_ratio={0.0, 0.05, 0.2}, term_grad_weight={1.0, 4.0}. Single-seed smoke test first. Only expand to 3 seeds for configs with S1>0 and S2>0. """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import copy import time sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import ValueNet, create_ema_model, update_ema from models.state_bridge import StateBridgeNet from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test, offline_bp_cosine ) def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) return train_loader, test_loader def evaluate(model, test_loader, device): model.eval() correct, total = 0, 0 with torch.no_grad(): for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) logits = model(x) correct += (logits.argmax(1) == y).sum().item() total += x.size(0) return correct / total # ============================================================================= # Training methods # ============================================================================= def train_dfa(model, train_loader, test_loader, device, epochs, lr, wd): """DFA training.""" d = model.d_hidden num_classes = 10 L = model.num_blocks Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd ) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = {'train_loss': [], 'test_acc': []} for epoch in range(1, epochs + 1): model.train() total_loss, correct, total = 0, 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 hL_det = hiddens[-1].detach() logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() for l in range(L): h_l = hiddens[l].detach() a_dfa = (e_T @ Bs[l].T).detach() rms = (a_dfa ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a_dfa / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() a_0 = (e_T @ Bs[0].T).detach() rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for s in all_schedulers: s.step() train_loss = total_loss / total test_acc = evaluate(model, test_loader, device) log['train_loss'].append(train_loss) log['test_acc'].append(test_acc) if epoch % 20 == 0 or epoch == 1: print(f" [DFA] Ep {epoch}: loss={train_loss:.4f}, test={test_acc:.4f}") return log, Bs def train_state_bridge_online(model, train_loader, test_loader, device, epochs, lr, lr_fb, wd): """State bridge online training.""" d = model.d_hidden num_classes = 10 L = model.num_blocks state_pred = StateBridgeNet( d_hidden=d, s_dim=num_classes, time_embed_dim=32, hidden_dim=256, num_layers=3 ).to(device) block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd ) state_opt = optim.Adam(state_pred.parameters(), lr=lr_fb) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = {'train_loss': [], 'test_acc': [], 'state_pred_error': []} for epoch in range(1, epochs + 1): model.train() state_pred.train() total_loss, correct, total = 0, 0, 0 total_se = 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s = e_T.detach() hL_det = hiddens[-1].detach() # Train state predictor state_loss = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0) state_loss = state_loss + (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean() state_loss = state_loss / L state_opt.zero_grad() state_loss.backward() state_opt.step() total_se += state_loss.item() * batch # Compute credits credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) pred_hL = state_pred(h_l_det, t_l, s) pred_logits = model.out_head(model.out_ln(pred_hL)) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_det, create_graph=False)[0] credits.append(a_l.detach()) # Update output head logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() # Update embedding a_0 = credits[0] rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for sch in all_schedulers: sch.step() train_loss = total_loss / total test_acc = evaluate(model, test_loader, device) se = total_se / total log['train_loss'].append(train_loss) log['test_acc'].append(test_acc) log['state_pred_error'].append(se) if epoch % 20 == 0 or epoch == 1: print(f" [SB] Ep {epoch}: loss={train_loss:.4f}, test={test_acc:.4f}, se={se:.4f}") return log, state_pred def train_credit_bridge_online(model, train_loader, test_loader, device, epochs, lr, lr_fb, wd, s_type='eT', warmup_ratio=0.2, term_grad_weight=1.0, lam=0.1, K=4, sigma_bridge=0.05, ema_momentum=0.995): """Credit bridge online training with configurable s_type, warmup, tgw.""" d = model.d_hidden num_classes = 10 L = model.num_blocks warmup_epochs = max(1, int(epochs * warmup_ratio)) s_dim = num_classes if s_type == 'eT' else d value_net = ValueNet( d_hidden=d, s_dim=s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3 ).to(device) value_net_ema = create_ema_model(value_net) # DFA fallback for warmup Bs_fallback = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=lr, weight_decay=wd) for block in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd ) value_opt = optim.Adam(value_net.parameters(), lr=lr_fb) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]) log = {'train_loss': [], 'test_acc': [], 'value_loss': []} for epoch in range(1, epochs + 1): model.train() value_net.train() total_loss, correct, total = 0, 0, 0 total_vloss = 0 if epoch <= warmup_epochs: credit_blend = 0.0 else: credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) for x, y in train_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 true_loss = F.cross_entropy(logits, y, reduction='none').detach() hL_det = hiddens[-1].detach() # Compute s if s_type == 'eT': s = e_T.detach() elif s_type == 'deltaL': hL_req = hL_det.clone().requires_grad_(True) logits_for_s = model.out_head(model.out_ln(hL_req)) ce_for_s = F.cross_entropy(logits_for_s, y, reduction='sum') delta_L = torch.autograd.grad(ce_for_s, hL_req, create_graph=False)[0].detach() s = delta_L # Train value net t_L = torch.ones(batch, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() loss_tgrad = torch.tensor(0.0, device=device) if term_grad_weight > 0: hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] hL_req2 = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(model.out_ln(hL_req2)) ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_l_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K): noise = sigma_bridge * torch.randn_like(h_next_det) V_next = value_net_ema(h_next_det + noise, t_l_next, s) log_terms.append(-V_next / lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad value_opt.zero_grad() value_loss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, ema_momentum) total_vloss += value_loss.item() * batch # Compute CB credits cb_credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V_l = value_net(h_l_det, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] cb_credits.append(a_l.detach()) dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] # Blend credits = [] for l in range(L): if credit_blend >= 1.0: a = cb_credits[l] elif credit_blend <= 0.0: a = dfa_credits[l] else: cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) credits.append(a) # Update head logits_out = model.out_head(model.out_ln(hL_det)) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() # Update embedding a_0 = credits[0] rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 h0 = model.embed(x) embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean() embed_opt.zero_grad() embed_loss.backward() embed_opt.step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for sch in all_schedulers: sch.step() train_loss = total_loss / total test_acc = evaluate(model, test_loader, device) vloss = total_vloss / total log['train_loss'].append(train_loss) log['test_acc'].append(test_acc) log['value_loss'].append(vloss) if epoch % 20 == 0 or epoch == 1: phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" print(f" [CB_{s_type}] Ep {epoch} ({phase}): loss={train_loss:.4f}, test={test_acc:.4f}") return log, value_net, value_net_ema # ============================================================================= # Diagnostics # ============================================================================= def compute_diagnostics(model, method_name, test_loader, device, value_net=None, state_pred=None, dfa_Bs=None, s_type='eT'): """Compute Gamma, rho, nudging per layer.""" model.eval() if value_net is not None: value_net.eval() if state_pred is not None: state_pred.eval() d = model.d_hidden L = model.num_blocks num_classes = 10 # Get one batch for x, y in test_loader: x = x.view(x.size(0), -1).to(device) y = y.to(device) break batch = x.size(0) # BP gradients (evaluation only) logits_bp, hiddens_bp = model(x, return_hidden=True) for l in range(L + 1): hiddens_bp[l].retain_grad() loss_bp = F.cross_entropy(logits_bp, y) loss_bp.backward() bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} # Clean forward with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 s_eT = e_T.detach() hL_det = hiddens[-1].detach() # delta_L for deltaL conditioning hL_req = hL_det.clone().requires_grad_(True) logits_for_delta = model.out_head(model.out_ln(hL_req)) ce_for_delta = F.cross_entropy(logits_for_delta, y, reduction='sum') delta_L = torch.autograd.grad(ce_for_delta, hL_req, create_graph=False)[0].detach() results = { 'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.001': [], '0.003': [], '0.01': []}, } for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) if method_name == 'dfa': a_l = (s_eT @ dfa_Bs[l].T).detach() elif method_name == 'state_bridge': h_l_req = h_l.clone().requires_grad_(True) pred_hL = state_pred(h_l_req, t_l, s_eT) pred_logits = model.out_head(model.out_ln(pred_hL)) pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() elif method_name.startswith('cb_'): s = s_eT if s_type == 'eT' else delta_L h_l_req = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() else: raise ValueError(f"Unknown method: {method_name}") bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) results['bp_cosine'].append(float(bp_cos)) def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): curr = h for i in range(start_l, L): curr = curr + model.blocks[i](curr) out = model.out_head(model.out_ln(curr)) return F.cross_entropy(out, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) results['perturbation_rho'].append(float(rho)) for eta in [0.001, 0.003, 0.01]: nud = nudging_test(h_l, a_l, fwd_fn, eta=eta) results['nudging'][str(eta)].append(float(nud)) return results # ============================================================================= # Single config runner # ============================================================================= def run_config(L, d, method, seed, train_loader, test_loader, device, epochs=100, lr=1e-3, lr_fb=1e-3, wd=0.01, warmup_ratio=0.2, term_grad_weight=1.0, lam=0.1, K=4, sigma_bridge=0.05, ema_momentum=0.995): """Run a single (L, d, method, seed) config and return results.""" input_dim = 32 * 32 * 3 num_classes = 10 torch.manual_seed(seed) np.random.seed(seed) torch.cuda.manual_seed_all(seed) model = ResidualMLP(input_dim, d, num_classes, L).to(device) config_str = f"L={L}, d={d}, method={method}, seed={seed}" if 'cb_' in method: config_str += f", wr={warmup_ratio}, tgw={term_grad_weight}" print(f"\n --- {config_str} ---") if method == 'dfa': log, Bs = train_dfa(model, train_loader, test_loader, device, epochs, lr, wd) diag = compute_diagnostics(model, 'dfa', test_loader, device, dfa_Bs=Bs) elif method == 'sb_eT': log, sp = train_state_bridge_online(model, train_loader, test_loader, device, epochs, lr, lr_fb, wd) diag = compute_diagnostics(model, 'state_bridge', test_loader, device, state_pred=sp) elif method == 'cb_eT': log, vnet, _ = train_credit_bridge_online( model, train_loader, test_loader, device, epochs, lr, lr_fb, wd, s_type='eT', warmup_ratio=warmup_ratio, term_grad_weight=term_grad_weight, lam=lam, K=K, sigma_bridge=sigma_bridge, ema_momentum=ema_momentum ) diag = compute_diagnostics(model, 'cb_eT', test_loader, device, value_net=vnet, s_type='eT') elif method == 'cb_deltaL': log, vnet, _ = train_credit_bridge_online( model, train_loader, test_loader, device, epochs, lr, lr_fb, wd, s_type='deltaL', warmup_ratio=warmup_ratio, term_grad_weight=term_grad_weight, lam=lam, K=K, sigma_bridge=sigma_bridge, ema_momentum=ema_momentum ) diag = compute_diagnostics(model, 'cb_deltaL', test_loader, device, value_net=vnet, s_type='deltaL') else: raise ValueError(f"Unknown method: {method}") result = { 'method': method, 'L': L, 'd_hidden': d, 'seed': seed, 'warmup_ratio': warmup_ratio, 'term_grad_weight': term_grad_weight, 'test_acc': log['test_acc'][-1], 'mean_gamma': float(np.mean(diag['bp_cosine'])), 'mean_rho': float(np.mean(diag['perturbation_rho'])), 'mean_nudge': float(np.mean(diag['nudging']['0.003'])), 'per_layer_gamma': diag['bp_cosine'], 'per_layer_rho': diag['perturbation_rho'], 'per_layer_nudge': diag['nudging']['0.003'], } print(f" Result: acc={result['test_acc']:.4f}, Gamma={result['mean_gamma']:.4f}, " f"rho={result['mean_rho']:.4f}, nudge={result['mean_nudge']:.6f}") return result # ============================================================================= # Main # ============================================================================= def main(): parser = argparse.ArgumentParser(description='Phase B: Online shallow CIFAR conditioning') parser.add_argument('--depths', type=int, nargs='+', default=[4, 6]) parser.add_argument('--widths', type=int, nargs='+', default=[256, 512]) parser.add_argument('--methods', type=str, nargs='+', default=['dfa', 'sb_eT', 'cb_eT', 'cb_deltaL']) parser.add_argument('--warmup_ratios', type=float, nargs='+', default=[0.0, 0.05, 0.2]) parser.add_argument('--tgws', type=float, nargs='+', default=[1.0, 4.0]) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--lam', type=float, default=0.1) parser.add_argument('--K', type=int, default=4) parser.add_argument('--sigma_bridge', type=float, default=0.05) parser.add_argument('--ema_momentum', type=float, default=0.995) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--gpu', type=int, default=2) parser.add_argument('--output_dir', type=str, default='results/online_shallow') args = parser.parse_args() device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) train_loader, test_loader = get_cifar10(batch_size=args.batch_size) all_results = [] for L in args.depths: for d in args.widths: for method in args.methods: if method in ['dfa', 'sb_eT']: # No warmup/tgw sweep for DFA and SB result = run_config( L, d, method, args.seed, train_loader, test_loader, device, epochs=args.epochs, lr=args.lr, lr_fb=args.lr_fb, wd=args.wd ) all_results.append(result) else: # Sweep warmup and tgw for CB methods for wr in args.warmup_ratios: for tgw in args.tgws: result = run_config( L, d, method, args.seed, train_loader, test_loader, device, epochs=args.epochs, lr=args.lr, lr_fb=args.lr_fb, wd=args.wd, warmup_ratio=wr, term_grad_weight=tgw, lam=args.lam, K=args.K, sigma_bridge=args.sigma_bridge, ema_momentum=args.ema_momentum ) all_results.append(result) # Summary table print(f"\n{'='*80}") print("SUMMARY") print(f"{'='*80}") # Find DFA baselines for S1, S2 computation dfa_baselines = {} for r in all_results: if r['method'] == 'dfa': dfa_baselines[(r['L'], r['d_hidden'])] = r print(f"\n{'Method':<20} {'L':>3} {'d':>4} {'wr':>5} {'tgw':>5} {'Acc':>6} " f"{'Gamma':>7} {'rho':>7} {'nudge':>10} {'S1':>7} {'S2':>7}") print("-" * 95) positive_configs = [] for r in all_results: key = (r['L'], r['d_hidden']) dfa_ref = dfa_baselines.get(key) S1 = r['mean_gamma'] - (dfa_ref['mean_gamma'] if dfa_ref else 0) S2 = r['mean_rho'] - (dfa_ref['mean_rho'] if dfa_ref else 0) wr_str = f"{r.get('warmup_ratio', '-'):>5.2f}" if r['method'].startswith('cb_') else " -" tgw_str = f"{r.get('term_grad_weight', '-'):>5.1f}" if r['method'].startswith('cb_') else " -" print(f"{r['method']:<20} {r['L']:>3} {r['d_hidden']:>4} {wr_str} {tgw_str} " f"{r['test_acc']:>6.4f} {r['mean_gamma']:>7.4f} {r['mean_rho']:>7.4f} " f"{r['mean_nudge']:>10.6f} {S1:>7.4f} {S2:>7.4f}") if r['method'].startswith('cb_') and S1 > 0 and S2 > 0: nudge_better = r['mean_nudge'] < (dfa_ref['mean_nudge'] if dfa_ref else 0) positive_configs.append({**r, 'S1': S1, 'S2': S2, 'nudge_better': nudge_better}) if positive_configs: print(f"\nPOSITIVE CONFIGS (S1>0 AND S2>0):") for pc in positive_configs: print(f" {pc['method']} L={pc['L']} d={pc['d_hidden']} wr={pc.get('warmup_ratio','-')} " f"tgw={pc.get('term_grad_weight','-')}: S1={pc['S1']:.4f} S2={pc['S2']:.4f} " f"nudge_better={pc['nudge_better']}") else: print(f"\nNO POSITIVE CONFIGS FOUND. All CB variants have S1<=0 or S2<=0.") # Save out_path = os.path.join(args.output_dir, f'scan_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(all_results, f, indent=2) print(f"\nResults saved to {out_path}") if __name__ == '__main__': main()