summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
Diffstat (limited to 'experiments')
-rw-r--r--experiments/checkpointed_handoff.py617
1 files changed, 617 insertions, 0 deletions
diff --git a/experiments/checkpointed_handoff.py b/experiments/checkpointed_handoff.py
new file mode 100644
index 0000000..3057825
--- /dev/null
+++ b/experiments/checkpointed_handoff.py
@@ -0,0 +1,617 @@
+"""
+Phase 9A: Checkpointed Offline Handoff.
+
+Core question: if we offline-train Vec on a DFA trajectory checkpoint,
+can it take over and outperform continuing with DFA?
+
+Steps:
+1. Train DFA baseline, save checkpoints at t0={1,5,10}
+2. At each checkpoint, freeze forward net and offline-train Vec_eT_M4
+3. From each checkpoint, branch into: continue_DFA, handoff_to_Vec, blends
+4. Compare trajectories
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import SinusoidalTimeEmbed
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ c, t = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ c += (model(x).argmax(1) == y).sum().item(); t += x.size(0)
+ return c / t
+
+
+def compute_diagnostics(model, vector_net, dfa_Bs, test_loader, device, credit_mode):
+ """Compute mean Gamma and rho for current credit source."""
+ model.eval()
+ if vector_net is not None:
+ vector_net.eval()
+ L = model.num_blocks
+
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); break
+ batch = x.size(0)
+
+ # BP gradients (eval only) — temporarily enable requires_grad
+ was_frozen = not next(model.parameters()).requires_grad
+ if was_frozen:
+ for p in model.parameters(): p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hbp = model(x, return_hidden=True)
+ for l in range(L + 1): hbp[l].retain_grad()
+ F.cross_entropy(logits_bp, y).backward()
+ bp_grads = {l: hbp[l].grad.detach().clone() for l in range(L + 1)}
+ if was_frozen:
+ for p in model.parameters(): p.requires_grad_(False)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ gammas, rhos = [], []
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ if credit_mode == 'dfa':
+ a_l = (s @ dfa_Bs[l].T).detach()
+ elif credit_mode == 'vec':
+ a_l = vector_net(h_l, t_l, s).detach()
+ elif isinstance(credit_mode, float):
+ alpha = credit_mode
+ a_dfa = (s @ dfa_Bs[l].T).detach()
+ a_vec = vector_net(h_l, t_l, s).detach()
+ rms_v = (a_vec**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ rms_d = (a_dfa**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_l = alpha * a_vec / rms_v + (1 - alpha) * a_dfa / rms_d
+
+ gammas.append(cosine_similarity_batch(a_l, bp_grads[l]))
+ def make_fwd(sl):
+ def f(h):
+ with torch.no_grad():
+ c = h
+ for i in range(sl, L): c = c + model.blocks[i](c)
+ return F.cross_entropy(model.out_head(model.out_ln(c)), y, reduction='none')
+ return f
+ rhos.append(perturbation_correlation(h_l, a_l, make_fwd(l), epsilon=1e-3, M=16))
+
+ return float(np.mean(gammas)), float(np.mean(rhos))
+
+
+# =============================================================================
+# Step 1: Train DFA with checkpoints
+# =============================================================================
+def train_dfa_with_checkpoints(model, train_loader, test_loader, device,
+ epochs, save_epochs, ckpt_dir, lr=1e-3, wd=0.01):
+ os.makedirs(ckpt_dir, exist_ok=True)
+ d = model.d_hidden
+ L = model.num_blocks
+ Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+
+ for epoch in range(1, epochs + 1):
+ model.train()
+ total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(batch), y] -= 1
+ hL = hiddens[-1].detach()
+ loss_out = F.cross_entropy(model.out_head(model.out_ln(hL)), y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ for l in range(L):
+ a = (e_T @ Bs[l].T).detach()
+ rms = (a**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](hiddens[l].detach())
+ ll = (f * (a / rms)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+ a0 = (e_T @ Bs[0].T).detach()
+ rms0 = (a0**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ el = (model.embed(x) * (a0 / rms0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+ for s in scheds: s.step()
+
+ if epoch in save_epochs:
+ acc = evaluate(model, test_loader, device)
+ ckpt = {
+ 'model': model.state_dict(),
+ 'Bs': [B.cpu() for B in Bs],
+ 'epoch': epoch, 'acc': acc,
+ }
+ torch.save(ckpt, os.path.join(ckpt_dir, f'dfa_epoch_{epoch}.pt'))
+ print(f" [DFA] Saved epoch {epoch} (acc={acc:.4f})")
+ elif epoch % 10 == 0:
+ acc = evaluate(model, test_loader, device)
+ print(f" [DFA] Epoch {epoch}: acc={acc:.4f}")
+
+ # Save final
+ final_acc = evaluate(model, test_loader, device)
+ ckpt = {'model': model.state_dict(), 'Bs': [B.cpu() for B in Bs],
+ 'epoch': epochs, 'acc': final_acc}
+ torch.save(ckpt, os.path.join(ckpt_dir, f'dfa_epoch_{epochs}.pt'))
+ return Bs, final_acc
+
+
+# =============================================================================
+# Step 2: Offline-fit Vec on frozen checkpoint
+# =============================================================================
+def offline_fit_vec(model, train_loader, device, epochs=60, lr_fb=1e-3, M=4):
+ d = model.d_hidden
+ L = model.num_blocks
+ vec_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb)
+ eps = 1e-3
+ model.eval()
+ for ep in range(1, epochs + 1):
+ vec_net.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ hL = hiddens[-1].detach()
+ t_L = torch.ones(batch, device=device)
+ a_term = vec_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_term - delta_L)**2).sum(-1).mean()
+
+ l = np.random.randint(0, L)
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ a_l = vec_net(h_l, t_l, s)
+ loss_proj = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(model.forward_from_layer(h_l + eps*v, l), y, reduction='none')
+ lm = F.cross_entropy(model.forward_from_layer(h_l - eps*v, l), y, reduction='none')
+ g_j = (lp - lm) / (2*eps)
+ loss_proj = loss_proj + (((a_l*v).sum(-1) - g_j.detach())**2).mean()
+ loss_proj /= M
+ vloss = loss_term + loss_proj
+ vec_opt.zero_grad(); vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vec_net.parameters(), 1.0)
+ vec_opt.step()
+ if ep % 20 == 0 or ep == 1:
+ print(f" [Vec fit] Ep {ep}")
+ return vec_net
+
+
+# =============================================================================
+# Step 3: Continue training from checkpoint with a given credit schedule
+# =============================================================================
+def continue_training(model, vector_net, Bs, train_loader, test_loader, device,
+ start_epoch, total_epochs, credit_mode, lr=1e-3, lr_fb=1e-3,
+ wd=0.01, M=4, branch_name=''):
+ """
+ Continue training from a checkpoint.
+ credit_mode: 'dfa', 'vec', or float (blend alpha for Vec)
+ """
+ d = model.d_hidden
+ L = model.num_blocks
+ eps_pert = 1e-3
+
+ block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=lr, weight_decay=wd)
+ vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb) if credit_mode != 'dfa' else None
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=total_epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=total_epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=total_epochs)]
+ # Step schedulers to current position
+ for _ in range(start_epoch):
+ for s in scheds: s.step()
+
+ use_vec = credit_mode != 'dfa'
+ blend_alpha = credit_mode if isinstance(credit_mode, float) else (1.0 if credit_mode == 'vec' else 0.0)
+
+ log = {'test_acc': [], 'train_loss': [], 'gamma': [], 'rho': []}
+
+ for epoch in range(start_epoch + 1, total_epochs + 1):
+ model.train()
+ if use_vec: vector_net.train()
+ total_loss, correct, total = 0, 0, 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ hL = hiddens[-1].detach()
+
+ # Train Vec online (keep it fresh)
+ if use_vec and vec_opt is not None:
+ t_L = torch.ones(batch, device=device)
+ a_term = vector_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_term - delta_L)**2).sum(-1).mean()
+
+ l_train = np.random.randint(0, L)
+ h_l = hiddens[l_train].detach()
+ t_l = torch.full((batch,), l_train / L, device=device)
+ a_l = vector_net(h_l, t_l, s)
+ loss_proj = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(model.forward_from_layer(h_l + eps_pert*v, l_train), y, reduction='none')
+ lm = F.cross_entropy(model.forward_from_layer(h_l - eps_pert*v, l_train), y, reduction='none')
+ g_j = (lp - lm) / (2*eps_pert)
+ loss_proj = loss_proj + (((a_l*v).sum(-1) - g_j.detach())**2).mean()
+ loss_proj /= M
+ vloss = loss_term + loss_proj
+ vec_opt.zero_grad(); vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0)
+ vec_opt.step()
+
+ # Compute credits
+ with torch.no_grad():
+ vec_credits = [vector_net(hiddens[l].detach(),
+ torch.full((batch,), l/L, device=device), s).detach() for l in range(L)]
+ dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ if blend_alpha >= 1.0:
+ credits.append(vec_credits[l])
+ elif blend_alpha <= 0.0:
+ credits.append(dfa_credits[l])
+ else:
+ rms_v = (vec_credits[l]**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ rms_d = (dfa_credits[l]**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ credits.append(blend_alpha * vec_credits[l] / rms_v +
+ (1 - blend_alpha) * dfa_credits[l] / rms_d)
+
+ # Update head
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ a = credits[l]
+ rms = (a**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ f = model.blocks[l](hiddens[l].detach())
+ ll = (f * (a / rms)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ # Update embed
+ a0 = credits[0]
+ rms0 = (a0**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ el = (model.embed(x) * (a0 / rms0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for s in scheds: s.step()
+ test_acc = evaluate(model, test_loader, device)
+ log['test_acc'].append(test_acc)
+ log['train_loss'].append(total_loss / total)
+
+ # Diagnostics every 5 epochs or near handoff
+ near_handoff = abs(epoch - start_epoch) <= 5
+ if epoch % 5 == 0 or near_handoff or epoch == total_epochs:
+ cm = credit_mode if isinstance(credit_mode, float) else credit_mode
+ gamma, rho = compute_diagnostics(model, vector_net, Bs, test_loader, device,
+ 'vec' if blend_alpha >= 0.5 else 'dfa')
+ log['gamma'].append((epoch, gamma))
+ log['rho'].append((epoch, rho))
+ else:
+ gamma, rho = None, None
+
+ if epoch % 10 == 0 or near_handoff or epoch == total_epochs:
+ g_str = f", G={gamma:.4f}, r={rho:.4f}" if gamma is not None else ""
+ print(f" [{branch_name}] Ep {epoch}: acc={test_acc:.4f}{g_str}")
+
+ return log
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3
+ L = args.num_blocks
+ d = args.d_hidden
+
+ ckpt_dir = os.path.join(args.output_dir, f'dfa_ckpts_s{args.seed}')
+
+ # =========================================================
+ # Step 1: Train DFA baseline with checkpoints
+ # =========================================================
+ print(f"\n{'='*60}")
+ print(f"Step 1: Train DFA baseline with checkpoints")
+ print(f"{'='*60}")
+
+ all_exist = all(os.path.exists(os.path.join(ckpt_dir, f'dfa_epoch_{e}.pt'))
+ for e in args.checkpoint_epochs)
+ final_exist = os.path.exists(os.path.join(ckpt_dir, f'dfa_epoch_{args.epochs}.pt'))
+
+ if not all_exist or not final_exist:
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ model_dfa = ResidualMLP(input_dim, d, 10, L).to(device)
+ Bs, dfa_final_acc = train_dfa_with_checkpoints(
+ model_dfa, train_loader, test_loader, device,
+ epochs=args.epochs, save_epochs=args.checkpoint_epochs + [args.epochs],
+ ckpt_dir=ckpt_dir, lr=args.lr, wd=args.wd)
+ print(f" DFA final acc: {dfa_final_acc:.4f}")
+ else:
+ print(f" All DFA checkpoints exist in {ckpt_dir}")
+ final_ckpt = torch.load(os.path.join(ckpt_dir, f'dfa_epoch_{args.epochs}.pt'), map_location=device)
+ dfa_final_acc = final_ckpt['acc']
+ Bs = [B.to(device) for B in final_ckpt['Bs']]
+ print(f" DFA final acc: {dfa_final_acc:.4f}")
+
+ # =========================================================
+ # Step 2 & 3: For each checkpoint, offline-fit Vec then branch
+ # =========================================================
+ all_results = {}
+
+ for t0 in args.checkpoint_epochs:
+ print(f"\n{'='*60}")
+ print(f"Checkpoint t0={t0}")
+ print(f"{'='*60}")
+
+ # Load checkpoint
+ ckpt = torch.load(os.path.join(ckpt_dir, f'dfa_epoch_{t0}.pt'), map_location=device)
+ ckpt_Bs = [B.to(device) for B in ckpt['Bs']]
+ print(f" DFA acc at t0={t0}: {ckpt['acc']:.4f}")
+
+ # Offline-fit Vec on this checkpoint
+ print(f" Offline-fitting Vec on t0={t0}...")
+ model_frozen = ResidualMLP(input_dim, d, 10, L).to(device)
+ model_frozen.load_state_dict(ckpt['model'])
+ model_frozen.eval()
+ for p in model_frozen.parameters(): p.requires_grad_(False)
+
+ torch.manual_seed(args.seed + t0 * 1000 + 4000)
+ vec_net = offline_fit_vec(model_frozen, train_loader, device,
+ epochs=args.vec_fit_epochs, lr_fb=args.lr_fb, M=args.M)
+
+ # Evaluate Vec quality on this checkpoint
+ gamma_frozen, rho_frozen = compute_diagnostics(
+ model_frozen, vec_net, ckpt_Bs, test_loader, device, 'vec')
+ print(f" Vec quality at t0={t0}: Gamma={gamma_frozen:.4f}, rho={rho_frozen:.4f}")
+ for p in model_frozen.parameters(): p.requires_grad_(True)
+
+ # Branch training
+ for branch_name, credit_mode in args.branches:
+ print(f"\n --- Branch: {branch_name} (from t0={t0}) ---")
+
+ # Fresh copy of model at checkpoint
+ model_branch = ResidualMLP(input_dim, d, 10, L).to(device)
+ model_branch.load_state_dict(ckpt['model'])
+
+ # Fresh copy of Vec (from offline-fitted state)
+ vec_branch = copy.deepcopy(vec_net)
+
+ log = continue_training(
+ model_branch, vec_branch, ckpt_Bs, train_loader, test_loader, device,
+ start_epoch=t0, total_epochs=args.epochs,
+ credit_mode=credit_mode, lr=args.lr, lr_fb=args.lr_fb, wd=args.wd,
+ M=args.M, branch_name=branch_name)
+
+ key = f"t0={t0}_{branch_name}"
+ all_results[key] = {
+ 't0': t0, 'branch': branch_name, 'credit_mode': str(credit_mode),
+ 'vec_gamma_frozen': gamma_frozen, 'vec_rho_frozen': rho_frozen,
+ 'test_acc': log['test_acc'],
+ 'train_loss': log['train_loss'],
+ 'gamma': log['gamma'],
+ 'rho': log['rho'],
+ }
+
+ # =========================================================
+ # Summary
+ # =========================================================
+ print(f"\n{'='*100}")
+ print("SUMMARY")
+ print(f"{'='*100}")
+ print(f"{'Key':<35} {'acc@t0':>7} {'acc@20':>7} {'acc@50':>7} {'final':>7} "
+ f"{'mGamma':>8} {'mRho':>7}")
+ print("-" * 85)
+
+ # Add DFA baseline
+ dfa_full = torch.load(os.path.join(ckpt_dir, f'dfa_epoch_{args.epochs}.pt'), map_location=device)
+ print(f"{'DFA_full_baseline':<35} {'':>7} {'':>7} {'':>7} {dfa_full['acc']:>7.4f} {'':>8} {'':>7}")
+
+ for key, r in all_results.items():
+ accs = r['test_acc']
+ t0 = r['t0']
+ # Index relative to start_epoch
+ def get_acc_at(target_epoch):
+ idx = target_epoch - t0 - 1
+ if 0 <= idx < len(accs):
+ return accs[idx]
+ return float('nan')
+
+ acc_20 = get_acc_at(20)
+ acc_50 = get_acc_at(50)
+ final = accs[-1] if accs else float('nan')
+ acc_t0 = r['vec_gamma_frozen'] # placeholder for checkpoint info
+
+ gammas = [g for _, g in r['gamma']]
+ rhos = [rh for _, rh in r['rho']]
+ mg = np.mean(gammas) if gammas else float('nan')
+ mr = np.mean(rhos) if rhos else float('nan')
+
+ print(f"{key:<35} {'':>7} {acc_20:>7.4f} {acc_50:>7.4f} {final:>7.4f} {mg:>8.4f} {mr:>7.4f}")
+
+ # Save
+ save_data = {}
+ for key, r in all_results.items():
+ save_data[key] = {k: v for k, v in r.items()}
+ save_data['dfa_final_acc'] = float(dfa_final_acc)
+
+ out_path = os.path.join(args.output_dir, f'handoff_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+ # Judgment
+ print(f"\n{'='*60}")
+ print("JUDGMENT")
+ print(f"{'='*60}")
+
+ for t0 in args.checkpoint_epochs:
+ dfa_key = f"t0={t0}_continue_DFA"
+ if dfa_key not in all_results:
+ continue
+ dfa_final = all_results[dfa_key]['test_acc'][-1]
+
+ for key, r in all_results.items():
+ if r['t0'] != t0 or r['branch'] == 'continue_DFA':
+ continue
+ branch_final = r['test_acc'][-1]
+ diff = branch_final - dfa_final
+ print(f" t0={t0}: {r['branch']} final={branch_final:.4f} vs continue_DFA={dfa_final:.4f} "
+ f"(diff={diff:+.4f})")
+ if diff > 0.01:
+ print(f" -> {r['branch']} OUTPERFORMS continue_DFA!")
+ elif diff > -0.01:
+ print(f" -> Similar to continue_DFA")
+ else:
+ print(f" -> Worse than continue_DFA")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 9A: Checkpointed Offline Handoff')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--M', type=int, default=4)
+ parser.add_argument('--vec_fit_epochs', type=int, default=60)
+ parser.add_argument('--checkpoint_epochs', type=int, nargs='+', default=[5])
+ parser.add_argument('--branch_spec', type=str, nargs='+',
+ default=['continue_DFA:dfa', 'handoff_to_Vec:vec', 'handoff_blend_05:0.5'])
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/checkpointed_handoff')
+ args = parser.parse_args()
+
+ # Parse branch specs
+ args.branches = []
+ for spec in args.branch_spec:
+ name, mode = spec.split(':')
+ try:
+ mode = float(mode)
+ except ValueError:
+ pass
+ args.branches.append((name, mode))
+
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()