diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-26 22:07:35 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-03-26 22:07:35 -0500 |
| commit | b4e3cbeae6cb4cf4a4b69b84a475afcd7d7e9dbe (patch) | |
| tree | fca5a27504471091eba74a8f7efe2cf48eb85826 /experiments | |
| parent | 610e1169e19378cccd2d9b92a588c24dca7f3df7 (diff) | |
Add Phase 10A.6: gain requires trainable depth-aware aux, not semantic credit
9-branch dissection results:
- zero_target crashes (-9.1%): aux must output non-zero
- constant_input neutral (+0.0%): needs at least depth info
- time_only works (+1.0%): h_l not needed, just depth index
- shuffled/fresh_random work (+1.3-1.4%): no semantic content needed
- prefit60_trainable ≈ random_trainable: prefit adds nothing
- All frozen branches crash: trainability is essential
Mechanism: depth-aware trainable auxiliary perturbation that diversifies
block-local updates. Not semantic credit, not pure trainability.
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/structured_vs_semantic_aux.py | 853 |
1 files changed, 853 insertions, 0 deletions
diff --git a/experiments/structured_vs_semantic_aux.py b/experiments/structured_vs_semantic_aux.py new file mode 100644 index 0000000..15aca16 --- /dev/null +++ b/experiments/structured_vs_semantic_aux.py @@ -0,0 +1,853 @@ +""" +Phase 10A.6: Structured vs Semantic Auxiliary Dissection. + +Core question: What kind of structure in the auxiliary signal actually drives +the blend gain observed in Phase 10A? + +9 branches from the same DFA checkpoint at t0=5: +1. continue_DFA — pure DFA baseline +2. blend_random_trainable — random Vec, trained online (same as 10A.5) +3. blend_shuffled_trainable — Vec trained but targets shuffled within batch +4. blend_zero_target — Vec trained with ||a_aux||^2 loss (learns zero) +5. blend_fresh_random_target — Vec trained with fresh i.i.d. random targets each step +6. blend_time_only — auxiliary only sees t=l/L (no h_l) +7. blend_constant_input — auxiliary learns a fixed per-block template (no input) +8. blend_prefit60_frozen — Vec prefit for 60 epochs, then FROZEN +9. blend_prefit60_trainable — Vec prefit for 60 epochs, then continue training +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import SinusoidalTimeEmbed +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation + + +# --------------------------------------------------------------------------- +# Auxiliary network architectures +# --------------------------------------------------------------------------- + +class VectorCreditNet(nn.Module): + """Standard Vec: takes (h, t, s) and outputs d_hidden credit vector.""" + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + return self.net(torch.cat([self.ln(h), self.time_embed(t), s], dim=-1)) + + +class TimeOnlyNet(nn.Module): + """Auxiliary that only takes t=l/L as input (no h_l, no s). + Shared MLP on sinusoidal time embedding -> d_hidden. + """ + def __init__(self, d_hidden, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + layers = [] + for i in range(num_layers): + in_d = time_embed_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + # h and s are ignored — only t is used + return self.net(self.time_embed(t)) + + +class ConstantNet(nn.Module): + """Auxiliary with NO input at all: learns a per-block fixed output template. + Each block has an nn.Parameter of shape (d_hidden,). + forward(h, t, s) just expands the per-block param to match batch size. + Block index is passed as an integer via a separate call to set_block(). + """ + def __init__(self, d_hidden, num_blocks): + super().__init__() + # One template per block + self.templates = nn.ParameterList( + [nn.Parameter(torch.zeros(d_hidden)) for _ in range(num_blocks)] + ) + self._block_idx = 0 + + def set_block(self, l): + self._block_idx = l + + def forward(self, h, t, s): + batch = h.size(0) + return self.templates[self._block_idx].unsqueeze(0).expand(batch, -1) + + +# --------------------------------------------------------------------------- +# Data +# --------------------------------------------------------------------------- + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]) + trainset = torchvision.datasets.CIFAR10( + root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10( + root='./data', train=False, download=True, transform=transform_test) + return (DataLoader(trainset, batch_size=batch_size, shuffle=True, + num_workers=4, pin_memory=True), + DataLoader(testset, batch_size=batch_size, shuffle=False, + num_workers=4, pin_memory=True)) + + +# --------------------------------------------------------------------------- +# Evaluation helpers +# --------------------------------------------------------------------------- + +def evaluate(model, test_loader, device): + model.eval(); c, t = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + c += (model(x).argmax(1) == y).sum().item(); t += x.size(0) + return c / t + + +def compute_diagnostics(model, aux_net, Bs, test_loader, device, + credit_mode, alpha=0.75, const_net_L=None): + """Compute mean Gamma (BP cosine) and mean rho (perturbation correlation).""" + model.eval() + if aux_net is not None: + aux_net.eval() + L = model.num_blocks + + # Get one batch + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device); break + batch = x.size(0) + + # BP pass to get hidden gradients + was_frozen = not next(model.parameters()).requires_grad + if was_frozen: + for p in model.parameters(): p.requires_grad_(True) + model.zero_grad() + lo, hbp = model(x, return_hidden=True) + for l in range(L + 1): hbp[l].retain_grad() + F.cross_entropy(lo, y).backward() + bp = {l: hbp[l].grad.detach().clone() for l in range(L + 1)} + if was_frozen: + for p in model.parameters(): p.requires_grad_(False) + + with torch.no_grad(): + lo2, hi = model(x, return_hidden=True) + eT = lo2.softmax(-1); eT[torch.arange(batch), y] -= 1; s = eT.detach() + + gammas, rhos = [], [] + for l in range(L): + h_l = hi[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + if credit_mode == 'dfa': + a_l = (s @ Bs[l].T).detach() + elif credit_mode == 'blend' and aux_net is not None: + a_dfa = (s @ Bs[l].T).detach() + if isinstance(aux_net, ConstantNet): + aux_net.set_block(l) + a_aux = aux_net(h_l, t_l, s).detach() + rd = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + rv = (a_aux ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_l = alpha * a_aux / rv + (1 - alpha) * a_dfa / rd + else: + a_l = (s @ Bs[l].T).detach() + + gammas.append(cosine_similarity_batch(a_l, bp[l])) + + def make_fwd(sl): + def f(h): + with torch.no_grad(): + c = h + for i in range(sl, L): + c = c + model.blocks[i](c) + return F.cross_entropy( + model.out_head(model.out_ln(c)), y, reduction='none') + return f + + rhos.append(perturbation_correlation(h_l, a_l, make_fwd(l), + epsilon=1e-3, M=16)) + + return float(np.mean(gammas)), float(np.mean(rhos)) + + +# --------------------------------------------------------------------------- +# DFA training + checkpoint +# --------------------------------------------------------------------------- + +def train_dfa_get_checkpoint(model, train_loader, test_loader, device, + total_epochs, t0, lr, wd): + d = model.d_hidden; L = model.num_blocks + Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) + for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + scheds = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=total_epochs) + for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=total_epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=total_epochs)]) + ckpt = None + for epoch in range(1, total_epochs + 1): + model.train(); tl, c, t = 0, 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device); b = x.size(0) + with torch.no_grad(): + lo, hi = model(x, return_hidden=True); lv = F.cross_entropy(lo, y) + eT = lo.softmax(-1); eT[torch.arange(b), y] -= 1 + hL = hi[-1].detach() + lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y) + head_opt.zero_grad(); lo2.backward(); head_opt.step() + for l in range(L): + a = (eT @ Bs[l].T).detach() + rm = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + f = model.blocks[l](hi[l].detach()) + ll = (f * (a / rm)).sum(-1).mean() + block_opts[l].zero_grad(); ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + a0 = (eT @ Bs[0].T).detach() + r0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + el = (model.embed(x) * (a0 / r0)).sum(-1).mean() + embed_opt.zero_grad(); el.backward(); embed_opt.step() + tl += lv.item() * b; c += (lo.argmax(1) == y).sum().item(); t += b + for s in scheds: s.step() + if epoch == t0: + acc = evaluate(model, test_loader, device) + ckpt = {'model': copy.deepcopy(model.state_dict()), + 'Bs': [B.clone() for B in Bs], 'acc': acc} + print(f" [DFA] Checkpoint at epoch {t0}: acc={acc:.4f}") + if epoch % 10 == 0: + print(f" [DFA] Epoch {epoch}: acc={evaluate(model, test_loader, device):.4f}") + return Bs, ckpt + + +# --------------------------------------------------------------------------- +# Offline Vec prefit +# --------------------------------------------------------------------------- + +def offline_fit_vec(vec_net, model, Bs, train_loader, device, epochs, lr_fb, M, eps_pert=1e-3): + """Pre-train VectorCreditNet using perturbation targets (no forward net training).""" + L = model.num_blocks + vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb) + model.eval() + print(f" [prefit] Starting Vec prefit for {epochs} epochs...") + for epoch in range(1, epochs + 1): + vec_net.train(); total_loss = 0.0; nb = 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device); batch = x.size(0) + with torch.no_grad(): + lo, hi = model(x, return_hidden=True) + eT = lo.softmax(-1); eT[torch.arange(batch), y] -= 1; s = eT.detach() + hL = hi[-1].detach() + + # Terminal anchor + t_L = torch.ones(batch, device=device) + a_term = vec_net(hL, t_L, s) + hL_req = hL.clone().requires_grad_(True) + ce = F.cross_entropy(model.out_head(model.out_ln(hL_req)), y, + reduction='sum') + dL = torch.autograd.grad(ce, hL_req)[0].detach() + loss_term = ((a_term - dL) ** 2).sum(-1).mean() + + # Perturbation loss at a random layer + lt = np.random.randint(0, L) + h_l = hi[lt].detach() + t_l = torch.full((batch,), lt / L, device=device) + a_l = vec_net(h_l, t_l, s) + lp2 = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy( + model.forward_from_layer(h_l + eps_pert * v, lt), y, reduction='none') + lm = F.cross_entropy( + model.forward_from_layer(h_l - eps_pert * v, lt), y, reduction='none') + gj = (lp - lm) / (2 * eps_pert) + lp2 = lp2 + (((a_l * v).sum(-1) - gj.detach()) ** 2).mean() + lp2 /= M + vl = loss_term + lp2 + vec_opt.zero_grad(); vl.backward() + torch.nn.utils.clip_grad_norm_(vec_net.parameters(), 1.0) + vec_opt.step() + total_loss += vl.item(); nb += 1 + if epoch % 10 == 0 or epoch == epochs: + print(f" [prefit] Epoch {epoch}/{epochs}: loss={total_loss/nb:.4f}") + vec_net.eval() + return vec_net + + +# --------------------------------------------------------------------------- +# Estimate shuffled-target RMS (for fresh_random sigma matching) +# --------------------------------------------------------------------------- + +def estimate_shuffled_rms(model, Bs, train_loader, device, M=4, eps_pert=1e-3, n_batches=5): + """Estimate RMS of shuffled perturbation targets to match for fresh_random branch.""" + model.eval(); L = model.num_blocks + rms_vals = [] + for i, (x, y) in enumerate(train_loader): + if i >= n_batches: break + x = x.view(x.size(0), -1).to(device); y = y.to(device); batch = x.size(0) + with torch.no_grad(): + lo, hi = model(x, return_hidden=True) + lt = np.random.randint(0, L) + h_l = hi[lt].detach() + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy( + model.forward_from_layer(h_l + eps_pert * v, lt), y, reduction='none') + lm = F.cross_entropy( + model.forward_from_layer(h_l - eps_pert * v, lt), y, reduction='none') + gj = (lp - lm) / (2 * eps_pert) + # Shuffled + gj_sh = gj[torch.randperm(batch, device=device)] + rms_vals.append(gj_sh.pow(2).mean().sqrt().item()) + return float(np.mean(rms_vals)) if rms_vals else 0.01 + + +# --------------------------------------------------------------------------- +# Branch runner +# --------------------------------------------------------------------------- + +def run_branch(model, aux_net, Bs, train_loader, test_loader, device, + t0, total_epochs, branch_type, alpha, lr, lr_fb, wd, M, + branch_name='', fresh_sigma=0.01): + """ + Run a training branch from a loaded checkpoint. + + branch_type options: + 'dfa' — pure DFA + 'blend_trainable' — blend with Vec trained online (perturbation targets) + 'blend_shuffled' — blend with Vec trained on shuffled targets + 'blend_zero_target' — blend with Vec trained toward zero (||a_aux||^2) + 'blend_fresh_random' — blend with Vec trained on fresh random targets each step + 'blend_time_only' — blend with TimeOnlyNet trained online + 'blend_constant' — blend with ConstantNet trained online + 'blend_frozen' — blend with pre-fit Vec, frozen + 'blend_prefit_trainable'— blend with pre-fit Vec, continue training + """ + d = model.d_hidden; L = model.num_blocks; eps_pert = 1e-3 + + # Determine which aux nets need training + trainable_types = {'blend_trainable', 'blend_shuffled', 'blend_zero_target', + 'blend_fresh_random', 'blend_time_only', 'blend_constant', + 'blend_prefit_trainable'} + aux_trained = (branch_type in trainable_types) and (aux_net is not None) + + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) + for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW( + list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + aux_opt = optim.Adam(aux_net.parameters(), lr=lr_fb) if aux_trained else None + + scheds = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=total_epochs) + for o in block_opts] + + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=total_epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=total_epochs)]) + # Advance schedulers to match checkpoint epoch + for _ in range(t0): + for s in scheds: s.step() + + log = {'test_acc': [], 'train_loss': [], 'gamma': [], 'rho': [], 'alpha_eff': []} + diag_epochs = set( + list(range(t0 + 1, min(t0 + 6, total_epochs + 1))) + + [t0 + 8, t0 + 10, t0 + 15, t0 + 20] + + list(range(t0 + 10, total_epochs + 1, 10)) + + [total_epochs]) + + for epoch in range(t0 + 1, total_epochs + 1): + model.train() + if aux_net is not None and aux_trained: + aux_net.train() + elif aux_net is not None: + aux_net.eval() + + tl, c, t = 0, 0, 0 + epoch_aux_norms, epoch_dfa_norms = [], [] + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device); batch = x.size(0) + with torch.no_grad(): + lo, hi = model(x, return_hidden=True); lv = F.cross_entropy(lo, y) + eT = lo.softmax(-1); eT[torch.arange(batch), y] -= 1; s = eT.detach() + hL = hi[-1].detach() + + # ---------------------------------------------------------------- + # Train auxiliary network (if applicable) + # ---------------------------------------------------------------- + if aux_opt is not None: + if branch_type in ('blend_trainable', 'blend_prefit_trainable', + 'blend_time_only'): + # Standard perturbation targets + t_L = torch.ones(batch, device=device) + a_term = aux_net(hL, t_L, s) + hL_req = hL.clone().requires_grad_(True) + ce = F.cross_entropy( + model.out_head(model.out_ln(hL_req)), y, reduction='sum') + dL = torch.autograd.grad(ce, hL_req)[0].detach() + loss_term = ((a_term - dL) ** 2).sum(-1).mean() + lt = np.random.randint(0, L) + h_l = hi[lt].detach() + t_l = torch.full((batch,), lt / L, device=device) + a_l = aux_net(h_l, t_l, s) + lp2 = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy( + model.forward_from_layer(h_l + eps_pert * v, lt), + y, reduction='none') + lm = F.cross_entropy( + model.forward_from_layer(h_l - eps_pert * v, lt), + y, reduction='none') + gj = (lp - lm) / (2 * eps_pert) + lp2 = lp2 + (((a_l * v).sum(-1) - gj.detach()) ** 2).mean() + lp2 /= M + vl = loss_term + lp2 + + elif branch_type == 'blend_shuffled': + # Perturbation targets shuffled within batch + t_L = torch.ones(batch, device=device) + a_term = aux_net(hL, t_L, s) + hL_req = hL.clone().requires_grad_(True) + ce = F.cross_entropy( + model.out_head(model.out_ln(hL_req)), y, reduction='sum') + dL = torch.autograd.grad(ce, hL_req)[0].detach() + dL_sh = dL[torch.randperm(batch, device=device)] + loss_term = ((a_term - dL_sh) ** 2).sum(-1).mean() + lt = np.random.randint(0, L) + h_l = hi[lt].detach() + t_l = torch.full((batch,), lt / L, device=device) + a_l = aux_net(h_l, t_l, s) + lp2 = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy( + model.forward_from_layer(h_l + eps_pert * v, lt), + y, reduction='none') + lm = F.cross_entropy( + model.forward_from_layer(h_l - eps_pert * v, lt), + y, reduction='none') + gj = (lp - lm) / (2 * eps_pert) + # Shuffle targets within batch to break semantic correlation + gj_sh = gj[torch.randperm(batch, device=device)] + lp2 = lp2 + (((a_l * v).sum(-1) - gj_sh.detach()) ** 2).mean() + lp2 /= M + vl = loss_term + lp2 + + elif branch_type == 'blend_zero_target': + # Minimize ||a_aux||^2 — teaches the network to output zero + lt = np.random.randint(0, L) + h_l = hi[lt].detach() + t_l = torch.full((batch,), lt / L, device=device) + if isinstance(aux_net, ConstantNet): + aux_net.set_block(lt) + a_l = aux_net(h_l, t_l, s) + vl = (a_l ** 2).sum(-1).mean() + + elif branch_type == 'blend_fresh_random': + # Fresh i.i.d. random targets sampled every step + lt = np.random.randint(0, L) + h_l = hi[lt].detach() + t_l = torch.full((batch,), lt / L, device=device) + a_l = aux_net(h_l, t_l, s) + # Sample random target with RMS matched to shuffled scale + rand_target = torch.randn(batch, device=device) * fresh_sigma + lp2 = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(-1, keepdim=True) + 1e-8) + # Fresh random target per direction per step + gj_rand = torch.randn(batch, device=device) * fresh_sigma + lp2 = lp2 + (((a_l * v).sum(-1) - gj_rand) ** 2).mean() + lp2 /= M + vl = lp2 + + elif branch_type == 'blend_constant': + # ConstantNet: train per-block templates + # Train all blocks this step (cheap since it's just parameters) + vl = torch.tensor(0.0, device=device) + lt = np.random.randint(0, L) + h_l = hi[lt].detach() + t_l = torch.full((batch,), lt / L, device=device) + aux_net.set_block(lt) + a_l = aux_net(h_l, t_l, s) + lp2_inner = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy( + model.forward_from_layer(h_l + eps_pert * v, lt), + y, reduction='none') + lm = F.cross_entropy( + model.forward_from_layer(h_l - eps_pert * v, lt), + y, reduction='none') + gj = (lp - lm) / (2 * eps_pert) + lp2_inner = lp2_inner + ( + ((a_l * v).sum(-1) - gj.detach()) ** 2).mean() + vl = lp2_inner / M + else: + vl = None + + if vl is not None: + aux_opt.zero_grad(); vl.backward() + torch.nn.utils.clip_grad_norm_(aux_net.parameters(), 1.0) + aux_opt.step() + + # ---------------------------------------------------------------- + # Compute credits for each block + # ---------------------------------------------------------------- + dfa_credits = [(eT @ Bs[l].T).detach() for l in range(L)] + credits = [] + for l in range(L): + a_dfa = dfa_credits[l] + rms_d = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + + if branch_type == 'dfa': + credits.append(a_dfa / rms_d) + else: + # All blend branches + h_l = hi[l].detach() + t_l = torch.full((batch,), l / L, device=device) + with torch.no_grad(): + if isinstance(aux_net, ConstantNet): + aux_net.set_block(l) + a_aux = aux_net(h_l, t_l, s).detach() + rms_v = (a_aux ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_blend = alpha * a_aux / rms_v + (1 - alpha) * a_dfa / rms_d + credits.append(a_blend) + + # Track norms for alpha_eff + a_c = credits[-1] + if branch_type == 'dfa': + epoch_aux_norms.append(0.0) + epoch_dfa_norms.append(a_c.norm().item()) + else: + a_dfa_n = a_dfa / rms_d + rms_v2 = (a_aux ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + epoch_aux_norms.append((alpha * a_aux / rms_v2).norm().item()) + epoch_dfa_norms.append(((1 - alpha) * a_dfa_n).norm().item()) + + # ---------------------------------------------------------------- + # Update output head (local exact gradient — allowed) + # ---------------------------------------------------------------- + lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y) + head_opt.zero_grad(); lo2.backward(); head_opt.step() + + # ---------------------------------------------------------------- + # Update blocks with local surrogate + # a_total is already blended; normalize before inner product + # ---------------------------------------------------------------- + for l in range(L): + a = credits[l] + rm = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + f = model.blocks[l](hi[l].detach()) + ll = (f * (a / rm)).sum(-1).mean() + block_opts[l].zero_grad(); ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + # Update embedding with block-0 credit + a0 = credits[0] + r0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + el = (model.embed(x) * (a0 / r0)).sum(-1).mean() + embed_opt.zero_grad(); el.backward(); embed_opt.step() + + tl += lv.item() * batch; c += (lo.argmax(1) == y).sum().item(); t += batch + + for sch in scheds: sch.step() + ta = evaluate(model, test_loader, device) + log['test_acc'].append(ta); log['train_loss'].append(tl / t) + + mean_aux = np.mean(epoch_aux_norms) if epoch_aux_norms else 0.0 + mean_dfa = np.mean(epoch_dfa_norms) if epoch_dfa_norms else 1.0 + aeff = mean_aux / (mean_aux + mean_dfa + 1e-12) + log['alpha_eff'].append((epoch, aeff)) + + if epoch in diag_epochs: + cm = 'blend' if branch_type != 'dfa' else 'dfa' + gamma, rho = compute_diagnostics( + model, aux_net if branch_type != 'dfa' else None, + Bs, test_loader, device, cm, alpha) + log['gamma'].append((epoch, gamma)); log['rho'].append((epoch, rho)) + if epoch <= t0 + 15 or epoch % 20 == 0 or epoch == total_epochs: + print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}, " + f"G={gamma:.4f}, r={rho:.4f}, aeff={aeff:.3f}") + elif epoch % 10 == 0 or epoch == total_epochs: + print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}") + + return log + + +# --------------------------------------------------------------------------- +# Main experiment +# --------------------------------------------------------------------------- + +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + torch.manual_seed(args.seed); np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + train_loader, test_loader = get_cifar10(args.batch_size) + input_dim = 32 * 32 * 3; L = args.num_blocks; d = args.d_hidden + + # ---------------------------------------------------------------- + # Step 1: Train DFA and capture checkpoint at t0 + # ---------------------------------------------------------------- + print(f"\n{'='*60}\nTraining DFA baseline (checkpoint at t0={args.t0})\n{'='*60}") + model_dfa = ResidualMLP(input_dim, d, 10, L).to(device) + Bs, ckpt = train_dfa_get_checkpoint( + model_dfa, train_loader, test_loader, device, + args.epochs, args.t0, args.lr, args.wd) + print(f" Checkpoint acc at t0={args.t0}: {ckpt['acc']:.4f}") + + # ---------------------------------------------------------------- + # Step 2: Offline prefit of Vec (for prefit60 branches) + # ---------------------------------------------------------------- + print(f"\n{'='*60}\nOffline prefit of Vec (60 epochs)\n{'='*60}") + torch.manual_seed(args.seed + 7777) + vec_prefit = VectorCreditNet(d_hidden=d, s_dim=10).to(device) + # Load checkpoint model for prefit (weights at t0, forward net frozen) + model_prefit = ResidualMLP(input_dim, d, 10, L).to(device) + model_prefit.load_state_dict(ckpt['model']) + model_prefit.eval() + for p in model_prefit.parameters(): p.requires_grad_(False) + vec_prefit = offline_fit_vec( + vec_prefit, model_prefit, ckpt['Bs'], + train_loader, device, epochs=60, lr_fb=args.lr_fb, M=args.M) + del model_prefit + + # ---------------------------------------------------------------- + # Step 3: Estimate sigma for fresh_random (match shuffled-target RMS) + # ---------------------------------------------------------------- + model_ref = ResidualMLP(input_dim, d, 10, L).to(device) + model_ref.load_state_dict(ckpt['model']); model_ref.eval() + for p in model_ref.parameters(): p.requires_grad_(False) + fresh_sigma = estimate_shuffled_rms(model_ref, ckpt['Bs'], train_loader, device) + print(f" fresh_random sigma (matched to shuffled RMS): {fresh_sigma:.5f}") + del model_ref + + # ---------------------------------------------------------------- + # Step 4: Define and run all 9 branches + # ---------------------------------------------------------------- + # Each entry: (branch_name, branch_type, aux_factory_fn) + # aux_factory_fn returns a fresh aux_net (or None for dfa) + VEC_SEED = args.seed + 7777 + + def make_vec(): + torch.manual_seed(VEC_SEED) + return VectorCreditNet(d_hidden=d, s_dim=10).to(device) + + def make_time_only(): + torch.manual_seed(VEC_SEED) + return TimeOnlyNet(d_hidden=d).to(device) + + def make_constant(): + torch.manual_seed(VEC_SEED) + return ConstantNet(d_hidden=d, num_blocks=L).to(device) + + def make_prefit_frozen(): + net = copy.deepcopy(vec_prefit) + for p in net.parameters(): p.requires_grad_(False) + net.eval() + return net + + def make_prefit_trainable(): + return copy.deepcopy(vec_prefit) + + branches = [ + # (name, branch_type, aux_factory) + ('continue_DFA', 'dfa', lambda: None), + ('blend_random_trainable', 'blend_trainable', make_vec), + ('blend_shuffled_trainable', 'blend_shuffled', make_vec), + ('blend_zero_target_trainable', 'blend_zero_target', make_vec), + ('blend_fresh_random_target', 'blend_fresh_random', make_vec), + ('blend_time_only_trainable', 'blend_time_only', make_time_only), + ('blend_constant_input', 'blend_constant', make_constant), + ('blend_prefit60_frozen', 'blend_frozen', make_prefit_frozen), + ('blend_prefit60_trainable', 'blend_prefit_trainable', make_prefit_trainable), + ] + + all_results = {} + for bname, btype, aux_factory in branches: + print(f"\n{'='*60}\n{bname}\n{'='*60}") + # Fresh model from checkpoint + model_b = ResidualMLP(input_dim, d, 10, L).to(device) + model_b.load_state_dict(ckpt['model']) + aux_net_b = aux_factory() + + log = run_branch( + model_b, aux_net_b, ckpt['Bs'], + train_loader, test_loader, device, + args.t0, args.epochs, btype, + args.alpha, args.lr, args.lr_fb, args.wd, args.M, + branch_name=bname, fresh_sigma=fresh_sigma) + all_results[bname] = log + print(f" {bname} final acc: {log['test_acc'][-1]:.4f}") + + # ---------------------------------------------------------------- + # Step 5: Summary table + # ---------------------------------------------------------------- + dfa_final = all_results['continue_DFA']['test_acc'][-1] + + print(f"\n{'='*95}") + print("SUMMARY — Phase 10A.6: Structured vs Semantic Auxiliary") + print(f"{'='*95}") + print(f"{'Branch':<34} {'@20':>6} {'final':>7} {'diff':>7} " + f"{'mG_5:15':>9} {'mr_5:15':>9} {'aeff':>7}") + print("-" * 83) + + for bname, log in all_results.items(): + accs = log['test_acc'] + idx20 = max(0, 20 - args.t0 - 1) + acc20 = accs[idx20] if len(accs) > idx20 else accs[-1] + final = accs[-1] + diff = final - dfa_final + gammas_e = [g for e, g in log['gamma'] if args.t0 < e <= args.t0 + 15] + rhos_e = [r for e, r in log['rho'] if args.t0 < e <= args.t0 + 15] + aeffs_e = [a for e, a in log['alpha_eff'] if args.t0 < e <= args.t0 + 15] + mg = float(np.mean(gammas_e)) if gammas_e else float('nan') + mr = float(np.mean(rhos_e)) if rhos_e else float('nan') + mae = float(np.mean(aeffs_e)) if aeffs_e else float('nan') + print(f"{bname:<34} {acc20:>6.4f} {final:>7.4f} {diff:>+7.4f} " + f"{mg:>9.4f} {mr:>9.4f} {mae:>7.3f}") + + # ---------------------------------------------------------------- + # Step 6: Save results + # ---------------------------------------------------------------- + save_data = { + 'args': vars(args), + 'dfa_ckpt_acc': float(ckpt['acc']), + 'fresh_sigma': float(fresh_sigma), + } + for bname, log in all_results.items(): + save_data[bname] = { + 'test_acc': log['test_acc'], + 'train_loss': log['train_loss'], + 'gamma': log['gamma'], + 'rho': log['rho'], + 'alpha_eff': log['alpha_eff'], + } + out_path = os.path.join(args.output_dir, + f'structured_aux_t{args.t0}_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2, default=float) + print(f"\nSaved to {out_path}") + + # ---------------------------------------------------------------- + # Step 7: Judgment + # ---------------------------------------------------------------- + print(f"\n{'='*60}\nJUDGMENT\n{'='*60}") + r = {bname: log['test_acc'][-1] for bname, log in all_results.items()} + dfa = r['continue_DFA'] + rt = r.get('blend_random_trainable', float('nan')) + sh = r.get('blend_shuffled_trainable', float('nan')) + zt = r.get('blend_zero_target_trainable', float('nan')) + fr = r.get('blend_fresh_random_target', float('nan')) + to = r.get('blend_time_only_trainable', float('nan')) + cn = r.get('blend_constant_input', float('nan')) + pf = r.get('blend_prefit60_frozen', float('nan')) + pt = r.get('blend_prefit60_trainable', float('nan')) + + print(f" DFA={dfa:.4f} rt={rt:.4f} sh={sh:.4f} zt={zt:.4f} " + f"fr={fr:.4f} to={to:.4f} cn={cn:.4f} pf={pf:.4f} pt={pt:.4f}") + + # Key comparisons + thr = 0.003 + if abs(rt - sh) < thr and abs(rt - fr) < thr: + print(" -> random_trainable ≈ shuffled ≈ fresh_random: " + "gain is NOT from semantic content, likely norm/regularization") + elif rt > sh + thr and rt > fr + thr: + print(" -> random_trainable > shuffled AND fresh_random: " + "some semantic structure in Vec helps") + + if abs(rt - zt) < thr: + print(" -> zero_target ≈ random_trainable: " + "the implicit regularization is from the blend ratio, not signal direction") + elif zt < dfa - thr: + print(" -> zero_target HURTS vs DFA: driving aux toward zero is not just neutral") + + if abs(rt - to) < thr: + print(" -> time_only ≈ random_trainable: " + "depth schedule alone captures most of the benefit") + if abs(rt - cn) < thr: + print(" -> constant_input ≈ random_trainable: " + "fixed per-block template sufficient (no input processing needed)") + + if pt > pf + thr: + print(" -> prefit60_trainable > prefit60_frozen: " + "continued training after prefit improves further") + elif abs(pt - pf) < thr: + print(" -> prefit60_trainable ≈ prefit60_frozen: " + "prefit quality saturates, additional online training neutral") + + if pt > rt + thr: + print(" -> prefit60_trainable > random_trainable: " + "warm-start from prefit helps over training from scratch") + + +def main(): + parser = argparse.ArgumentParser( + description='Phase 10A.6: Structured vs Semantic Auxiliary Dissection') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--t0', type=int, default=5) + parser.add_argument('--alpha', type=float, default=0.75) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--M', type=int, default=4) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=2) + parser.add_argument('--output_dir', type=str, default='results/structured_aux') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() |
