""" Phase 3: Boundary-condition ablation on credit bridge. Test different terminal conditioning codes: s1 = e_T (current default, softmax error) s2 = delta_L (grad of CE w.r.t. h_L, output-layer-local) s3 = concat(e_T, proj(h_L)) -- h_L projected to smaller dim s4 = concat(delta_L, proj(h_L)) Also ablate: - terminal gradient matching weight: w_term in {0, 0.25, 1.0, 4.0} - warmup ratio: r_warm in {0, 0.05, 0.2, 0.5} Run on best regimes from Phase 1/2. """ import os import sys import json import argparse import time import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader, TensorDataset import copy sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema from models.state_bridge import StateBridgeNet from metrics.credit_metrics import ( cosine_similarity_batch, perturbation_correlation, nudging_test ) # ============================================================================= # Reuse teacher and student from synth ladder # ============================================================================= class TeacherNet: def __init__(self, d_hidden, num_blocks, num_classes, alpha, seed=0): rng = np.random.RandomState(seed) self.d_hidden = d_hidden self.num_blocks = num_blocks self.num_classes = num_classes self.alpha = alpha self.Ws = [] for l in range(num_blocks): W = rng.randn(d_hidden, d_hidden).astype(np.float32) W = W / (np.linalg.norm(W, ord=2) + 1e-8) * 0.3 self.Ws.append(torch.from_numpy(W)) U = rng.randn(num_classes, d_hidden).astype(np.float32) U = U / (np.linalg.norm(U, ord=2) + 1e-8) self.U = torch.from_numpy(U) def to(self, device): self.Ws = [W.to(device) for W in self.Ws] self.U = self.U.to(device) return self def phi(self, z): return (1 - self.alpha) * z + self.alpha * torch.tanh(z) def forward(self, h0): h = h0 hiddens = [h] for l in range(self.num_blocks): f = F.linear(self.phi(h), self.Ws[l]) h = h + f hiddens.append(h) logits = F.linear(h, self.U) return logits, hiddens def generate_dataset(teacher, num_samples, d_hidden, device, seed=0): torch.manual_seed(seed) X = torch.randn(num_samples, d_hidden, device=device) with torch.no_grad(): logits, _ = teacher.forward(X) Y = logits.argmax(dim=-1) return X, Y class StudentBlock(nn.Module): def __init__(self, d_hidden, alpha): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.w = nn.Linear(d_hidden, d_hidden, bias=False) self.alpha = alpha nn.init.normal_(self.w.weight, std=0.01) def phi(self, z): return (1 - self.alpha) * z + self.alpha * torch.tanh(z) def forward(self, h): return self.w(self.phi(self.ln(h))) class StudentNet(nn.Module): def __init__(self, d_hidden, num_classes, num_blocks, alpha): super().__init__() self.blocks = nn.ModuleList([StudentBlock(d_hidden, alpha) for _ in range(num_blocks)]) self.out_head = nn.Linear(d_hidden, num_classes) self.num_blocks = num_blocks self.d_hidden = d_hidden def forward(self, x, return_hidden=False): h = x hiddens = [h] if return_hidden else None for block in self.blocks: f = block(h) h = h + f if return_hidden: hiddens.append(h) logits = self.out_head(h) if return_hidden: return logits, hiddens return logits def forward_from_layer(self, h, start_layer): for i in range(start_layer, self.num_blocks): h = h + self.blocks[i](h) return self.out_head(h) # ============================================================================= # Extended ValueNet that supports different s_dim # ============================================================================= class ValueNetFlex(nn.Module): """Value net with flexible s_dim.""" def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, 1)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): h_normed = self.ln(h) t_emb = self.time_embed(t) inp = torch.cat([h_normed, t_emb, s], dim=-1) return self.net(inp).squeeze(-1) # ============================================================================= # Terminal conditioning code computation # ============================================================================= def compute_s(s_type, model, hiddens, logits, y, device, hL_proj=None): """ Compute terminal conditioning code s based on s_type. Args: s_type: 'eT', 'deltaL', 'eT_hL', 'deltaL_hL' model: student net hiddens: list of hidden states logits: model logits y: true labels device: torch device hL_proj: fixed random projection matrix for h_L (d_hidden x proj_dim) Returns: s: (batch, s_dim) """ batch = logits.shape[0] hL_det = hiddens[-1].detach() if s_type == 'eT': e_T = logits.softmax(dim=-1).detach() e_T[torch.arange(batch), y] -= 1 return e_T elif s_type == 'deltaL': # grad of CE w.r.t. h_L (output-layer-local) hL_req = hL_det.clone().requires_grad_(True) logits_local = model.out_head(hL_req) loss_local = F.cross_entropy(logits_local, y, reduction='sum') delta_L = torch.autograd.grad(loss_local, hL_req, create_graph=False)[0].detach() return delta_L elif s_type == 'eT_hL': e_T = logits.softmax(dim=-1).detach() e_T[torch.arange(batch), y] -= 1 hL_proj_emb = hL_det @ hL_proj # (batch, proj_dim) return torch.cat([e_T, hL_proj_emb], dim=-1) elif s_type == 'deltaL_hL': hL_req = hL_det.clone().requires_grad_(True) logits_local = model.out_head(hL_req) loss_local = F.cross_entropy(logits_local, y, reduction='sum') delta_L = torch.autograd.grad(loss_local, hL_req, create_graph=False)[0].detach() hL_proj_emb = hL_det @ hL_proj return torch.cat([delta_L, hL_proj_emb], dim=-1) else: raise ValueError(f"Unknown s_type: {s_type}") def get_s_dim(s_type, num_classes, d_hidden, proj_dim=32): if s_type == 'eT': return num_classes elif s_type == 'deltaL': return d_hidden elif s_type == 'eT_hL': return num_classes + proj_dim elif s_type == 'deltaL_hL': return d_hidden + proj_dim else: raise ValueError(f"Unknown s_type: {s_type}") # ============================================================================= # Credit bridge training with configurable boundary conditions # ============================================================================= def train_credit_bridge_ablation(model, train_loader, test_loader, device, args, s_type='eT', term_grad_weight=1.0, warmup_ratio=0.2, hL_proj=None): d = model.d_hidden L = model.num_blocks C = args.num_classes warmup_epochs = max(1, int(args.epochs * warmup_ratio)) s_dim = get_s_dim(s_type, C, d, proj_dim=32) value_net = ValueNetFlex(d_hidden=d, s_dim=s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3).to(device) value_net_ema = create_ema_model(value_net) Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) for block in model.blocks] head_opt = optim.AdamW(model.out_head.parameters(), lr=args.lr, weight_decay=args.wd) value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) lam = args.lam K_samples = args.K sigma_bridge = args.sigma_bridge ema_momentum = args.ema_momentum log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': [], 'term_loss': [], 'bridge_loss': [], 'tgrad_loss': []} for epoch in range(1, args.epochs + 1): model.train() value_net.train() total_loss, correct, total = 0, 0, 0 total_vloss = 0 if warmup_epochs == 0: credit_blend = 1.0 elif epoch <= warmup_epochs: credit_blend = 0.0 else: credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) for x, y in train_loader: x, y = x.to(device), y.to(device) batch = x.size(0) with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) loss_val = F.cross_entropy(logits, y) true_loss = F.cross_entropy(logits, y, reduction='none').detach() # Compute s with the specified type s = compute_s(s_type, model, hiddens, logits, y, device, hL_proj) hL_det = hiddens[-1].detach() # Also need e_T for DFA fallback with torch.no_grad(): e_T = logits.softmax(dim=-1) e_T[torch.arange(batch), y] -= 1 # Train value net t_L = torch.ones(batch, device=device) V_terminal = value_net(hL_det, t_L, s) loss_term = ((V_terminal - true_loss) ** 2).mean() loss_tgrad = torch.tensor(0.0, device=device) if term_grad_weight > 0: hL_req = hL_det.clone().requires_grad_(True) V_at_L = value_net(hL_req, t_L, s) grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] hL_req2 = hL_det.clone().requires_grad_(True) logits_tgt = model.out_head(hL_req2) ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() loss_bridge = 0.0 for l in range(L): h_l_det = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) t_l_next = torch.full((batch,), (l + 1) / L, device=device) V_l = value_net(h_l_det, t_l, s) with torch.no_grad(): h_next_det = hiddens[l + 1].detach() log_terms = [] for k in range(K_samples): noise = sigma_bridge * torch.randn_like(h_next_det) V_next = value_net_ema(h_next_det + noise, t_l_next, s) log_terms.append(-V_next / lam) log_stack = torch.stack(log_terms, dim=-1) V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples)) loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() loss_bridge = loss_bridge / L value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad value_opt.zero_grad() value_loss.backward() torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) value_opt.step() update_ema(value_net, value_net_ema, ema_momentum) total_vloss += value_loss.item() * batch # Compute credits cb_credits = [] for l in range(L): h_l_det = hiddens[l].detach().requires_grad_(True) t_l = torch.full((batch,), l / L, device=device) V_l = value_net(h_l_det, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] cb_credits.append(a_l.detach()) dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] credits = [] for l in range(L): if credit_blend >= 1.0: a = cb_credits[l] elif credit_blend <= 0.0: a = dfa_credits[l] else: cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) credits.append(a) # Update output head logits_out = model.out_head(hL_det) loss_out = F.cross_entropy(logits_out, y) head_opt.zero_grad() loss_out.backward() head_opt.step() # Update blocks for l in range(L): h_l = hiddens[l].detach() a = credits[l] rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 a_norm = a / rms f_l = model.blocks[l](h_l) local_loss = (f_l * a_norm).sum(dim=-1).mean() block_opts[l].zero_grad() local_loss.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() total_loss += loss_val.item() * batch correct += (logits.argmax(1) == y).sum().item() total += batch for sch in all_schedulers: sch.step() log['train_loss'].append(total_loss / total) log['train_acc'].append(correct / total) test_acc = 0 model.eval() with torch.no_grad(): tc, tt = 0, 0 for x, y in test_loader: x, y = x.to(device), y.to(device) logits = model(x) tc += (logits.argmax(1) == y).sum().item() tt += x.size(0) test_acc = tc / tt log['test_acc'].append(test_acc) log['value_loss'].append(total_vloss / total) return log, value_net def compute_diagnostics(model, value_net, test_loader, device, args, s_type='eT', hL_proj=None): model.eval() value_net.eval() d = model.d_hidden L = model.num_blocks C = args.num_classes for x, y in test_loader: x, y = x.to(device), y.to(device) break batch = x.size(0) # BP gradients h = x.detach().requires_grad_(True) hiddens_bp = [h] for block in model.blocks: f = block(hiddens_bp[-1]) hiddens_bp.append(hiddens_bp[-1] + f) logits_bp = model.out_head(hiddens_bp[-1]) loss_bp = F.cross_entropy(logits_bp, y) grads = torch.autograd.grad(loss_bp, hiddens_bp, retain_graph=False) bp_grads = {l: grads[l].detach().clone() for l in range(L + 1)} with torch.no_grad(): logits, hiddens = model(x, return_hidden=True) s = compute_s(s_type, model, hiddens, logits, y, device, hL_proj) results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.01': []}} for l in range(L): h_l = hiddens[l].detach() t_l = torch.full((batch,), l / L, device=device) h_l_req = h_l.clone().requires_grad_(True) V_l = value_net(h_l_req, t_l, s) a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) results['bp_cosine'].append(bp_cos) def make_fwd_fn(start_l): def fwd_fn(h): with torch.no_grad(): curr = h for i in range(start_l, L): curr = curr + model.blocks[i](curr) out = model.out_head(curr) return F.cross_entropy(out, y, reduction='none') return fwd_fn fwd_fn = make_fwd_fn(l) rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) results['perturbation_rho'].append(rho) nud = nudging_test(h_l, a_l, fwd_fn, eta=0.01) results['nudging']['0.01'].append(nud) return results def run_ablation(args, device): d = args.d_hidden C = args.num_classes alpha = args.alpha L = args.L teacher = TeacherNet(d, L, C, alpha, seed=0).to(device) X_train, Y_train = generate_dataset(teacher, args.n_train, d, device, seed=args.seed) X_test, Y_test = generate_dataset(teacher, args.n_test, d, device, seed=args.seed + 10000) train_ds = TensorDataset(X_train, Y_train) test_ds = TensorDataset(X_test, Y_test) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True) test_loader = DataLoader(test_ds, batch_size=args.batch_size, shuffle=False) # h_L projection matrix (fixed random) proj_dim = 32 hL_proj = torch.randn(d, proj_dim, device=device) / np.sqrt(d) results = {} for s_type in args.s_types: for tgw in args.term_grad_weights: for wr in args.warmup_ratios: key = f"s_{s_type}_tgw{tgw}_wr{wr}" print(f"\n === {key} ===") t0 = time.time() torch.manual_seed(args.seed) model = StudentNet(d, C, L, alpha).to(device) log, vnet = train_credit_bridge_ablation( model, train_loader, test_loader, device, args, s_type=s_type, term_grad_weight=tgw, warmup_ratio=wr, hL_proj=hL_proj ) diag = compute_diagnostics(model, vnet, test_loader, device, args, s_type=s_type, hL_proj=hL_proj) mean_gamma = np.mean(diag['bp_cosine']) mean_rho = np.mean(diag['perturbation_rho']) mean_nudge = np.mean(diag['nudging']['0.01']) test_acc = log['test_acc'][-1] results[key] = { 'test_acc': test_acc, 'mean_bp_cosine': float(mean_gamma), 'mean_rho': float(mean_rho), 'mean_nudge': float(mean_nudge), 'bp_cosine_per_layer': [float(x) for x in diag['bp_cosine']], 'rho_per_layer': [float(x) for x in diag['perturbation_rho']], 'final_value_loss': log['value_loss'][-1], 's_type': s_type, 'term_grad_weight': tgw, 'warmup_ratio': wr, } elapsed = time.time() - t0 print(f" Done in {elapsed:.0f}s: acc={test_acc:.4f} Gamma={mean_gamma:.4f} " f"rho={mean_rho:.4f} nudge={mean_nudge:.6f}") return results def serialize(obj): if isinstance(obj, dict): return {str(k): serialize(v) for k, v in obj.items()} elif isinstance(obj, list): return [serialize(v) for v in obj] elif isinstance(obj, (np.floating, np.integer)): return float(obj) elif isinstance(obj, np.ndarray): return obj.tolist() elif isinstance(obj, torch.Tensor): return obj.cpu().numpy().tolist() return obj def main(): parser = argparse.ArgumentParser(description='Boundary Condition Ablation') parser.add_argument('--alpha', type=float, default=1.0) parser.add_argument('--L', type=int, default=4) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--d_hidden', type=int, default=128) parser.add_argument('--num_classes', type=int, default=10) parser.add_argument('--n_train', type=int, default=10000) parser.add_argument('--n_test', type=int, default=2000) parser.add_argument('--batch_size', type=int, default=256) parser.add_argument('--epochs', type=int, default=80) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--lam', type=float, default=0.1) parser.add_argument('--K', type=int, default=4) parser.add_argument('--sigma_bridge', type=float, default=0.05) parser.add_argument('--ema_momentum', type=float, default=0.995) parser.add_argument('--s_types', type=str, nargs='+', default=['eT', 'deltaL', 'eT_hL', 'deltaL_hL']) parser.add_argument('--term_grad_weights', type=float, nargs='+', default=[0.0, 0.25, 1.0, 4.0]) parser.add_argument('--warmup_ratios', type=float, nargs='+', default=[0.0, 0.05, 0.2, 0.5]) parser.add_argument('--gpu', type=int, default=1) parser.add_argument('--output_dir', type=str, default='results/boundary_ablation') args = parser.parse_args() device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Device: {device}") print(f"alpha={args.alpha}, L={args.L}, seed={args.seed}") print(f"s_types: {args.s_types}") print(f"term_grad_weights: {args.term_grad_weights}") print(f"warmup_ratios: {args.warmup_ratios}") os.makedirs(args.output_dir, exist_ok=True) results = run_ablation(args, device) out_path = os.path.join(args.output_dir, f'ablation_a{args.alpha}_L{args.L}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(serialize(results), f, indent=2) # Print summary print("\n" + "=" * 100) print("BOUNDARY CONDITION ABLATION SUMMARY") print("=" * 100) print(f"{'Config':<40} {'Acc':>8} {'Gamma':>8} {'rho':>8} {'nudge':>10}") print("-" * 100) for key in sorted(results.keys()): r = results[key] print(f"{key:<40} {r['test_acc']:>8.4f} {r['mean_bp_cosine']:>8.4f} " f"{r['mean_rho']:>8.4f} {r['mean_nudge']:>10.6f}") if __name__ == '__main__': main()