diff options
| -rw-r--r-- | NOTE.md | 12 | ||||
| -rw-r--r-- | experiments/cifar_deltaL_test.py | 393 | ||||
| -rw-r--r-- | report_explore/MEMO_pivot_vector_field.md | 128 |
3 files changed, 533 insertions, 0 deletions
@@ -153,3 +153,15 @@ wr=0.5 -> worst Gamma (0.23) but best acc (0.66). Clear tradeoff between credit quality and accuracy. Best single config: deltaL + tgw=1.0 + wr=0.05 -> **Gamma=0.768, rho=0.691** + +### CIFAR deltaL Test +deltaL conditioning (s=grad_{h_L} CE, dim=512) on CIFAR L=4: FAILED. +Acc=17.2%, Gamma≈0, rho≈0. The 512-dim conditioning is too high-dimensional +for the value net. Confirms the scalar V approach has a dimensionality bottleneck. + +### Pivot Recommendation: Direct Vector Credit Field +See `report_explore/MEMO_pivot_vector_field.md`. +Instead of V_phi -> grad_h V, learn a_phi(h_l, t_l, s) -> R^d directly. +Train with perturbation-based target: match <a, v> to actual loss change. +Still satisfies no hidden BP anchor constraint. +Minimal test: synthetic alpha=1.0, L=4 with M=4 perturbation directions. diff --git a/experiments/cifar_deltaL_test.py b/experiments/cifar_deltaL_test.py new file mode 100644 index 0000000..c085489 --- /dev/null +++ b/experiments/cifar_deltaL_test.py @@ -0,0 +1,393 @@ +""" +Quick test: Credit Bridge on CIFAR-10 with s=deltaL conditioning. +deltaL = grad_{h_L} CE(out_head(h_L), y) -- output-layer-local, dim=d_hidden. +This gives 512-dim conditioning instead of 10-dim e_T. +""" +import os +import sys +import json +import argparse +import time +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test +) + + +class ValueNetLargeS(nn.Module): + """Value net with larger s_dim (for deltaL conditioning).""" + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + # Compress s to a fixed dim to keep value net manageable + self.s_compress = nn.Linear(s_dim, 64) + input_dim = d_hidden + time_embed_dim + 64 + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, 1)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + s_compressed = self.s_compress(s) + inp = torch.cat([h_normed, t_emb, s_compressed], dim=-1) + return self.net(inp).squeeze(-1) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def evaluate(model, test_loader, device): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + return correct / total + + +def compute_deltaL(model, hL_det, y): + """Compute delta_L = grad_{h_L} CE(out_head(out_ln(h_L)), y). Output-layer-local.""" + hL_req = hL_det.clone().requires_grad_(True) + logits_local = model.out_head(model.out_ln(hL_req)) + loss_local = F.cross_entropy(logits_local, y, reduction='sum') + delta_L = torch.autograd.grad(loss_local, hL_req, create_graph=False)[0].detach() + return delta_L + + +def train_cb_deltaL(model, train_loader, test_loader, device, args): + """Credit bridge with s=deltaL conditioning.""" + d = model.d_hidden + L = model.num_blocks + C = 10 + warmup_epochs = max(1, args.epochs // 5) + + value_net = ValueNetLargeS(d_hidden=d, s_dim=d, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + value_net_ema = create_ema_model(value_net) + + Bs_fallback = [torch.randn(d, C, device=device) / np.sqrt(C) for _ in range(L)] + + block_opts = [optim.AdamW(block.parameters(), lr=args.lr, weight_decay=args.wd) + for block in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd + ) + value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) + + all_schedulers = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]) + + log = {'train_loss': [], 'train_acc': [], 'test_acc': [], 'value_loss': []} + + for epoch in range(1, args.epochs + 1): + model.train() + value_net.train() + total_loss, correct, total = 0, 0, 0 + total_vloss = 0 + + if epoch <= warmup_epochs: + credit_blend = 0.0 + else: + credit_blend = min(1.0, (epoch - warmup_epochs) / max(1, warmup_epochs)) + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + + hL_det = hiddens[-1].detach() + + # Compute s = deltaL (output-layer-local gradient) + s = compute_deltaL(model, hL_det, y) + + # Train value net + t_L = torch.ones(batch, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Terminal gradient matching + loss_tgrad = torch.tensor(0.0, device=device) + if args.term_grad_weight > 0: + hL_req = hL_det.clone().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + # a_L_exact is just s (deltaL) itself + a_L_exact = s + loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + t_l_next = torch.full((batch,), (l + 1) / L, device=device) + V_l = value_net(h_l_det, t_l, s) + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(args.K): + noise = args.sigma_bridge * torch.randn_like(h_next_det) + V_next = value_net_ema(h_next_det + noise, t_l_next, s) + log_terms.append(-V_next / args.lam) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -args.lam * (torch.logsumexp(log_stack, dim=-1) - np.log(args.K)) + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + loss_bridge = loss_bridge / L + + value_loss = loss_term + loss_bridge + args.term_grad_weight * loss_tgrad + value_opt.zero_grad() + value_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + value_opt.step() + update_ema(value_net, value_net_ema, args.ema_momentum) + total_vloss += value_loss.item() * batch + + # Compute credits + cb_credits = [] + for l in range(L): + h_l_det = hiddens[l].detach().requires_grad_(True) + t_l = torch.full((batch,), l / L, device=device) + V_l = value_net(h_l_det, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_det, create_graph=False)[0] + cb_credits.append(a_l.detach()) + + dfa_credits = [(e_T @ Bs_fallback[l].T).detach() for l in range(L)] + + credits = [] + for l in range(L): + if credit_blend >= 1.0: + a = cb_credits[l] + elif credit_blend <= 0.0: + a = dfa_credits[l] + else: + cb_rms = (cb_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + dfa_rms = (dfa_credits[l] ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a = credit_blend * (cb_credits[l] / cb_rms) + (1 - credit_blend) * (dfa_credits[l] / dfa_rms) + credits.append(a) + + # Update output head + logits_out = model.out_head(model.out_ln(hL_det)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad() + loss_out.backward() + head_opt.step() + + # Update blocks + for l in range(L): + h_l = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a_norm).sum(dim=-1).mean() + block_opts[l].zero_grad() + local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + # Update embedding + a_0 = credits[0] + rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6 + a_0_norm = a_0 / rms_0 + h0 = model.embed(x) + embed_loss = (h0 * a_0_norm).sum(dim=-1).mean() + embed_opt.zero_grad() + embed_loss.backward() + embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in all_schedulers: + sch.step() + + log['train_loss'].append(total_loss / total) + log['train_acc'].append(correct / total) + log['test_acc'].append(evaluate(model, test_loader, device)) + log['value_loss'].append(total_vloss / total) + if epoch % 10 == 0 or epoch == 1: + phase = "warmup" if epoch <= warmup_epochs else f"blend={credit_blend:.2f}" + print(f" [CB-deltaL] Ep {epoch} ({phase}): loss={log['train_loss'][-1]:.4f} " + f"train={log['train_acc'][-1]:.4f} test={log['test_acc'][-1]:.4f} " + f"vloss={log['value_loss'][-1]:.6f}") + return log, value_net + + +def compute_diagnostics(model, value_net, test_loader, device, args): + model.eval() + value_net.eval() + d = model.d_hidden + L = model.num_blocks + + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + break + + batch = x.size(0) + + # BP gradients + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + loss_bp = F.cross_entropy(logits_bp, y) + loss_bp.backward() + bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + + hL_det = hiddens[-1].detach() + s = compute_deltaL(model, hL_det, y) + + results = {'bp_cosine': [], 'perturbation_rho': [], 'nudging': {'0.01': []}} + + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + h_l_req = h_l.clone().requires_grad_(True) + V_l = value_net(h_l_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() + + bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) + results['bp_cosine'].append(bp_cos) + + def make_fwd_fn(start_l): + def fwd_fn(h): + with torch.no_grad(): + curr = h + for i in range(start_l, L): + curr = curr + model.blocks[i](curr) + out = model.out_head(model.out_ln(curr)) + return F.cross_entropy(out, y, reduction='none') + return fwd_fn + + fwd_fn = make_fwd_fn(l) + rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=16) + results['perturbation_rho'].append(rho) + + nud = nudging_test(h_l, a_l, fwd_fn, eta=0.01) + results['nudging']['0.01'].append(nud) + + return results + + +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('--d_hidden', type=int, default=512) + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=4) + parser.add_argument('--sigma_bridge', type=float, default=0.05) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--term_grad_weight', type=float, default=1.0) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=1) + parser.add_argument('--output_dir', type=str, default='results/cifar_deltaL') + args = parser.parse_args() + + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + train_loader, test_loader = get_cifar10(args.batch_size) + input_dim = 32 * 32 * 3 + + model = ResidualMLP(input_dim, args.d_hidden, 10, args.num_blocks).to(device) + print(f"Model: d={args.d_hidden}, L={args.num_blocks}") + print(f"Conditioning: s=deltaL (dim={args.d_hidden})") + + t0 = time.time() + log, vnet = train_cb_deltaL(model, train_loader, test_loader, device, args) + elapsed = time.time() - t0 + + diag = compute_diagnostics(model, vnet, test_loader, device, args) + + mean_gamma = np.mean(diag['bp_cosine']) + mean_rho = np.mean(diag['perturbation_rho']) + mean_nudge = np.mean(diag['nudging']['0.01']) + + print(f"\nDone in {elapsed:.0f}s") + print(f"Test acc: {log['test_acc'][-1]:.4f}") + print(f"Mean Gamma: {mean_gamma:.4f}") + print(f"Mean rho: {mean_rho:.4f}") + print(f"Mean nudge: {mean_nudge:.6f}") + print(f"Gamma per layer: {[round(g, 4) for g in diag['bp_cosine']]}") + print(f"rho per layer: {[round(r, 4) for r in diag['perturbation_rho']]}") + + result = { + 'test_acc': log['test_acc'][-1], + 'mean_gamma': float(mean_gamma), + 'mean_rho': float(mean_rho), + 'mean_nudge': float(mean_nudge), + 'gamma_per_layer': [float(g) for g in diag['bp_cosine']], + 'rho_per_layer': [float(r) for r in diag['perturbation_rho']], + 'log': log, + } + + out_path = os.path.join(args.output_dir, f'cb_deltaL_d{args.d_hidden}_L{args.num_blocks}_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(result, f, indent=2) + print(f"Saved to {out_path}") + + +if __name__ == '__main__': + main() diff --git a/report_explore/MEMO_pivot_vector_field.md b/report_explore/MEMO_pivot_vector_field.md new file mode 100644 index 0000000..f73ac3d --- /dev/null +++ b/report_explore/MEMO_pivot_vector_field.md @@ -0,0 +1,128 @@ +# Pivot Design Memo: Direct Vector Credit Field + +## Why Scalar V May Have Value-Correct but Curvature-Wrong Gradients + +The current credit bridge learns a scalar function V_phi(h_l, t_l, s) and defines credit as a_l = grad_h V_phi. The bridge consistency loss constrains V's **values** at successive layers: + + V(h_l, t_l, s) ≈ soft-min_noise V_bar(h_{l+1} + noise, t_{l+1}, s) + +This gives correct V values but provides only **indirect** constraints on grad_h V. The gradient of V depends on its curvature with respect to h, which is a second-order property that the value-matching loss doesn't directly optimize. + +Terminal gradient matching addresses this at the boundary (l=L), but the information must propagate backward through the bridge consistency, which is a value-level (zeroth-order) constraint. Each layer of bridge consistency loses gradient information. + +**Evidence from experiments:** +- Without terminal gradient matching: V values converge but gradients are uninformative (cosine → 0.03) +- With terminal gradient matching: gradients improve but degrade with distance from terminal layer +- On CIFAR (d=512), the gradient information from 10-dim terminal code is insufficient +- deltaL (d-dim conditioning) helps on synthetic but fundamental issue remains + +The core problem: **optimizing a scalar function's values does not efficiently constrain its d-dimensional gradient field**, especially in high dimensions. + +## Direct Vector Credit Field: The Alternative + +Instead of V_phi: R^d x R x R^s -> R, learn the credit directly: + + a_phi(h_l, t_l, s) in R^d + +This outputs the credit vector without going through a scalar intermediate. The gradient computation disappears — the network output IS the credit. + +### Architecture + +``` +Input: [LN(h_l), time_embed(t_l), s] +-> MLP (same as current ValueNet architecture) +-> Linear(hidden_dim, d_hidden) # Output d-dimensional credit +``` + +### Training Objective + +The bridge consistency becomes a **vector** consistency: + + a_phi(h_l, t_l, s) ≈ J_l^T a_phi(h_{l+1}, t_{l+1}, s) + +where J_l = I + dF_l/dh_l is the block Jacobian. But computing J_l^T v requires hidden BP, which violates the constraint! + +**Alternative 1: Forward-mode approximation** + +Use finite differences along the forward dynamics: + + a_phi(h_l, t_l, s) ≈ E_xi [ (a_phi(h_{l+1} + sigma*xi, t_{l+1}, s) - a_phi(h_{l+1}, t_{l+1}, s)) / sigma * xi + a_phi(h_{l+1}, t_{l+1}, s) ] + +Wait — this doesn't work either because it would need J_l^T, not J_l. + +**Alternative 2: Perturbation-based target** + +Train a_phi to predict local loss sensitivity directly: + + L_pert = E_v [ (<a_phi(h_l, t_l, s), v> - (loss(h_l + eps*v) - loss(h_l))/eps )^2 ] + +This is computationally expensive (need M forward passes per layer per sample) but provides a direct training signal for the credit vector. It doesn't require any Jacobian or hidden BP. + +**Alternative 3: Terminal matching + interpolation smoothness** + +- Terminal: a_phi(h_L, 1, s) = delta_L (exact output-layer gradient) +- Smoothness: ||a_phi(h_{l+0.5}, ...) - 0.5*a_phi(h_l, ...) - 0.5*a_phi(h_{l+1}, ...)||^2 + +This is similar to FM auxiliary but applied to the credit vector directly. + +**Alternative 4: Soft contrastive target** + + a_phi(h_l, t_l, s) should point in the direction that makes + V_target(h_l + eps*a_phi) < V_target(h_l - eps*a_phi) + +Using the EMA target network: + + L_contrastive = -log sigmoid( (V_bar(h_l - eps*a_norm, t_l, s) - V_bar(h_l + eps*a_norm, t_l, s)) / tau ) + +This trains a_phi to point "downhill" on the value landscape without needing the exact gradient. + +### Recommended Approach: Alternative 2 + Terminal Matching + +The perturbation-based target is the most principled because it directly measures what we want: local loss sensitivity. Combined with terminal matching: + + L_total = L_terminal + beta * L_perturbation + +Where: +- L_terminal = ||a_phi(h_L, 1, s) - delta_L||^2 +- L_perturbation = sum_l E_v [ (<a_phi(h_l, t_l, s), v> - (loss(h_l + eps*v) - loss(h_l))/eps)^2 ] + +With M=4 directions per layer, this needs 4*L extra forward-from-layer passes per batch. For L=4, that's 16 passes — expensive but tractable. + +## Does It Still Satisfy No Hidden BP Anchor? + +**Yes.** The perturbation-based target uses: +1. Forward-from-layer passes (no backprop through hidden layers) +2. Output-layer loss evaluation (no gradient extraction) +3. Terminal gradient matching (output-layer-local) + +No hidden-layer BP gradients are used as training targets at any point. + +## Minimal Test Setup + +**Task**: Synthetic teacher-student, alpha=1.0, L=4, d=128 (same as Phase 1 best regime) + +**Comparison**: +1. Current scalar credit bridge (V_phi -> grad_h V) — baseline +2. Direct vector credit field with perturbation target (M=4) +3. Direct vector credit field with perturbation target (M=8) + +**Metrics**: Same as Phase 1 (Gamma, rho, nudge) + +**Expected outcome**: +- Direct vector field should achieve higher rho than scalar V (it's directly trained to predict perturbation sensitivity) +- Gamma may or may not improve (depends on whether the perturbation target implicitly aligns with BP gradient) +- Training cost: ~4x per-step for M=4 due to extra forward passes + +**Implementation effort**: ~100 lines of new code. Reuse existing StudentNet and diagnostics. + +## Risk Assessment + +**Upside**: Direct vector field avoids the fundamental curvature problem. It's trained on exactly the quantity we care about (local loss sensitivity). + +**Downside**: The perturbation target is noisier than the scalar bridge consistency. With M=4 random directions, the variance of the gradient estimate is high in d=512 dimensions. + +**Mitigation**: Start with d=128 synthetic. If it works, gradually increase d. The perturbation target quality scales as sqrt(M/d), so d=512 with M=4 gives signal/noise ~ 0.09. May need M=32+ for CIFAR. + +## Bottom Line + +The scalar V approach has a fundamental curvature-vs-value disconnect. The direct vector field addresses this head-on. The recommended first step is a minimal test on the synthetic alpha=1.0, L=4 regime, comparing perturbation-trained vector field against the current scalar bridge. If it shows improved rho, scale up to CIFAR with higher M. |
