summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-27 16:39:17 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-27 16:39:17 -0500
commit4d6e689fe6bfffef6db7a4650aec210cd3eeed5c (patch)
treefa8b6d123a51bab4b17a07f787cc89e74584397f /experiments
parent65d97ad1ef4b552103420e6501655df192c98d57 (diff)
Add Phase 10A.8: freeze-with-decay confirms stale aux is main freeze failure cause;
alpha sweep shows perlayer_vector at alpha=0.75 matches full network 10A.8A: freeze_decay_to_000 recovers to 28.5% (vs 14.6% fixed freeze) — stale high-weight aux is the primary cause of freeze crashes. But 28.5% < DFA 31.2% confirms continuous trainability adds ~2.7% independent value. 10A.8B: Both perlayer_vector and random_trainable optimal at alpha=0.75. perlayer_vector +1.1% vs random_trainable +0.8% — per-layer vector is the minimal sufficient scaffold, no network needed. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments')
-rw-r--r--experiments/alpha_sweep_scaffold.py625
-rw-r--r--experiments/freeze_with_decay.py653
2 files changed, 1278 insertions, 0 deletions
diff --git a/experiments/alpha_sweep_scaffold.py b/experiments/alpha_sweep_scaffold.py
new file mode 100644
index 0000000..07e6b6e
--- /dev/null
+++ b/experiments/alpha_sweep_scaffold.py
@@ -0,0 +1,625 @@
+"""
+Phase 10A.8B: Alpha Sweep Scaffold.
+
+Core question: What is the optimal blend weight alpha for each auxiliary type?
+
+9 branches from the same DFA checkpoint at t0=5:
+1. continue_DFA — pure DFA baseline
+2. blend_perlayer_vector_alpha025 — PerLayerVector, alpha=0.25
+3. blend_perlayer_vector_alpha050 — PerLayerVector, alpha=0.50
+4. blend_perlayer_vector_alpha075 — PerLayerVector, alpha=0.75
+5. blend_perlayer_vector_alpha090 — PerLayerVector, alpha=0.90
+6. blend_random_trainable_alpha025 — VectorCreditNet, alpha=0.25
+7. blend_random_trainable_alpha050 — VectorCreditNet, alpha=0.50
+8. blend_random_trainable_alpha075 — VectorCreditNet, alpha=0.75
+9. blend_random_trainable_alpha090 — VectorCreditNet, alpha=0.90
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import SinusoidalTimeEmbed
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation
+
+
+# ---------------------------------------------------------------------------
+# Auxiliary network architectures
+# ---------------------------------------------------------------------------
+
+class VectorCreditNet(nn.Module):
+ """Standard Vec: takes (h, t, s) -> d_hidden credit vector."""
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ return self.net(torch.cat([self.ln(h), self.time_embed(t), s], dim=-1))
+
+
+class PerLayerVector(nn.Module):
+ """No network: each block l has a trainable nn.Parameter v_l of shape (d_hidden,).
+ All samples in a batch receive the same v_l (broadcast).
+ forward(h, t, s) ignores h, t, s and returns v_l expanded to (batch, d_hidden).
+ Must call set_block(l) before forward to select the right block vector.
+ """
+ def __init__(self, d_hidden, num_blocks):
+ super().__init__()
+ # Initialize with small random values (std=0.01)
+ self.vectors = nn.ParameterList(
+ [nn.Parameter(torch.randn(d_hidden) * 0.01) for _ in range(num_blocks)]
+ )
+ self._block_idx = 0
+
+ def set_block(self, l):
+ self._block_idx = l
+
+ def forward(self, h, t, s):
+ batch = h.size(0)
+ return self.vectors[self._block_idx].unsqueeze(0).expand(batch, -1)
+
+
+# ---------------------------------------------------------------------------
+# Data
+# ---------------------------------------------------------------------------
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
+ trainset = torchvision.datasets.CIFAR10(
+ root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(
+ root='./data', train=False, download=True, transform=transform_test)
+ return (DataLoader(trainset, batch_size=batch_size, shuffle=True,
+ num_workers=4, pin_memory=True),
+ DataLoader(testset, batch_size=batch_size, shuffle=False,
+ num_workers=4, pin_memory=True))
+
+
+# ---------------------------------------------------------------------------
+# Evaluation helpers
+# ---------------------------------------------------------------------------
+
+def evaluate(model, test_loader, device):
+ model.eval(); c, t = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ c += (model(x).argmax(1) == y).sum().item(); t += x.size(0)
+ return c / t
+
+
+def compute_diagnostics(model, aux_net, Bs, test_loader, device, credit_mode, alpha=0.75):
+ """Compute mean Gamma (BP cosine) and mean rho (perturbation correlation)."""
+ model.eval()
+ if aux_net is not None:
+ aux_net.eval()
+ L = model.num_blocks
+
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); break
+ batch = x.size(0)
+
+ # BP pass for hidden gradients (offline eval only, not used for training)
+ was_frozen = not next(model.parameters()).requires_grad
+ if was_frozen:
+ for p in model.parameters(): p.requires_grad_(True)
+ model.zero_grad()
+ lo, hbp = model(x, return_hidden=True)
+ for l in range(L + 1): hbp[l].retain_grad()
+ F.cross_entropy(lo, y).backward()
+ bp = {l: hbp[l].grad.detach().clone() for l in range(L + 1)}
+ if was_frozen:
+ for p in model.parameters(): p.requires_grad_(False)
+
+ with torch.no_grad():
+ lo2, hi = model(x, return_hidden=True)
+ eT = lo2.softmax(-1); eT[torch.arange(batch), y] -= 1; s = eT.detach()
+
+ gammas, rhos = [], []
+ for l in range(L):
+ h_l = hi[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ if credit_mode == 'dfa':
+ a_l = (s @ Bs[l].T).detach()
+ elif credit_mode == 'blend' and aux_net is not None:
+ a_dfa = (s @ Bs[l].T).detach()
+ if isinstance(aux_net, PerLayerVector):
+ aux_net.set_block(l)
+ a_aux = aux_net(h_l, t_l, s).detach()
+ rd = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ rv = (a_aux ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_l = alpha * a_aux / rv + (1 - alpha) * a_dfa / rd
+ else:
+ a_l = (s @ Bs[l].T).detach()
+
+ gammas.append(cosine_similarity_batch(a_l, bp[l]))
+
+ def make_fwd(sl):
+ def f(h):
+ with torch.no_grad():
+ c = h
+ for i in range(sl, L):
+ c = c + model.blocks[i](c)
+ return F.cross_entropy(
+ model.out_head(model.out_ln(c)), y, reduction='none')
+ return f
+
+ rhos.append(perturbation_correlation(h_l, a_l, make_fwd(l), epsilon=1e-3, M=16))
+
+ return float(np.mean(gammas)), float(np.mean(rhos))
+
+
+# ---------------------------------------------------------------------------
+# DFA training + checkpoint
+# ---------------------------------------------------------------------------
+
+def train_dfa_get_checkpoint(model, train_loader, test_loader, device,
+ total_epochs, t0, lr, wd):
+ d = model.d_hidden; L = model.num_blocks
+ Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd)
+ for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ scheds = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=total_epochs)
+ for o in block_opts] +
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=total_epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=total_epochs)])
+ ckpt = None
+ for epoch in range(1, total_epochs + 1):
+ model.train(); tl, c, t = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); b = x.size(0)
+ with torch.no_grad():
+ lo, hi = model(x, return_hidden=True); lv = F.cross_entropy(lo, y)
+ eT = lo.softmax(-1); eT[torch.arange(b), y] -= 1
+ hL = hi[-1].detach()
+ lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y)
+ head_opt.zero_grad(); lo2.backward(); head_opt.step()
+ for l in range(L):
+ a = (eT @ Bs[l].T).detach()
+ rm = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](hi[l].detach())
+ ll = (f * (a / rm)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a0 = (eT @ Bs[0].T).detach()
+ r0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ el = (model.embed(x) * (a0 / r0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+ tl += lv.item() * b; c += (lo.argmax(1) == y).sum().item(); t += b
+ for s in scheds: s.step()
+ if epoch == t0:
+ acc = evaluate(model, test_loader, device)
+ ckpt = {'model': copy.deepcopy(model.state_dict()),
+ 'Bs': [B.clone() for B in Bs], 'acc': acc}
+ print(f" [DFA] Checkpoint at epoch {t0}: acc={acc:.4f}")
+ if epoch % 10 == 0:
+ print(f" [DFA] Epoch {epoch}: acc={evaluate(model, test_loader, device):.4f}")
+ return Bs, ckpt
+
+
+# ---------------------------------------------------------------------------
+# Branch runner
+# ---------------------------------------------------------------------------
+
+def run_branch(model, aux_net, Bs, train_loader, test_loader, device,
+ t0, total_epochs, branch_type, alpha, lr, lr_fb, wd, M,
+ branch_name=''):
+ """
+ Run a training branch from a loaded checkpoint.
+
+ branch_type options:
+ 'dfa' — pure DFA baseline
+ 'blend_perlayer' — blend with PerLayerVector trained online (perturbation targets)
+ 'blend_trainable' — blend with VectorCreditNet trained online (perturbation targets)
+
+ alpha is fixed for the entire run (no warmup/decay).
+ Both aux types are trained online continuously after handoff.
+ """
+ d = model.d_hidden; L = model.num_blocks; eps_pert = 1e-3
+
+ trainable_types = {'blend_perlayer', 'blend_trainable'}
+ aux_trained = (branch_type in trainable_types) and (aux_net is not None)
+
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd)
+ for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+
+ if aux_trained:
+ aux_opt = optim.Adam(aux_net.parameters(), lr=lr_fb)
+ else:
+ aux_opt = None
+
+ scheds = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=total_epochs)
+ for o in block_opts] +
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=total_epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=total_epochs)])
+ # Advance schedulers to match checkpoint epoch
+ for _ in range(t0):
+ for s in scheds: s.step()
+
+ log = {'test_acc': [], 'train_loss': [], 'gamma': [], 'rho': [], 'alpha_eff': []}
+ diag_epochs = set(
+ list(range(t0 + 1, min(t0 + 6, total_epochs + 1))) +
+ [t0 + 8, t0 + 10, t0 + 15, t0 + 20] +
+ list(range(t0 + 10, total_epochs + 1, 10)) +
+ [total_epochs])
+
+ for epoch in range(t0 + 1, total_epochs + 1):
+ model.train()
+ if aux_net is not None:
+ aux_net.train() if aux_opt is not None else aux_net.eval()
+
+ tl, c, t = 0, 0, 0
+ epoch_aux_norms, epoch_dfa_norms = [], []
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); batch = x.size(0)
+ with torch.no_grad():
+ lo, hi = model(x, return_hidden=True); lv = F.cross_entropy(lo, y)
+ eT = lo.softmax(-1); eT[torch.arange(batch), y] -= 1; s = eT.detach()
+ hL = hi[-1].detach()
+
+ # ----------------------------------------------------------------
+ # Train auxiliary network (if applicable)
+ # ----------------------------------------------------------------
+ if aux_opt is not None:
+ if branch_type == 'blend_trainable':
+ # Standard VectorCreditNet: terminal matching + perturbation targets
+ t_L = torch.ones(batch, device=device)
+ a_term = aux_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ ce = F.cross_entropy(
+ model.out_head(model.out_ln(hL_req)), y, reduction='sum')
+ dL = torch.autograd.grad(ce, hL_req)[0].detach()
+ loss_term = ((a_term - dL) ** 2).sum(-1).mean()
+ lt = np.random.randint(0, L)
+ h_l = hi[lt].detach()
+ t_l = torch.full((batch,), lt / L, device=device)
+ a_l = aux_net(h_l, t_l, s)
+ lp2 = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(
+ model.forward_from_layer(h_l + eps_pert * v, lt),
+ y, reduction='none')
+ lm = F.cross_entropy(
+ model.forward_from_layer(h_l - eps_pert * v, lt),
+ y, reduction='none')
+ gj = (lp - lm) / (2 * eps_pert)
+ lp2 = lp2 + (((a_l * v).sum(-1) - gj.detach()) ** 2).mean()
+ lp2 /= M
+ vl = loss_term + lp2
+
+ elif branch_type == 'blend_perlayer':
+ # PerLayerVector: perturbation-based loss only (no terminal matching).
+ # v_l is the per-layer parameter (shared across all samples in batch).
+ # Also add terminal matching: a_L should match delta_L.
+ # Terminal matching: set_block(L-1) and match grad at last layer.
+ lt = np.random.randint(0, L)
+ h_l = hi[lt].detach()
+ t_l = torch.full((batch,), lt / L, device=device)
+ aux_net.set_block(lt)
+ a_l = aux_net(h_l, t_l, s) # (batch, d) — same v_lt broadcast
+ lp2 = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(
+ model.forward_from_layer(h_l + eps_pert * v, lt),
+ y, reduction='none')
+ lm = F.cross_entropy(
+ model.forward_from_layer(h_l - eps_pert * v, lt),
+ y, reduction='none')
+ gj = (lp - lm) / (2 * eps_pert)
+ # <v_l, v_dir> — v_l is shared across batch, v is per-sample
+ lp2 = lp2 + (((a_l * v).sum(-1) - gj.detach()) ** 2).mean()
+ lp2 /= M
+
+ # Terminal matching: v_{L-1} should approximate delta_L
+ aux_net.set_block(L - 1)
+ a_term = aux_net(hL, torch.ones(batch, device=device), s)
+ hL_req = hL.clone().requires_grad_(True)
+ ce = F.cross_entropy(
+ model.out_head(model.out_ln(hL_req)), y, reduction='sum')
+ dL = torch.autograd.grad(ce, hL_req)[0].detach()
+ loss_term = ((a_term - dL) ** 2).sum(-1).mean()
+
+ vl = lp2 + loss_term
+
+ else:
+ vl = None
+
+ if vl is not None:
+ aux_opt.zero_grad(); vl.backward()
+ torch.nn.utils.clip_grad_norm_(aux_net.parameters(), 1.0)
+ aux_opt.step()
+
+ # ----------------------------------------------------------------
+ # Compute credits for each block
+ # ----------------------------------------------------------------
+ dfa_credits = [(eT @ Bs[l].T).detach() for l in range(L)]
+ credits = []
+ for l in range(L):
+ a_dfa = dfa_credits[l]
+ rms_d = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+
+ if branch_type == 'dfa':
+ credits.append(a_dfa / rms_d)
+ else:
+ # All blend branches
+ h_l = hi[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ with torch.no_grad():
+ if isinstance(aux_net, PerLayerVector):
+ aux_net.set_block(l)
+ a_aux = aux_net(h_l, t_l, s).detach()
+ rms_v = (a_aux ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_blend = alpha * a_aux / rms_v + (1 - alpha) * a_dfa / rms_d
+ credits.append(a_blend)
+
+ # Track norms for alpha_eff
+ a_c = credits[-1]
+ if branch_type == 'dfa':
+ epoch_aux_norms.append(0.0)
+ epoch_dfa_norms.append(a_c.norm().item())
+ else:
+ a_dfa_n = a_dfa / rms_d
+ rms_v2 = (a_aux ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ epoch_aux_norms.append((alpha * a_aux / rms_v2).norm().item())
+ epoch_dfa_norms.append(((1 - alpha) * a_dfa_n).norm().item())
+
+ # ----------------------------------------------------------------
+ # Update output head (local exact gradient — allowed)
+ # ----------------------------------------------------------------
+ lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y)
+ head_opt.zero_grad(); lo2.backward(); head_opt.step()
+
+ # ----------------------------------------------------------------
+ # Update blocks with local surrogate
+ # ----------------------------------------------------------------
+ for l in range(L):
+ a = credits[l]
+ rm = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](hi[l].detach())
+ ll = (f * (a / rm)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ # Update embedding with block-0 credit
+ a0 = credits[0]
+ r0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ el = (model.embed(x) * (a0 / r0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+
+ tl += lv.item() * batch; c += (lo.argmax(1) == y).sum().item(); t += batch
+
+ for sch in scheds: sch.step()
+ ta = evaluate(model, test_loader, device)
+ log['test_acc'].append(ta); log['train_loss'].append(tl / t)
+
+ mean_aux = np.mean(epoch_aux_norms) if epoch_aux_norms else 0.0
+ mean_dfa = np.mean(epoch_dfa_norms) if epoch_dfa_norms else 1.0
+ aeff = mean_aux / (mean_aux + mean_dfa + 1e-12)
+ log['alpha_eff'].append((epoch, aeff))
+
+ if epoch in diag_epochs:
+ cm = 'blend' if branch_type != 'dfa' else 'dfa'
+ gamma, rho = compute_diagnostics(
+ model, aux_net if branch_type != 'dfa' else None,
+ Bs, test_loader, device, cm, alpha)
+ log['gamma'].append((epoch, gamma)); log['rho'].append((epoch, rho))
+ if epoch <= t0 + 15 or epoch % 20 == 0 or epoch == total_epochs:
+ print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}, "
+ f"G={gamma:.4f}, r={rho:.4f}, aeff={aeff:.3f}, alpha={alpha:.2f}")
+ elif epoch % 10 == 0 or epoch == total_epochs:
+ print(f" [{branch_name}] Ep {epoch}: acc={ta:.4f}, alpha={alpha:.2f}")
+
+ return log
+
+
+# ---------------------------------------------------------------------------
+# Main experiment
+# ---------------------------------------------------------------------------
+
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3; L = args.num_blocks; d = args.d_hidden
+
+ # ----------------------------------------------------------------
+ # Step 1: Train DFA and capture checkpoint at t0
+ # ----------------------------------------------------------------
+ print(f"\n{'='*60}\nTraining DFA baseline (checkpoint at t0={args.t0})\n{'='*60}")
+ model_dfa = ResidualMLP(input_dim, d, 10, L).to(device)
+ Bs, ckpt = train_dfa_get_checkpoint(
+ model_dfa, train_loader, test_loader, device,
+ args.epochs, args.t0, args.lr, args.wd)
+ print(f" Checkpoint acc at t0={args.t0}: {ckpt['acc']:.4f}")
+
+ # ----------------------------------------------------------------
+ # Step 2: Define and run all 9 branches
+ # ----------------------------------------------------------------
+ VEC_SEED = args.seed + 7777
+
+ def make_vec():
+ torch.manual_seed(VEC_SEED)
+ return VectorCreditNet(d_hidden=d, s_dim=10).to(device)
+
+ def make_perlayer():
+ torch.manual_seed(VEC_SEED)
+ return PerLayerVector(d_hidden=d, num_blocks=L).to(device)
+
+ # (name, branch_type, aux_factory, alpha)
+ ALPHAS = [0.25, 0.50, 0.75, 0.90]
+ branches = [('continue_DFA', 'dfa', lambda: None, 0.0)]
+ for a in ALPHAS:
+ tag = f"{int(a*100):03d}"
+ branches.append((f'blend_perlayer_vector_alpha{tag}', 'blend_perlayer', make_perlayer, a))
+ for a in ALPHAS:
+ tag = f"{int(a*100):03d}"
+ branches.append((f'blend_random_trainable_alpha{tag}', 'blend_trainable', make_vec, a))
+
+ all_results = {}
+ for bname, btype, aux_factory, alpha in branches:
+ print(f"\n{'='*60}\n{bname}\n{'='*60}")
+ model_b = ResidualMLP(input_dim, d, 10, L).to(device)
+ model_b.load_state_dict(ckpt['model'])
+ aux_net_b = aux_factory()
+
+ log = run_branch(
+ model_b, aux_net_b, ckpt['Bs'],
+ train_loader, test_loader, device,
+ args.t0, args.epochs, btype,
+ alpha, args.lr, args.lr_fb, args.wd, args.M,
+ branch_name=bname)
+ all_results[bname] = log
+ all_results[bname]['alpha'] = alpha
+ print(f" {bname} final acc: {log['test_acc'][-1]:.4f}")
+
+ # ----------------------------------------------------------------
+ # Step 3: Summary table
+ # ----------------------------------------------------------------
+ dfa_final = all_results['continue_DFA']['test_acc'][-1]
+
+ print(f"\n{'='*95}")
+ print("SUMMARY — Phase 10A.8B: Alpha Sweep")
+ print(f"{'='*95}")
+ print(f"{'Branch':<40} {'alpha':>5} {'@20':>6} {'final':>7} {'diff vs DFA':>11}")
+ print("-" * 73)
+
+ for bname, log in all_results.items():
+ accs = log['test_acc']
+ alpha = log['alpha']
+ idx20 = max(0, 20 - args.t0 - 1)
+ acc20 = accs[idx20] if len(accs) > idx20 else accs[-1]
+ final = accs[-1]
+ diff = final - dfa_final
+ print(f"{bname:<40} {alpha:>5.2f} {acc20:>6.4f} {final:>7.4f} {diff:>+11.4f}")
+
+ # ----------------------------------------------------------------
+ # Step 4: Optimal alpha per method type
+ # ----------------------------------------------------------------
+ print(f"\n{'='*60}")
+ print("OPTIMAL ALPHA PER METHOD TYPE")
+ print(f"{'='*60}")
+
+ # PerLayerVector branches
+ perlayer_results = {
+ bname: log for bname, log in all_results.items()
+ if bname.startswith('blend_perlayer_vector_alpha')}
+ if perlayer_results:
+ best_plv = max(perlayer_results.items(), key=lambda kv: kv[1]['test_acc'][-1])
+ print(f" PerLayerVector best alpha: {best_plv[1]['alpha']:.2f} "
+ f"(branch={best_plv[0]}, final={best_plv[1]['test_acc'][-1]:.4f})")
+ for bname in sorted(perlayer_results.keys()):
+ log = perlayer_results[bname]
+ diff = log['test_acc'][-1] - dfa_final
+ print(f" alpha={log['alpha']:.2f}: final={log['test_acc'][-1]:.4f} "
+ f"({diff:+.4f} vs DFA)")
+
+ print()
+
+ # VectorCreditNet (random_trainable) branches
+ trainable_results = {
+ bname: log for bname, log in all_results.items()
+ if bname.startswith('blend_random_trainable_alpha')}
+ if trainable_results:
+ best_rt = max(trainable_results.items(), key=lambda kv: kv[1]['test_acc'][-1])
+ print(f" VectorCreditNet best alpha: {best_rt[1]['alpha']:.2f} "
+ f"(branch={best_rt[0]}, final={best_rt[1]['test_acc'][-1]:.4f})")
+ for bname in sorted(trainable_results.keys()):
+ log = trainable_results[bname]
+ diff = log['test_acc'][-1] - dfa_final
+ print(f" alpha={log['alpha']:.2f}: final={log['test_acc'][-1]:.4f} "
+ f"({diff:+.4f} vs DFA)")
+
+ # ----------------------------------------------------------------
+ # Step 5: Save results
+ # ----------------------------------------------------------------
+ save_data = {
+ 'args': vars(args),
+ 'dfa_ckpt_acc': float(ckpt['acc']),
+ 'dfa_final_acc': float(dfa_final),
+ }
+ for bname, log in all_results.items():
+ save_data[bname] = {
+ 'alpha': log['alpha'],
+ 'test_acc': log['test_acc'],
+ 'train_loss': log['train_loss'],
+ 'gamma': log['gamma'],
+ 'rho': log['rho'],
+ 'alpha_eff': log['alpha_eff'],
+ }
+ out_path = os.path.join(args.output_dir,
+ f'alpha_sweep_t{args.t0}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Phase 10A.8B: Alpha Sweep Scaffold')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--t0', type=int, default=5)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--M', type=int, default=4)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/alpha_sweep_scaffold')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/freeze_with_decay.py b/experiments/freeze_with_decay.py
new file mode 100644
index 0000000..b8d7ab2
--- /dev/null
+++ b/experiments/freeze_with_decay.py
@@ -0,0 +1,653 @@
+"""
+Phase 10A.8A: Freeze with Alpha Decay.
+
+Core question: After freezing Vec, can linearly decaying alpha (fading out the
+frozen Vec and returning to pure DFA) recover or improve over a fixed-alpha frozen blend?
+
+8 branches from the same DFA checkpoint at t0=5:
+1. continue_DFA — pure DFA baseline
+2. blend_random_trainable_alpha075 — standard reference (always trainable, alpha=0.75)
+3. freeze_after_1_fixed075 — train Vec 1 epoch, freeze, keep alpha=0.75
+4. freeze_after_5_fixed075 — train Vec 5 epochs, freeze, keep alpha=0.75
+5. freeze_after_1_decay_to_025 — train Vec 1 epoch, freeze, then decay alpha 0.75->0.25 over 5 epochs
+6. freeze_after_5_decay_to_025 — train Vec 5 epochs, freeze, then decay alpha 0.75->0.25 over 5 epochs
+7. freeze_after_1_decay_to_000 — train Vec 1 epoch, freeze, then decay alpha 0.75->0.0 over 5 epochs
+8. freeze_after_5_decay_to_000 — train Vec 5 epochs, freeze, then decay alpha 0.75->0.0 over 5 epochs
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import SinusoidalTimeEmbed
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation
+
+
+# ---------------------------------------------------------------------------
+# Auxiliary network
+# ---------------------------------------------------------------------------
+
+class VectorCreditNet(nn.Module):
+ """Standard Vec: takes (h, t, s) -> d_hidden credit vector."""
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ return self.net(torch.cat([self.ln(h), self.time_embed(t), s], dim=-1))
+
+
+# ---------------------------------------------------------------------------
+# Alpha schedule helpers
+# ---------------------------------------------------------------------------
+
+def make_alpha_schedule(freeze_epoch, initial_alpha, target_alpha, decay_window):
+ """
+ Returns a function alpha_fn(epoch, t0) -> current alpha.
+
+ Before freeze_epoch training epochs have passed, alpha = initial_alpha.
+ After freeze_epoch training epochs, linearly decay from initial_alpha to
+ target_alpha over decay_window epochs, then stay at target_alpha.
+
+ epoch is the absolute epoch number; t0 is the DFA checkpoint epoch.
+ Training epochs elapsed since handoff = epoch - t0.
+ """
+ def alpha_fn(epoch, t0):
+ elapsed = epoch - t0 # epochs since handoff (1-indexed)
+ if elapsed <= freeze_epoch:
+ return initial_alpha
+ # epochs after freeze
+ after_freeze = elapsed - freeze_epoch
+ if decay_window <= 0 or target_alpha == initial_alpha:
+ return target_alpha
+ progress = min(after_freeze / decay_window, 1.0)
+ return initial_alpha + (target_alpha - initial_alpha) * progress
+ return alpha_fn
+
+
+# ---------------------------------------------------------------------------
+# Data
+# ---------------------------------------------------------------------------
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))])
+ trainset = torchvision.datasets.CIFAR10(
+ root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(
+ root='./data', train=False, download=True, transform=transform_test)
+ return (DataLoader(trainset, batch_size=batch_size, shuffle=True,
+ num_workers=4, pin_memory=True),
+ DataLoader(testset, batch_size=batch_size, shuffle=False,
+ num_workers=4, pin_memory=True))
+
+
+# ---------------------------------------------------------------------------
+# Evaluation helpers
+# ---------------------------------------------------------------------------
+
+def evaluate(model, test_loader, device):
+ model.eval(); c, t = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ c += (model(x).argmax(1) == y).sum().item(); t += x.size(0)
+ return c / t
+
+
+def compute_diagnostics(model, aux_net, Bs, test_loader, device, credit_mode, alpha=0.75):
+ """Compute mean Gamma (BP cosine) and mean rho (perturbation correlation)."""
+ model.eval()
+ if aux_net is not None:
+ aux_net.eval()
+ L = model.num_blocks
+
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); break
+ batch = x.size(0)
+
+ # BP pass for hidden gradients (offline eval only, not used for training)
+ was_frozen = not next(model.parameters()).requires_grad
+ if was_frozen:
+ for p in model.parameters(): p.requires_grad_(True)
+ model.zero_grad()
+ lo, hbp = model(x, return_hidden=True)
+ for l in range(L + 1): hbp[l].retain_grad()
+ F.cross_entropy(lo, y).backward()
+ bp = {l: hbp[l].grad.detach().clone() for l in range(L + 1)}
+ if was_frozen:
+ for p in model.parameters(): p.requires_grad_(False)
+
+ with torch.no_grad():
+ lo2, hi = model(x, return_hidden=True)
+ eT = lo2.softmax(-1); eT[torch.arange(batch), y] -= 1; s = eT.detach()
+
+ gammas, rhos = [], []
+ for l in range(L):
+ h_l = hi[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ if credit_mode == 'dfa':
+ a_l = (s @ Bs[l].T).detach()
+ elif credit_mode == 'blend' and aux_net is not None:
+ a_dfa = (s @ Bs[l].T).detach()
+ a_aux = aux_net(h_l, t_l, s).detach()
+ rd = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ rv = (a_aux ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_l = alpha * a_aux / rv + (1 - alpha) * a_dfa / rd
+ else:
+ a_l = (s @ Bs[l].T).detach()
+
+ gammas.append(cosine_similarity_batch(a_l, bp[l]))
+
+ def make_fwd(sl):
+ def f(h):
+ with torch.no_grad():
+ c = h
+ for i in range(sl, L):
+ c = c + model.blocks[i](c)
+ return F.cross_entropy(
+ model.out_head(model.out_ln(c)), y, reduction='none')
+ return f
+
+ rhos.append(perturbation_correlation(h_l, a_l, make_fwd(l), epsilon=1e-3, M=16))
+
+ return float(np.mean(gammas)), float(np.mean(rhos))
+
+
+# ---------------------------------------------------------------------------
+# DFA training + checkpoint
+# ---------------------------------------------------------------------------
+
+def train_dfa_get_checkpoint(model, train_loader, test_loader, device,
+ total_epochs, t0, lr, wd):
+ d = model.d_hidden; L = model.num_blocks
+ Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd)
+ for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ scheds = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=total_epochs)
+ for o in block_opts] +
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=total_epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=total_epochs)])
+ ckpt = None
+ for epoch in range(1, total_epochs + 1):
+ model.train(); tl, c, t = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); b = x.size(0)
+ with torch.no_grad():
+ lo, hi = model(x, return_hidden=True); lv = F.cross_entropy(lo, y)
+ eT = lo.softmax(-1); eT[torch.arange(b), y] -= 1
+ hL = hi[-1].detach()
+ lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y)
+ head_opt.zero_grad(); lo2.backward(); head_opt.step()
+ for l in range(L):
+ a = (eT @ Bs[l].T).detach()
+ rm = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](hi[l].detach())
+ ll = (f * (a / rm)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a0 = (eT @ Bs[0].T).detach()
+ r0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ el = (model.embed(x) * (a0 / r0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+ tl += lv.item() * b; c += (lo.argmax(1) == y).sum().item(); t += b
+ for s in scheds: s.step()
+ if epoch == t0:
+ acc = evaluate(model, test_loader, device)
+ ckpt = {'model': copy.deepcopy(model.state_dict()),
+ 'Bs': [B.clone() for B in Bs], 'acc': acc}
+ print(f" [DFA] Checkpoint at epoch {t0}: acc={acc:.4f}")
+ if epoch % 10 == 0:
+ print(f" [DFA] Epoch {epoch}: acc={evaluate(model, test_loader, device):.4f}")
+ return Bs, ckpt
+
+
+# ---------------------------------------------------------------------------
+# Branch runner
+# ---------------------------------------------------------------------------
+
+def run_branch(model, aux_net, Bs, train_loader, test_loader, device,
+ t0, total_epochs, branch_type, alpha_schedule_fn,
+ lr, lr_fb, wd, M, branch_name='', freeze_epoch=None):
+ """
+ Run a training branch from a loaded checkpoint.
+
+ branch_type:
+ 'dfa' — pure DFA, no aux
+ 'blend' — blend DFA + Vec; aux_net trained online if vec_opt active
+ 'blend_frozen' — blend DFA + frozen Vec; Vec trained for freeze_epoch epochs then frozen
+
+ alpha_schedule_fn(epoch, t0) -> float: returns alpha at each absolute epoch.
+ freeze_epoch: int — for 'blend_frozen', number of epochs to train Vec before freezing.
+ """
+ d = model.d_hidden; L = model.num_blocks; eps_pert = 1e-3
+
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd)
+ for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(
+ list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+
+ if branch_type != 'dfa' and aux_net is not None:
+ vec_opt = optim.Adam(aux_net.parameters(), lr=lr_fb)
+ else:
+ vec_opt = None
+
+ scheds = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=total_epochs)
+ for o in block_opts] +
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=total_epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=total_epochs)])
+ # Advance schedulers to match checkpoint epoch
+ for _ in range(t0):
+ for s in scheds: s.step()
+
+ log = {'test_acc': [], 'train_loss': [], 'gamma': [], 'rho': [], 'alpha_eff': []}
+ diag_epochs = set(
+ list(range(t0 + 1, min(t0 + 6, total_epochs + 1))) +
+ [t0 + 8, t0 + 10, t0 + 15, t0 + 20] +
+ list(range(t0 + 10, total_epochs + 1, 10)) +
+ [total_epochs])
+
+ vec_frozen = False # whether Vec has been frozen
+
+ for epoch in range(t0 + 1, total_epochs + 1):
+ # Handle freeze: freeze Vec after freeze_epoch training epochs
+ if (branch_type == 'blend_frozen' and freeze_epoch is not None
+ and not vec_frozen):
+ elapsed = epoch - t0 # training epochs since handoff (1-indexed)
+ if elapsed > freeze_epoch:
+ if aux_net is not None:
+ aux_net.requires_grad_(False)
+ aux_net.eval()
+ vec_opt = None
+ vec_frozen = True
+ print(f" [{branch_name}] Freezing Vec at epoch {epoch} "
+ f"(after {freeze_epoch} training epochs)")
+
+ # Compute alpha for this epoch
+ cur_alpha = alpha_schedule_fn(epoch, t0)
+
+ model.train()
+ if aux_net is not None:
+ if vec_opt is not None:
+ aux_net.train()
+ else:
+ aux_net.eval()
+
+ tl, c, t = 0, 0, 0
+ epoch_aux_norms, epoch_dfa_norms = [], []
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); batch = x.size(0)
+ with torch.no_grad():
+ lo, hi = model(x, return_hidden=True); lv = F.cross_entropy(lo, y)
+ eT = lo.softmax(-1); eT[torch.arange(batch), y] -= 1; s = eT.detach()
+ hL = hi[-1].detach()
+
+ # ----------------------------------------------------------------
+ # Train Vec with standard perturbation targets (if applicable)
+ # ----------------------------------------------------------------
+ if vec_opt is not None and aux_net is not None:
+ t_L = torch.ones(batch, device=device)
+ a_term = aux_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ ce = F.cross_entropy(
+ model.out_head(model.out_ln(hL_req)), y, reduction='sum')
+ dL = torch.autograd.grad(ce, hL_req)[0].detach()
+ loss_term = ((a_term - dL) ** 2).sum(-1).mean()
+ lt = np.random.randint(0, L)
+ h_l = hi[lt].detach()
+ t_l = torch.full((batch,), lt / L, device=device)
+ a_l = aux_net(h_l, t_l, s)
+ lp2 = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(
+ model.forward_from_layer(h_l + eps_pert * v, lt),
+ y, reduction='none')
+ lm = F.cross_entropy(
+ model.forward_from_layer(h_l - eps_pert * v, lt),
+ y, reduction='none')
+ gj = (lp - lm) / (2 * eps_pert)
+ lp2 = lp2 + (((a_l * v).sum(-1) - gj.detach()) ** 2).mean()
+ lp2 /= M
+ vl = loss_term + lp2
+ vec_opt.zero_grad(); vl.backward()
+ torch.nn.utils.clip_grad_norm_(aux_net.parameters(), 1.0)
+ vec_opt.step()
+
+ # ----------------------------------------------------------------
+ # Compute credits for each block
+ # ----------------------------------------------------------------
+ credits = []
+ for l in range(L):
+ a_dfa = (eT @ Bs[l].T).detach()
+ rms_d = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+
+ if branch_type == 'dfa' or aux_net is None or cur_alpha == 0.0:
+ credits.append(a_dfa / rms_d)
+ epoch_aux_norms.append(0.0)
+ epoch_dfa_norms.append((a_dfa / rms_d).norm().item())
+ else:
+ h_l = hi[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ with torch.no_grad():
+ a_aux = aux_net(h_l, t_l, s).detach()
+ rms_v = (a_aux ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_blend = cur_alpha * a_aux / rms_v + (1 - cur_alpha) * a_dfa / rms_d
+ credits.append(a_blend)
+ epoch_aux_norms.append((cur_alpha * a_aux / rms_v).norm().item())
+ epoch_dfa_norms.append(((1 - cur_alpha) * a_dfa / rms_d).norm().item())
+
+ # ----------------------------------------------------------------
+ # Update output head
+ # ----------------------------------------------------------------
+ lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y)
+ head_opt.zero_grad(); lo2.backward(); head_opt.step()
+
+ # ----------------------------------------------------------------
+ # Update blocks with local surrogate
+ # ----------------------------------------------------------------
+ for l in range(L):
+ a = credits[l]
+ rm = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](hi[l].detach())
+ ll = (f * (a / rm)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ # Update embedding with block-0 credit
+ a0 = credits[0]
+ r0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ el = (model.embed(x) * (a0 / r0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+
+ tl += lv.item() * batch; c += (lo.argmax(1) == y).sum().item(); t += batch
+
+ for sch in scheds: sch.step()
+ ta = evaluate(model, test_loader, device)
+ log['test_acc'].append(ta); log['train_loss'].append(tl / t)
+
+ mean_aux = np.mean(epoch_aux_norms) if epoch_aux_norms else 0.0
+ mean_dfa = np.mean(epoch_dfa_norms) if epoch_dfa_norms else 1.0
+ aeff = mean_aux / (mean_aux + mean_dfa + 1e-12)
+ log['alpha_eff'].append((epoch, aeff))
+
+ if epoch in diag_epochs:
+ cm = 'blend' if (branch_type != 'dfa' and aux_net is not None
+ and cur_alpha > 0.0) else 'dfa'
+ diag_aux = aux_net if cm == 'blend' else None
+ gamma, rho = compute_diagnostics(
+ model, diag_aux, Bs, test_loader, device, cm, cur_alpha)
+ log['gamma'].append((epoch, gamma)); log['rho'].append((epoch, rho))
+ if epoch <= t0 + 15 or epoch % 20 == 0 or epoch == total_epochs:
+ frozen_str = ' [FROZEN]' if vec_frozen else ''
+ print(f" [{branch_name}]{frozen_str} Ep {epoch}: acc={ta:.4f}, "
+ f"G={gamma:.4f}, r={rho:.4f}, aeff={aeff:.3f}, alpha={cur_alpha:.3f}")
+ elif epoch % 10 == 0 or epoch == total_epochs:
+ frozen_str = ' [FROZEN]' if vec_frozen else ''
+ print(f" [{branch_name}]{frozen_str} Ep {epoch}: acc={ta:.4f}, "
+ f"alpha={cur_alpha:.3f}")
+
+ return log
+
+
+# ---------------------------------------------------------------------------
+# Main experiment
+# ---------------------------------------------------------------------------
+
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+ torch.manual_seed(args.seed); np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3; L = args.num_blocks; d = args.d_hidden
+
+ # ----------------------------------------------------------------
+ # Step 1: Train DFA and capture checkpoint at t0
+ # ----------------------------------------------------------------
+ print(f"\n{'='*60}\nTraining DFA baseline (checkpoint at t0={args.t0})\n{'='*60}")
+ model_dfa = ResidualMLP(input_dim, d, 10, L).to(device)
+ Bs, ckpt = train_dfa_get_checkpoint(
+ model_dfa, train_loader, test_loader, device,
+ args.epochs, args.t0, args.lr, args.wd)
+ print(f" Checkpoint acc at t0={args.t0}: {ckpt['acc']:.4f}")
+
+ # ----------------------------------------------------------------
+ # Step 2: Define branches
+ # ----------------------------------------------------------------
+ VEC_SEED = args.seed + 7777
+ DECAY_WINDOW = 5
+
+ def make_vec():
+ torch.manual_seed(VEC_SEED)
+ return VectorCreditNet(d_hidden=d, s_dim=10).to(device)
+
+ # constant alpha
+ def fixed_alpha(a):
+ return lambda epoch, t0: a
+
+ # (name, branch_type, aux_factory, freeze_epoch, alpha_schedule_fn)
+ branches = [
+ ('continue_DFA',
+ 'dfa', lambda: None, None,
+ fixed_alpha(0.0)),
+
+ ('blend_random_trainable_alpha075',
+ 'blend', make_vec, None,
+ fixed_alpha(0.75)),
+
+ ('freeze_after_1_fixed075',
+ 'blend_frozen', make_vec, 1,
+ fixed_alpha(0.75)),
+
+ ('freeze_after_5_fixed075',
+ 'blend_frozen', make_vec, 5,
+ fixed_alpha(0.75)),
+
+ ('freeze_after_1_decay_to_025',
+ 'blend_frozen', make_vec, 1,
+ make_alpha_schedule(freeze_epoch=1, initial_alpha=0.75,
+ target_alpha=0.25, decay_window=DECAY_WINDOW)),
+
+ ('freeze_after_5_decay_to_025',
+ 'blend_frozen', make_vec, 5,
+ make_alpha_schedule(freeze_epoch=5, initial_alpha=0.75,
+ target_alpha=0.25, decay_window=DECAY_WINDOW)),
+
+ ('freeze_after_1_decay_to_000',
+ 'blend_frozen', make_vec, 1,
+ make_alpha_schedule(freeze_epoch=1, initial_alpha=0.75,
+ target_alpha=0.0, decay_window=DECAY_WINDOW)),
+
+ ('freeze_after_5_decay_to_000',
+ 'blend_frozen', make_vec, 5,
+ make_alpha_schedule(freeze_epoch=5, initial_alpha=0.75,
+ target_alpha=0.0, decay_window=DECAY_WINDOW)),
+ ]
+
+ # ----------------------------------------------------------------
+ # Step 3: Run all branches
+ # ----------------------------------------------------------------
+ all_results = {}
+ for bname, btype, aux_factory, freeze_ep, alpha_fn in branches:
+ print(f"\n{'='*60}\n{bname}\n{'='*60}")
+ model_b = ResidualMLP(input_dim, d, 10, L).to(device)
+ model_b.load_state_dict(ckpt['model'])
+ aux_net_b = aux_factory()
+
+ log = run_branch(
+ model_b, aux_net_b, ckpt['Bs'],
+ train_loader, test_loader, device,
+ args.t0, args.epochs, btype,
+ alpha_fn, args.lr, args.lr_fb, args.wd, args.M,
+ branch_name=bname,
+ freeze_epoch=freeze_ep)
+ all_results[bname] = log
+ print(f" {bname} final acc: {log['test_acc'][-1]:.4f}")
+
+ # ----------------------------------------------------------------
+ # Step 4: Summary table
+ # ----------------------------------------------------------------
+ dfa_final = all_results['continue_DFA']['test_acc'][-1]
+
+ print(f"\n{'='*95}")
+ print("SUMMARY — Phase 10A.8A: Freeze with Alpha Decay")
+ print(f"{'='*95}")
+ print(f"{'Branch':<38} {'@20':>6} {'final':>7} {'diff':>7} "
+ f"{'mG_5:15':>9} {'mr_5:15':>9} {'aeff':>7}")
+ print("-" * 85)
+
+ for bname, log in all_results.items():
+ accs = log['test_acc']
+ idx20 = max(0, 20 - args.t0 - 1)
+ acc20 = accs[idx20] if len(accs) > idx20 else accs[-1]
+ final = accs[-1]
+ diff = final - dfa_final
+ gammas_e = [g for e, g in log['gamma'] if args.t0 < e <= args.t0 + 15]
+ rhos_e = [r for e, r in log['rho'] if args.t0 < e <= args.t0 + 15]
+ aeffs_e = [a for e, a in log['alpha_eff'] if args.t0 < e <= args.t0 + 15]
+ mg = float(np.mean(gammas_e)) if gammas_e else float('nan')
+ mr = float(np.mean(rhos_e)) if rhos_e else float('nan')
+ mae = float(np.mean(aeffs_e)) if aeffs_e else float('nan')
+ print(f"{bname:<38} {acc20:>6.4f} {final:>7.4f} {diff:>+7.4f} "
+ f"{mg:>9.4f} {mr:>9.4f} {mae:>7.3f}")
+
+ # ----------------------------------------------------------------
+ # Step 5: Save results
+ # ----------------------------------------------------------------
+ save_data = {'args': vars(args), 'dfa_ckpt_acc': float(ckpt['acc'])}
+ for bname, log in all_results.items():
+ save_data[bname] = {
+ 'test_acc': log['test_acc'],
+ 'train_loss': log['train_loss'],
+ 'gamma': log['gamma'],
+ 'rho': log['rho'],
+ 'alpha_eff': log['alpha_eff'],
+ }
+ out_path = os.path.join(args.output_dir,
+ f'freeze_with_decay_t{args.t0}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+ # ----------------------------------------------------------------
+ # Step 6: Judgment
+ # ----------------------------------------------------------------
+ print(f"\n{'='*60}\nJUDGMENT\n{'='*60}")
+ r = {bname: log['test_acc'][-1] for bname, log in all_results.items()}
+ dfa = r['continue_DFA']
+ ref = r.get('blend_random_trainable_alpha075', float('nan'))
+ f1 = r.get('freeze_after_1_fixed075', float('nan'))
+ f5 = r.get('freeze_after_5_fixed075', float('nan'))
+ f1d25 = r.get('freeze_after_1_decay_to_025', float('nan'))
+ f5d25 = r.get('freeze_after_5_decay_to_025', float('nan'))
+ f1d00 = r.get('freeze_after_1_decay_to_000', float('nan'))
+ f5d00 = r.get('freeze_after_5_decay_to_000', float('nan'))
+
+ print(f" DFA={dfa:.4f} ref={ref:.4f}")
+ print(f" freeze1_fixed={f1:.4f} freeze5_fixed={f5:.4f}")
+ print(f" freeze1_to025={f1d25:.4f} freeze5_to025={f5d25:.4f}")
+ print(f" freeze1_to000={f1d00:.4f} freeze5_to000={f5d00:.4f}")
+
+ thr = 0.003
+
+ # Fixed vs trainable reference
+ best_fixed = max(f1, f5)
+ if best_fixed > ref - thr:
+ print(f"\n -> Best frozen-fixed ({best_fixed:.4f}) ≈ trainable reference: "
+ "freezing early is sufficient; ongoing Vec training adds no value")
+ elif ref > best_fixed + thr:
+ print(f"\n -> Trainable reference ({ref:.4f}) > best frozen-fixed ({best_fixed:.4f}): "
+ "continuous Vec adaptation helps")
+
+ # Effect of more training before freeze
+ if f5 > f1 + thr:
+ print(f" -> More Vec training before freeze helps: "
+ f"5ep ({f5:.4f}) > 1ep ({f1:.4f})")
+ else:
+ print(f" -> Freeze timing (1 vs 5 epochs) makes little difference: "
+ f"f1={f1:.4f} f5={f5:.4f}")
+
+ # Effect of decay on fixed-freeze branches
+ print(f"\n Decay effect (vs fixed075):")
+ for label, fixed_v, d25, d00 in [
+ ('freeze_after_1', f1, f1d25, f1d00),
+ ('freeze_after_5', f5, f5d25, f5d00)]:
+ print(f" {label}: fixed={fixed_v:.4f} ->0.25={d25:.4f} ->0.00={d00:.4f}")
+ if d25 > fixed_v + thr:
+ print(f" -> decay to 0.25 helps vs fixed ({d25-fixed_v:+.4f})")
+ if d00 > fixed_v + thr:
+ print(f" -> decay to 0.0 (full DFA) helps vs fixed ({d00-fixed_v:+.4f})")
+ if d00 > d25 + thr:
+ print(f" -> faster decay (to 0) better than partial ({d00-d25:+.4f})")
+ elif d25 > d00 + thr:
+ print(f" -> partial decay (to 0.25) better than full decay ({d25-d00:+.4f})")
+
+ # Overall winner
+ best_name = max(r, key=r.get)
+ print(f"\n Best branch: {best_name} = {r[best_name]:.4f} "
+ f"(+{r[best_name]-dfa:+.4f} vs DFA)")
+
+
+def main():
+ parser = argparse.ArgumentParser(
+ description='Phase 10A.8A: Freeze with Alpha Decay')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--t0', type=int, default=5)
+ parser.add_argument('--alpha', type=float, default=0.75)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--M', type=int, default=4)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=2)
+ parser.add_argument('--output_dir', type=str, default='results/freeze_with_decay')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()