""" Phase 10A.8A: Freeze with Alpha Decay. Core question: After freezing Vec, can linearly decaying alpha (fading out the frozen Vec and returning to pure DFA) recover or improve over a fixed-alpha frozen blend? 8 branches from the same DFA checkpoint at t0=5: 1. continue_DFA — pure DFA baseline 2. blend_random_trainable_alpha075 — standard reference (always trainable, alpha=0.75) 3. freeze_after_1_fixed075 — train Vec 1 epoch, freeze, keep alpha=0.75 4. freeze_after_5_fixed075 — train Vec 5 epochs, freeze, keep alpha=0.75 5. freeze_after_1_decay_to_025 — train Vec 1 epoch, freeze, then decay alpha 0.75->0.25 over 5 epochs 6. freeze_after_5_decay_to_025 — train Vec 5 epochs, freeze, then decay alpha 0.75->0.25 over 5 epochs 7. freeze_after_1_decay_to_000 — train Vec 1 epoch, freeze, then decay alpha 0.75->0.0 over 5 epochs 8. freeze_after_5_decay_to_000 — train Vec 5 epochs, freeze, then decay alpha 0.75->0.0 over 5 epochs """ import os import sys import json import argparse import numpy as np import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.utils.data import DataLoader import torchvision import torchvision.transforms as transforms import copy sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP from models.value_net import SinusoidalTimeEmbed from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation # --------------------------------------------------------------------------- # Auxiliary network # --------------------------------------------------------------------------- class VectorCreditNet(nn.Module): """Standard Vec: takes (h, t, s) -> d_hidden credit vector.""" def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): super().__init__() self.ln = nn.LayerNorm(d_hidden) self.time_embed = SinusoidalTimeEmbed(time_embed_dim) input_dim = d_hidden + time_embed_dim + s_dim layers = [] for i in range(num_layers): in_d = input_dim if i == 0 else hidden_dim layers.append(nn.Linear(in_d, hidden_dim)) layers.append(nn.GELU()) layers.append(nn.Linear(hidden_dim, d_hidden)) self.net = nn.Sequential(*layers) def forward(self, h, t, s): return self.net(torch.cat([self.ln(h), self.time_embed(t), s], dim=-1)) # --------------------------------------------------------------------------- # Alpha schedule helpers # --------------------------------------------------------------------------- def make_alpha_schedule(freeze_epoch, initial_alpha, target_alpha, decay_window): """ Returns a function alpha_fn(epoch, t0) -> current alpha. Before freeze_epoch training epochs have passed, alpha = initial_alpha. After freeze_epoch training epochs, linearly decay from initial_alpha to target_alpha over decay_window epochs, then stay at target_alpha. epoch is the absolute epoch number; t0 is the DFA checkpoint epoch. Training epochs elapsed since handoff = epoch - t0. """ def alpha_fn(epoch, t0): elapsed = epoch - t0 # epochs since handoff (1-indexed) if elapsed <= freeze_epoch: return initial_alpha # epochs after freeze after_freeze = elapsed - freeze_epoch if decay_window <= 0 or target_alpha == initial_alpha: return target_alpha progress = min(after_freeze / decay_window, 1.0) return initial_alpha + (target_alpha - initial_alpha) * progress return alpha_fn # --------------------------------------------------------------------------- # Data # --------------------------------------------------------------------------- def get_cifar10(batch_size=128): transform_train = transforms.Compose([ transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]) transform_test = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616))]) trainset = torchvision.datasets.CIFAR10( root='./data', train=True, download=True, transform=transform_train) testset = torchvision.datasets.CIFAR10( root='./data', train=False, download=True, transform=transform_test) return (DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True), DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)) # --------------------------------------------------------------------------- # Evaluation helpers # --------------------------------------------------------------------------- def evaluate(model, test_loader, device): model.eval(); c, t = 0, 0 with torch.no_grad(): for x, y in test_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device) c += (model(x).argmax(1) == y).sum().item(); t += x.size(0) return c / t def compute_diagnostics(model, aux_net, Bs, test_loader, device, credit_mode, alpha=0.75): """Compute mean Gamma (BP cosine) and mean rho (perturbation correlation).""" model.eval() if aux_net is not None: aux_net.eval() L = model.num_blocks for x, y in test_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device); break batch = x.size(0) # BP pass for hidden gradients (offline eval only, not used for training) was_frozen = not next(model.parameters()).requires_grad if was_frozen: for p in model.parameters(): p.requires_grad_(True) model.zero_grad() lo, hbp = model(x, return_hidden=True) for l in range(L + 1): hbp[l].retain_grad() F.cross_entropy(lo, y).backward() bp = {l: hbp[l].grad.detach().clone() for l in range(L + 1)} if was_frozen: for p in model.parameters(): p.requires_grad_(False) with torch.no_grad(): lo2, hi = model(x, return_hidden=True) eT = lo2.softmax(-1); eT[torch.arange(batch), y] -= 1; s = eT.detach() gammas, rhos = [], [] for l in range(L): h_l = hi[l].detach() t_l = torch.full((batch,), l / L, device=device) if credit_mode == 'dfa': a_l = (s @ Bs[l].T).detach() elif credit_mode == 'blend' and aux_net is not None: a_dfa = (s @ Bs[l].T).detach() a_aux = aux_net(h_l, t_l, s).detach() rd = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 rv = (a_aux ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_l = alpha * a_aux / rv + (1 - alpha) * a_dfa / rd else: a_l = (s @ Bs[l].T).detach() gammas.append(cosine_similarity_batch(a_l, bp[l])) def make_fwd(sl): def f(h): with torch.no_grad(): c = h for i in range(sl, L): c = c + model.blocks[i](c) return F.cross_entropy( model.out_head(model.out_ln(c)), y, reduction='none') return f rhos.append(perturbation_correlation(h_l, a_l, make_fwd(l), epsilon=1e-3, M=16)) return float(np.mean(gammas)), float(np.mean(rhos)) # --------------------------------------------------------------------------- # DFA training + checkpoint # --------------------------------------------------------------------------- def train_dfa_get_checkpoint(model, train_loader, test_loader, device, total_epochs, t0, lr, wd): d = model.d_hidden; L = model.num_blocks Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) scheds = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=total_epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=total_epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=total_epochs)]) ckpt = None for epoch in range(1, total_epochs + 1): model.train(); tl, c, t = 0, 0, 0 for x, y in train_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device); b = x.size(0) with torch.no_grad(): lo, hi = model(x, return_hidden=True); lv = F.cross_entropy(lo, y) eT = lo.softmax(-1); eT[torch.arange(b), y] -= 1 hL = hi[-1].detach() lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y) head_opt.zero_grad(); lo2.backward(); head_opt.step() for l in range(L): a = (eT @ Bs[l].T).detach() rm = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 f = model.blocks[l](hi[l].detach()) ll = (f * (a / rm)).sum(-1).mean() block_opts[l].zero_grad(); ll.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() a0 = (eT @ Bs[0].T).detach() r0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 el = (model.embed(x) * (a0 / r0)).sum(-1).mean() embed_opt.zero_grad(); el.backward(); embed_opt.step() tl += lv.item() * b; c += (lo.argmax(1) == y).sum().item(); t += b for s in scheds: s.step() if epoch == t0: acc = evaluate(model, test_loader, device) ckpt = {'model': copy.deepcopy(model.state_dict()), 'Bs': [B.clone() for B in Bs], 'acc': acc} print(f" [DFA] Checkpoint at epoch {t0}: acc={acc:.4f}") if epoch % 10 == 0: print(f" [DFA] Epoch {epoch}: acc={evaluate(model, test_loader, device):.4f}") return Bs, ckpt # --------------------------------------------------------------------------- # Branch runner # --------------------------------------------------------------------------- def run_branch(model, aux_net, Bs, train_loader, test_loader, device, t0, total_epochs, branch_type, alpha_schedule_fn, lr, lr_fb, wd, M, branch_name='', freeze_epoch=None): """ Run a training branch from a loaded checkpoint. branch_type: 'dfa' — pure DFA, no aux 'blend' — blend DFA + Vec; aux_net trained online if vec_opt active 'blend_frozen' — blend DFA + frozen Vec; Vec trained for freeze_epoch epochs then frozen alpha_schedule_fn(epoch, t0) -> float: returns alpha at each absolute epoch. freeze_epoch: int — for 'blend_frozen', number of epochs to train Vec before freezing. """ d = model.d_hidden; L = model.num_blocks; eps_pert = 1e-3 block_opts = [optim.AdamW(b.parameters(), lr=lr, weight_decay=wd) for b in model.blocks] embed_opt = optim.AdamW(model.embed.parameters(), lr=lr, weight_decay=wd) head_opt = optim.AdamW( list(model.out_head.parameters()) + list(model.out_ln.parameters()), lr=lr, weight_decay=wd) if branch_type != 'dfa' and aux_net is not None: vec_opt = optim.Adam(aux_net.parameters(), lr=lr_fb) else: vec_opt = None scheds = ([optim.lr_scheduler.CosineAnnealingLR(o, T_max=total_epochs) for o in block_opts] + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=total_epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=total_epochs)]) # Advance schedulers to match checkpoint epoch for _ in range(t0): for s in scheds: s.step() log = {'test_acc': [], 'train_loss': [], 'gamma': [], 'rho': [], 'alpha_eff': []} diag_epochs = set( list(range(t0 + 1, min(t0 + 6, total_epochs + 1))) + [t0 + 8, t0 + 10, t0 + 15, t0 + 20] + list(range(t0 + 10, total_epochs + 1, 10)) + [total_epochs]) vec_frozen = False # whether Vec has been frozen for epoch in range(t0 + 1, total_epochs + 1): # Handle freeze: freeze Vec after freeze_epoch training epochs if (branch_type == 'blend_frozen' and freeze_epoch is not None and not vec_frozen): elapsed = epoch - t0 # training epochs since handoff (1-indexed) if elapsed > freeze_epoch: if aux_net is not None: aux_net.requires_grad_(False) aux_net.eval() vec_opt = None vec_frozen = True print(f" [{branch_name}] Freezing Vec at epoch {epoch} " f"(after {freeze_epoch} training epochs)") # Compute alpha for this epoch cur_alpha = alpha_schedule_fn(epoch, t0) model.train() if aux_net is not None: if vec_opt is not None: aux_net.train() else: aux_net.eval() tl, c, t = 0, 0, 0 epoch_aux_norms, epoch_dfa_norms = [], [] for x, y in train_loader: x = x.view(x.size(0), -1).to(device); y = y.to(device); batch = x.size(0) with torch.no_grad(): lo, hi = model(x, return_hidden=True); lv = F.cross_entropy(lo, y) eT = lo.softmax(-1); eT[torch.arange(batch), y] -= 1; s = eT.detach() hL = hi[-1].detach() # ---------------------------------------------------------------- # Train Vec with standard perturbation targets (if applicable) # ---------------------------------------------------------------- if vec_opt is not None and aux_net is not None: t_L = torch.ones(batch, device=device) a_term = aux_net(hL, t_L, s) hL_req = hL.clone().requires_grad_(True) ce = F.cross_entropy( model.out_head(model.out_ln(hL_req)), y, reduction='sum') dL = torch.autograd.grad(ce, hL_req)[0].detach() loss_term = ((a_term - dL) ** 2).sum(-1).mean() lt = np.random.randint(0, L) h_l = hi[lt].detach() t_l = torch.full((batch,), lt / L, device=device) a_l = aux_net(h_l, t_l, s) lp2 = torch.tensor(0.0, device=device) for _ in range(M): v = torch.randn_like(h_l) v = v / (v.norm(-1, keepdim=True) + 1e-8) with torch.no_grad(): lp = F.cross_entropy( model.forward_from_layer(h_l + eps_pert * v, lt), y, reduction='none') lm = F.cross_entropy( model.forward_from_layer(h_l - eps_pert * v, lt), y, reduction='none') gj = (lp - lm) / (2 * eps_pert) lp2 = lp2 + (((a_l * v).sum(-1) - gj.detach()) ** 2).mean() lp2 /= M vl = loss_term + lp2 vec_opt.zero_grad(); vl.backward() torch.nn.utils.clip_grad_norm_(aux_net.parameters(), 1.0) vec_opt.step() # ---------------------------------------------------------------- # Compute credits for each block # ---------------------------------------------------------------- credits = [] for l in range(L): a_dfa = (eT @ Bs[l].T).detach() rms_d = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 if branch_type == 'dfa' or aux_net is None or cur_alpha == 0.0: credits.append(a_dfa / rms_d) epoch_aux_norms.append(0.0) epoch_dfa_norms.append((a_dfa / rms_d).norm().item()) else: h_l = hi[l].detach() t_l = torch.full((batch,), l / L, device=device) with torch.no_grad(): a_aux = aux_net(h_l, t_l, s).detach() rms_v = (a_aux ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 a_blend = cur_alpha * a_aux / rms_v + (1 - cur_alpha) * a_dfa / rms_d credits.append(a_blend) epoch_aux_norms.append((cur_alpha * a_aux / rms_v).norm().item()) epoch_dfa_norms.append(((1 - cur_alpha) * a_dfa / rms_d).norm().item()) # ---------------------------------------------------------------- # Update output head # ---------------------------------------------------------------- lo2 = F.cross_entropy(model.out_head(model.out_ln(hL)), y) head_opt.zero_grad(); lo2.backward(); head_opt.step() # ---------------------------------------------------------------- # Update blocks with local surrogate # ---------------------------------------------------------------- for l in range(L): a = credits[l] rm = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 f = model.blocks[l](hi[l].detach()) ll = (f * (a / rm)).sum(-1).mean() block_opts[l].zero_grad(); ll.backward() torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) block_opts[l].step() # Update embedding with block-0 credit a0 = credits[0] r0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 el = (model.embed(x) * (a0 / r0)).sum(-1).mean() embed_opt.zero_grad(); el.backward(); embed_opt.step() tl += lv.item() * batch; c += (lo.argmax(1) == y).sum().item(); t += batch for sch in scheds: sch.step() ta = evaluate(model, test_loader, device) log['test_acc'].append(ta); log['train_loss'].append(tl / t) mean_aux = np.mean(epoch_aux_norms) if epoch_aux_norms else 0.0 mean_dfa = np.mean(epoch_dfa_norms) if epoch_dfa_norms else 1.0 aeff = mean_aux / (mean_aux + mean_dfa + 1e-12) log['alpha_eff'].append((epoch, aeff)) if epoch in diag_epochs: cm = 'blend' if (branch_type != 'dfa' and aux_net is not None and cur_alpha > 0.0) else 'dfa' diag_aux = aux_net if cm == 'blend' else None gamma, rho = compute_diagnostics( model, diag_aux, Bs, test_loader, device, cm, cur_alpha) log['gamma'].append((epoch, gamma)); log['rho'].append((epoch, rho)) if epoch <= t0 + 15 or epoch % 20 == 0 or epoch == total_epochs: frozen_str = ' [FROZEN]' if vec_frozen else '' print(f" [{branch_name}]{frozen_str} Ep {epoch}: acc={ta:.4f}, " f"G={gamma:.4f}, r={rho:.4f}, aeff={aeff:.3f}, alpha={cur_alpha:.3f}") elif epoch % 10 == 0 or epoch == total_epochs: frozen_str = ' [FROZEN]' if vec_frozen else '' print(f" [{branch_name}]{frozen_str} Ep {epoch}: acc={ta:.4f}, " f"alpha={cur_alpha:.3f}") return log # --------------------------------------------------------------------------- # Main experiment # --------------------------------------------------------------------------- def run_experiment(args): device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') print(f"Using device: {device}") os.makedirs(args.output_dir, exist_ok=True) torch.manual_seed(args.seed); np.random.seed(args.seed) torch.cuda.manual_seed_all(args.seed) train_loader, test_loader = get_cifar10(args.batch_size) input_dim = 32 * 32 * 3; L = args.num_blocks; d = args.d_hidden # ---------------------------------------------------------------- # Step 1: Train DFA and capture checkpoint at t0 # ---------------------------------------------------------------- print(f"\n{'='*60}\nTraining DFA baseline (checkpoint at t0={args.t0})\n{'='*60}") model_dfa = ResidualMLP(input_dim, d, 10, L).to(device) Bs, ckpt = train_dfa_get_checkpoint( model_dfa, train_loader, test_loader, device, args.epochs, args.t0, args.lr, args.wd) print(f" Checkpoint acc at t0={args.t0}: {ckpt['acc']:.4f}") # ---------------------------------------------------------------- # Step 2: Define branches # ---------------------------------------------------------------- VEC_SEED = args.seed + 7777 DECAY_WINDOW = 5 def make_vec(): torch.manual_seed(VEC_SEED) return VectorCreditNet(d_hidden=d, s_dim=10).to(device) # constant alpha def fixed_alpha(a): return lambda epoch, t0: a # (name, branch_type, aux_factory, freeze_epoch, alpha_schedule_fn) branches = [ ('continue_DFA', 'dfa', lambda: None, None, fixed_alpha(0.0)), ('blend_random_trainable_alpha075', 'blend', make_vec, None, fixed_alpha(0.75)), ('freeze_after_1_fixed075', 'blend_frozen', make_vec, 1, fixed_alpha(0.75)), ('freeze_after_5_fixed075', 'blend_frozen', make_vec, 5, fixed_alpha(0.75)), ('freeze_after_1_decay_to_025', 'blend_frozen', make_vec, 1, make_alpha_schedule(freeze_epoch=1, initial_alpha=0.75, target_alpha=0.25, decay_window=DECAY_WINDOW)), ('freeze_after_5_decay_to_025', 'blend_frozen', make_vec, 5, make_alpha_schedule(freeze_epoch=5, initial_alpha=0.75, target_alpha=0.25, decay_window=DECAY_WINDOW)), ('freeze_after_1_decay_to_000', 'blend_frozen', make_vec, 1, make_alpha_schedule(freeze_epoch=1, initial_alpha=0.75, target_alpha=0.0, decay_window=DECAY_WINDOW)), ('freeze_after_5_decay_to_000', 'blend_frozen', make_vec, 5, make_alpha_schedule(freeze_epoch=5, initial_alpha=0.75, target_alpha=0.0, decay_window=DECAY_WINDOW)), ] # ---------------------------------------------------------------- # Step 3: Run all branches # ---------------------------------------------------------------- all_results = {} for bname, btype, aux_factory, freeze_ep, alpha_fn in branches: print(f"\n{'='*60}\n{bname}\n{'='*60}") model_b = ResidualMLP(input_dim, d, 10, L).to(device) model_b.load_state_dict(ckpt['model']) aux_net_b = aux_factory() log = run_branch( model_b, aux_net_b, ckpt['Bs'], train_loader, test_loader, device, args.t0, args.epochs, btype, alpha_fn, args.lr, args.lr_fb, args.wd, args.M, branch_name=bname, freeze_epoch=freeze_ep) all_results[bname] = log print(f" {bname} final acc: {log['test_acc'][-1]:.4f}") # ---------------------------------------------------------------- # Step 4: Summary table # ---------------------------------------------------------------- dfa_final = all_results['continue_DFA']['test_acc'][-1] print(f"\n{'='*95}") print("SUMMARY — Phase 10A.8A: Freeze with Alpha Decay") print(f"{'='*95}") print(f"{'Branch':<38} {'@20':>6} {'final':>7} {'diff':>7} " f"{'mG_5:15':>9} {'mr_5:15':>9} {'aeff':>7}") print("-" * 85) for bname, log in all_results.items(): accs = log['test_acc'] idx20 = max(0, 20 - args.t0 - 1) acc20 = accs[idx20] if len(accs) > idx20 else accs[-1] final = accs[-1] diff = final - dfa_final gammas_e = [g for e, g in log['gamma'] if args.t0 < e <= args.t0 + 15] rhos_e = [r for e, r in log['rho'] if args.t0 < e <= args.t0 + 15] aeffs_e = [a for e, a in log['alpha_eff'] if args.t0 < e <= args.t0 + 15] mg = float(np.mean(gammas_e)) if gammas_e else float('nan') mr = float(np.mean(rhos_e)) if rhos_e else float('nan') mae = float(np.mean(aeffs_e)) if aeffs_e else float('nan') print(f"{bname:<38} {acc20:>6.4f} {final:>7.4f} {diff:>+7.4f} " f"{mg:>9.4f} {mr:>9.4f} {mae:>7.3f}") # ---------------------------------------------------------------- # Step 5: Save results # ---------------------------------------------------------------- save_data = {'args': vars(args), 'dfa_ckpt_acc': float(ckpt['acc'])} for bname, log in all_results.items(): save_data[bname] = { 'test_acc': log['test_acc'], 'train_loss': log['train_loss'], 'gamma': log['gamma'], 'rho': log['rho'], 'alpha_eff': log['alpha_eff'], } out_path = os.path.join(args.output_dir, f'freeze_with_decay_t{args.t0}_s{args.seed}.json') with open(out_path, 'w') as f: json.dump(save_data, f, indent=2, default=float) print(f"\nSaved to {out_path}") # ---------------------------------------------------------------- # Step 6: Judgment # ---------------------------------------------------------------- print(f"\n{'='*60}\nJUDGMENT\n{'='*60}") r = {bname: log['test_acc'][-1] for bname, log in all_results.items()} dfa = r['continue_DFA'] ref = r.get('blend_random_trainable_alpha075', float('nan')) f1 = r.get('freeze_after_1_fixed075', float('nan')) f5 = r.get('freeze_after_5_fixed075', float('nan')) f1d25 = r.get('freeze_after_1_decay_to_025', float('nan')) f5d25 = r.get('freeze_after_5_decay_to_025', float('nan')) f1d00 = r.get('freeze_after_1_decay_to_000', float('nan')) f5d00 = r.get('freeze_after_5_decay_to_000', float('nan')) print(f" DFA={dfa:.4f} ref={ref:.4f}") print(f" freeze1_fixed={f1:.4f} freeze5_fixed={f5:.4f}") print(f" freeze1_to025={f1d25:.4f} freeze5_to025={f5d25:.4f}") print(f" freeze1_to000={f1d00:.4f} freeze5_to000={f5d00:.4f}") thr = 0.003 # Fixed vs trainable reference best_fixed = max(f1, f5) if best_fixed > ref - thr: print(f"\n -> Best frozen-fixed ({best_fixed:.4f}) ≈ trainable reference: " "freezing early is sufficient; ongoing Vec training adds no value") elif ref > best_fixed + thr: print(f"\n -> Trainable reference ({ref:.4f}) > best frozen-fixed ({best_fixed:.4f}): " "continuous Vec adaptation helps") # Effect of more training before freeze if f5 > f1 + thr: print(f" -> More Vec training before freeze helps: " f"5ep ({f5:.4f}) > 1ep ({f1:.4f})") else: print(f" -> Freeze timing (1 vs 5 epochs) makes little difference: " f"f1={f1:.4f} f5={f5:.4f}") # Effect of decay on fixed-freeze branches print(f"\n Decay effect (vs fixed075):") for label, fixed_v, d25, d00 in [ ('freeze_after_1', f1, f1d25, f1d00), ('freeze_after_5', f5, f5d25, f5d00)]: print(f" {label}: fixed={fixed_v:.4f} ->0.25={d25:.4f} ->0.00={d00:.4f}") if d25 > fixed_v + thr: print(f" -> decay to 0.25 helps vs fixed ({d25-fixed_v:+.4f})") if d00 > fixed_v + thr: print(f" -> decay to 0.0 (full DFA) helps vs fixed ({d00-fixed_v:+.4f})") if d00 > d25 + thr: print(f" -> faster decay (to 0) better than partial ({d00-d25:+.4f})") elif d25 > d00 + thr: print(f" -> partial decay (to 0.25) better than full decay ({d25-d00:+.4f})") # Overall winner best_name = max(r, key=r.get) print(f"\n Best branch: {best_name} = {r[best_name]:.4f} " f"(+{r[best_name]-dfa:+.4f} vs DFA)") def main(): parser = argparse.ArgumentParser( description='Phase 10A.8A: Freeze with Alpha Decay') parser.add_argument('--num_blocks', type=int, default=4) parser.add_argument('--d_hidden', type=int, default=256) parser.add_argument('--batch_size', type=int, default=128) parser.add_argument('--epochs', type=int, default=100) parser.add_argument('--t0', type=int, default=5) parser.add_argument('--alpha', type=float, default=0.75) parser.add_argument('--lr', type=float, default=1e-3) parser.add_argument('--lr_fb', type=float, default=1e-3) parser.add_argument('--wd', type=float, default=0.01) parser.add_argument('--M', type=int, default=4) parser.add_argument('--seed', type=int, default=42) parser.add_argument('--gpu', type=int, default=2) parser.add_argument('--output_dir', type=str, default='results/freeze_with_decay') args = parser.parse_args() run_experiment(args) if __name__ == '__main__': main()