diff options
Diffstat (limited to 'experiments/checkpointed_handoff.py')
| -rw-r--r-- | experiments/checkpointed_handoff.py | 617 |
1 files changed, 617 insertions, 0 deletions
diff --git a/experiments/checkpointed_handoff.py b/experiments/checkpointed_handoff.py new file mode 100644 index 0000000..3057825 --- /dev/null +++ b/experiments/checkpointed_handoff.py @@ -0,0 +1,617 @@ +""" +Phase 9A: Checkpointed Offline Handoff. + +Core question: if we offline-train Vec on a DFA trajectory checkpoint, +can it take over and outperform continuing with DFA? + +Steps: +1. Train DFA baseline, save checkpoints at t0={1,5,10} +2. At each checkpoint, freeze forward net and offline-train Vec_eT_M4 +3. From each checkpoint, branch into: continue_DFA, handoff_to_Vec, blends +4. Compare trajectories +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import SinusoidalTimeEmbed +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test + + +class VectorCreditNet(nn.Module): + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def evaluate(model, test_loader, device): + model.eval() + c, t = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + c += (model(x).argmax(1) == y).sum().item(); t += x.size(0) + return c / t + + +def compute_diagnostics(model, vector_net, dfa_Bs, test_loader, device, credit_mode): + """Compute mean Gamma and rho for current credit source.""" + model.eval() + if vector_net is not None: + vector_net.eval() + L = model.num_blocks + + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device); break + batch = x.size(0) + + # BP gradients (eval only) — temporarily enable requires_grad + was_frozen = not next(model.parameters()).requires_grad + if was_frozen: + for p in model.parameters(): p.requires_grad_(True) + model.zero_grad() + logits_bp, hbp = model(x, return_hidden=True) + for l in range(L + 1): hbp[l].retain_grad() + F.cross_entropy(logits_bp, y).backward() + bp_grads = {l: hbp[l].grad.detach().clone() for l in range(L + 1)} + if was_frozen: + for p in model.parameters(): p.requires_grad_(False) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + gammas, rhos = [], [] + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + if credit_mode == 'dfa': + a_l = (s @ dfa_Bs[l].T).detach() + elif credit_mode == 'vec': + a_l = vector_net(h_l, t_l, s).detach() + elif isinstance(credit_mode, float): + alpha = credit_mode + a_dfa = (s @ dfa_Bs[l].T).detach() + a_vec = vector_net(h_l, t_l, s).detach() + rms_v = (a_vec**2).mean(-1, keepdim=True).sqrt() + 1e-6 + rms_d = (a_dfa**2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_l = alpha * a_vec / rms_v + (1 - alpha) * a_dfa / rms_d + + gammas.append(cosine_similarity_batch(a_l, bp_grads[l])) + def make_fwd(sl): + def f(h): + with torch.no_grad(): + c = h + for i in range(sl, L): c = c + model.blocks[i](c) + return F.cross_entropy(model.out_head(model.out_ln(c)), y, reduction='none') + return f + rhos.append(perturbation_correlation(h_l, a_l, make_fwd(l), epsilon=1e-3, M=16)) + + return float(np.mean(gammas)), float(np.mean(rhos)) + + +# ============================================================================= +# Step 1: Train DFA with checkpoints +# ============================================================================= +def train_dfa_with_checkpoints(model, train_loader, test_loader, device, + epochs, save_epochs, ckpt_dir, lr=1e-3, wd=0.01): + os.makedirs(ckpt_dir, exist_ok=True) + d = model.d_hidden + L = model.num_blocks + Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + + for epoch in range(1, epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(-1) + e_T[torch.arange(batch), y] -= 1 + hL = hiddens[-1].detach() + loss_out = F.cross_entropy(model.out_head(model.out_ln(hL)), y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + for l in range(L): + a = (e_T @ Bs[l].T).detach() + rms = (a**2).mean(-1, keepdim=True).sqrt() + 1e-6 + f = model.blocks[l](hiddens[l].detach()) + ll = (f * (a / rms)).sum(-1).mean() + block_opts[l].zero_grad(); ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + a0 = (e_T @ Bs[0].T).detach() + rms0 = (a0**2).mean(-1, keepdim=True).sqrt() + 1e-6 + el = (model.embed(x) * (a0 / rms0)).sum(-1).mean() + embed_opt.zero_grad(); el.backward(); embed_opt.step() + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + for s in scheds: s.step() + + if epoch in save_epochs: + acc = evaluate(model, test_loader, device) + ckpt = { + 'model': model.state_dict(), + 'Bs': [B.cpu() for B in Bs], + 'epoch': epoch, 'acc': acc, + } + torch.save(ckpt, os.path.join(ckpt_dir, f'dfa_epoch_{epoch}.pt')) + print(f" [DFA] Saved epoch {epoch} (acc={acc:.4f})") + elif epoch % 10 == 0: + acc = evaluate(model, test_loader, device) + print(f" [DFA] Epoch {epoch}: acc={acc:.4f}") + + # Save final + final_acc = evaluate(model, test_loader, device) + ckpt = {'model': model.state_dict(), 'Bs': [B.cpu() for B in Bs], + 'epoch': epochs, 'acc': final_acc} + torch.save(ckpt, os.path.join(ckpt_dir, f'dfa_epoch_{epochs}.pt')) + return Bs, final_acc + + +# ============================================================================= +# Step 2: Offline-fit Vec on frozen checkpoint +# ============================================================================= +def offline_fit_vec(model, train_loader, device, epochs=60, lr_fb=1e-3, M=4): + d = model.d_hidden + L = model.num_blocks + vec_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb) + eps = 1e-3 + model.eval() + for ep in range(1, epochs + 1): + vec_net.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + hL = hiddens[-1].detach() + t_L = torch.ones(batch, device=device) + a_term = vec_net(hL, t_L, s) + hL_req = hL.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req)) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() + loss_term = ((a_term - delta_L)**2).sum(-1).mean() + + l = np.random.randint(0, L) + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + a_l = vec_net(h_l, t_l, s) + loss_proj = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy(model.forward_from_layer(h_l + eps*v, l), y, reduction='none') + lm = F.cross_entropy(model.forward_from_layer(h_l - eps*v, l), y, reduction='none') + g_j = (lp - lm) / (2*eps) + loss_proj = loss_proj + (((a_l*v).sum(-1) - g_j.detach())**2).mean() + loss_proj /= M + vloss = loss_term + loss_proj + vec_opt.zero_grad(); vloss.backward() + torch.nn.utils.clip_grad_norm_(vec_net.parameters(), 1.0) + vec_opt.step() + if ep % 20 == 0 or ep == 1: + print(f" [Vec fit] Ep {ep}") + return vec_net + + +# ============================================================================= +# Step 3: Continue training from checkpoint with a given credit schedule +# ============================================================================= +def continue_training(model, vector_net, Bs, train_loader, test_loader, device, + start_epoch, total_epochs, credit_mode, lr=1e-3, lr_fb=1e-3, + wd=0.01, M=4, branch_name=''): + """ + Continue training from a checkpoint. + credit_mode: 'dfa', 'vec', or float (blend alpha for Vec) + """ + d = model.d_hidden + L = model.num_blocks + eps_pert = 1e-3 + + block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=lr, weight_decay=wd) + vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb) if credit_mode != 'dfa' else None + scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=total_epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=total_epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=total_epochs)] + # Step schedulers to current position + for _ in range(start_epoch): + for s in scheds: s.step() + + use_vec = credit_mode != 'dfa' + blend_alpha = credit_mode if isinstance(credit_mode, float) else (1.0 if credit_mode == 'vec' else 0.0) + + log = {'test_acc': [], 'train_loss': [], 'gamma': [], 'rho': []} + + for epoch in range(start_epoch + 1, total_epochs + 1): + model.train() + if use_vec: vector_net.train() + total_loss, correct, total = 0, 0, 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + hL = hiddens[-1].detach() + + # Train Vec online (keep it fresh) + if use_vec and vec_opt is not None: + t_L = torch.ones(batch, device=device) + a_term = vector_net(hL, t_L, s) + hL_req = hL.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req)) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() + loss_term = ((a_term - delta_L)**2).sum(-1).mean() + + l_train = np.random.randint(0, L) + h_l = hiddens[l_train].detach() + t_l = torch.full((batch,), l_train / L, device=device) + a_l = vector_net(h_l, t_l, s) + loss_proj = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy(model.forward_from_layer(h_l + eps_pert*v, l_train), y, reduction='none') + lm = F.cross_entropy(model.forward_from_layer(h_l - eps_pert*v, l_train), y, reduction='none') + g_j = (lp - lm) / (2*eps_pert) + loss_proj = loss_proj + (((a_l*v).sum(-1) - g_j.detach())**2).mean() + loss_proj /= M + vloss = loss_term + loss_proj + vec_opt.zero_grad(); vloss.backward() + torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0) + vec_opt.step() + + # Compute credits + with torch.no_grad(): + vec_credits = [vector_net(hiddens[l].detach(), + torch.full((batch,), l/L, device=device), s).detach() for l in range(L)] + dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)] + + credits = [] + for l in range(L): + if blend_alpha >= 1.0: + credits.append(vec_credits[l]) + elif blend_alpha <= 0.0: + credits.append(dfa_credits[l]) + else: + rms_v = (vec_credits[l]**2).mean(-1, keepdim=True).sqrt() + 1e-6 + rms_d = (dfa_credits[l]**2).mean(-1, keepdim=True).sqrt() + 1e-6 + credits.append(blend_alpha * vec_credits[l] / rms_v + + (1 - blend_alpha) * dfa_credits[l] / rms_d) + + # Update head + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + + # Update blocks + for l in range(L): + a = credits[l] + rms = (a**2).mean(-1, keepdim=True).sqrt() + 1e-6 + f = model.blocks[l](hiddens[l].detach()) + ll = (f * (a / rms)).sum(-1).mean() + block_opts[l].zero_grad(); ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + # Update embed + a0 = credits[0] + rms0 = (a0**2).mean(-1, keepdim=True).sqrt() + 1e-6 + el = (model.embed(x) * (a0 / rms0)).sum(-1).mean() + embed_opt.zero_grad(); el.backward(); embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for s in scheds: s.step() + test_acc = evaluate(model, test_loader, device) + log['test_acc'].append(test_acc) + log['train_loss'].append(total_loss / total) + + # Diagnostics every 5 epochs or near handoff + near_handoff = abs(epoch - start_epoch) <= 5 + if epoch % 5 == 0 or near_handoff or epoch == total_epochs: + cm = credit_mode if isinstance(credit_mode, float) else credit_mode + gamma, rho = compute_diagnostics(model, vector_net, Bs, test_loader, device, + 'vec' if blend_alpha >= 0.5 else 'dfa') + log['gamma'].append((epoch, gamma)) + log['rho'].append((epoch, rho)) + else: + gamma, rho = None, None + + if epoch % 10 == 0 or near_handoff or epoch == total_epochs: + g_str = f", G={gamma:.4f}, r={rho:.4f}" if gamma is not None else "" + print(f" [{branch_name}] Ep {epoch}: acc={test_acc:.4f}{g_str}") + + return log + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + train_loader, test_loader = get_cifar10(args.batch_size) + input_dim = 32 * 32 * 3 + L = args.num_blocks + d = args.d_hidden + + ckpt_dir = os.path.join(args.output_dir, f'dfa_ckpts_s{args.seed}') + + # ========================================================= + # Step 1: Train DFA baseline with checkpoints + # ========================================================= + print(f"\n{'='*60}") + print(f"Step 1: Train DFA baseline with checkpoints") + print(f"{'='*60}") + + all_exist = all(os.path.exists(os.path.join(ckpt_dir, f'dfa_epoch_{e}.pt')) + for e in args.checkpoint_epochs) + final_exist = os.path.exists(os.path.join(ckpt_dir, f'dfa_epoch_{args.epochs}.pt')) + + if not all_exist or not final_exist: + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + model_dfa = ResidualMLP(input_dim, d, 10, L).to(device) + Bs, dfa_final_acc = train_dfa_with_checkpoints( + model_dfa, train_loader, test_loader, device, + epochs=args.epochs, save_epochs=args.checkpoint_epochs + [args.epochs], + ckpt_dir=ckpt_dir, lr=args.lr, wd=args.wd) + print(f" DFA final acc: {dfa_final_acc:.4f}") + else: + print(f" All DFA checkpoints exist in {ckpt_dir}") + final_ckpt = torch.load(os.path.join(ckpt_dir, f'dfa_epoch_{args.epochs}.pt'), map_location=device) + dfa_final_acc = final_ckpt['acc'] + Bs = [B.to(device) for B in final_ckpt['Bs']] + print(f" DFA final acc: {dfa_final_acc:.4f}") + + # ========================================================= + # Step 2 & 3: For each checkpoint, offline-fit Vec then branch + # ========================================================= + all_results = {} + + for t0 in args.checkpoint_epochs: + print(f"\n{'='*60}") + print(f"Checkpoint t0={t0}") + print(f"{'='*60}") + + # Load checkpoint + ckpt = torch.load(os.path.join(ckpt_dir, f'dfa_epoch_{t0}.pt'), map_location=device) + ckpt_Bs = [B.to(device) for B in ckpt['Bs']] + print(f" DFA acc at t0={t0}: {ckpt['acc']:.4f}") + + # Offline-fit Vec on this checkpoint + print(f" Offline-fitting Vec on t0={t0}...") + model_frozen = ResidualMLP(input_dim, d, 10, L).to(device) + model_frozen.load_state_dict(ckpt['model']) + model_frozen.eval() + for p in model_frozen.parameters(): p.requires_grad_(False) + + torch.manual_seed(args.seed + t0 * 1000 + 4000) + vec_net = offline_fit_vec(model_frozen, train_loader, device, + epochs=args.vec_fit_epochs, lr_fb=args.lr_fb, M=args.M) + + # Evaluate Vec quality on this checkpoint + gamma_frozen, rho_frozen = compute_diagnostics( + model_frozen, vec_net, ckpt_Bs, test_loader, device, 'vec') + print(f" Vec quality at t0={t0}: Gamma={gamma_frozen:.4f}, rho={rho_frozen:.4f}") + for p in model_frozen.parameters(): p.requires_grad_(True) + + # Branch training + for branch_name, credit_mode in args.branches: + print(f"\n --- Branch: {branch_name} (from t0={t0}) ---") + + # Fresh copy of model at checkpoint + model_branch = ResidualMLP(input_dim, d, 10, L).to(device) + model_branch.load_state_dict(ckpt['model']) + + # Fresh copy of Vec (from offline-fitted state) + vec_branch = copy.deepcopy(vec_net) + + log = continue_training( + model_branch, vec_branch, ckpt_Bs, train_loader, test_loader, device, + start_epoch=t0, total_epochs=args.epochs, + credit_mode=credit_mode, lr=args.lr, lr_fb=args.lr_fb, wd=args.wd, + M=args.M, branch_name=branch_name) + + key = f"t0={t0}_{branch_name}" + all_results[key] = { + 't0': t0, 'branch': branch_name, 'credit_mode': str(credit_mode), + 'vec_gamma_frozen': gamma_frozen, 'vec_rho_frozen': rho_frozen, + 'test_acc': log['test_acc'], + 'train_loss': log['train_loss'], + 'gamma': log['gamma'], + 'rho': log['rho'], + } + + # ========================================================= + # Summary + # ========================================================= + print(f"\n{'='*100}") + print("SUMMARY") + print(f"{'='*100}") + print(f"{'Key':<35} {'acc@t0':>7} {'acc@20':>7} {'acc@50':>7} {'final':>7} " + f"{'mGamma':>8} {'mRho':>7}") + print("-" * 85) + + # Add DFA baseline + dfa_full = torch.load(os.path.join(ckpt_dir, f'dfa_epoch_{args.epochs}.pt'), map_location=device) + print(f"{'DFA_full_baseline':<35} {'':>7} {'':>7} {'':>7} {dfa_full['acc']:>7.4f} {'':>8} {'':>7}") + + for key, r in all_results.items(): + accs = r['test_acc'] + t0 = r['t0'] + # Index relative to start_epoch + def get_acc_at(target_epoch): + idx = target_epoch - t0 - 1 + if 0 <= idx < len(accs): + return accs[idx] + return float('nan') + + acc_20 = get_acc_at(20) + acc_50 = get_acc_at(50) + final = accs[-1] if accs else float('nan') + acc_t0 = r['vec_gamma_frozen'] # placeholder for checkpoint info + + gammas = [g for _, g in r['gamma']] + rhos = [rh for _, rh in r['rho']] + mg = np.mean(gammas) if gammas else float('nan') + mr = np.mean(rhos) if rhos else float('nan') + + print(f"{key:<35} {'':>7} {acc_20:>7.4f} {acc_50:>7.4f} {final:>7.4f} {mg:>8.4f} {mr:>7.4f}") + + # Save + save_data = {} + for key, r in all_results.items(): + save_data[key] = {k: v for k, v in r.items()} + save_data['dfa_final_acc'] = float(dfa_final_acc) + + out_path = os.path.join(args.output_dir, f'handoff_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2, default=float) + print(f"\nSaved to {out_path}") + + # Judgment + print(f"\n{'='*60}") + print("JUDGMENT") + print(f"{'='*60}") + + for t0 in args.checkpoint_epochs: + dfa_key = f"t0={t0}_continue_DFA" + if dfa_key not in all_results: + continue + dfa_final = all_results[dfa_key]['test_acc'][-1] + + for key, r in all_results.items(): + if r['t0'] != t0 or r['branch'] == 'continue_DFA': + continue + branch_final = r['test_acc'][-1] + diff = branch_final - dfa_final + print(f" t0={t0}: {r['branch']} final={branch_final:.4f} vs continue_DFA={dfa_final:.4f} " + f"(diff={diff:+.4f})") + if diff > 0.01: + print(f" -> {r['branch']} OUTPERFORMS continue_DFA!") + elif diff > -0.01: + print(f" -> Similar to continue_DFA") + else: + print(f" -> Worse than continue_DFA") + + +def main(): + parser = argparse.ArgumentParser(description='Phase 9A: Checkpointed Offline Handoff') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--M', type=int, default=4) + parser.add_argument('--vec_fit_epochs', type=int, default=60) + parser.add_argument('--checkpoint_epochs', type=int, nargs='+', default=[5]) + parser.add_argument('--branch_spec', type=str, nargs='+', + default=['continue_DFA:dfa', 'handoff_to_Vec:vec', 'handoff_blend_05:0.5']) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=3) + parser.add_argument('--output_dir', type=str, default='results/checkpointed_handoff') + args = parser.parse_args() + + # Parse branch specs + args.branches = [] + for spec in args.branch_spec: + name, mode = spec.split(':') + try: + mode = float(mode) + except ValueError: + pass + args.branches.append((name, mode)) + + run_experiment(args) + + +if __name__ == '__main__': + main() |
