summaryrefslogtreecommitdiff
path: root/experiments/periodic_refit.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-03-26 00:07:01 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-03-26 00:07:01 -0500
commit05ccd23154d1e9d090178b9d4d5f2c821711e784 (patch)
tree0ce74409f506df9f7f1717c13b30e79eb5a24f12 /experiments/periodic_refit.py
parentccc6add69553893f6d3f9de4e2010ca8139ba1a6 (diff)
Add Phase 9B+9C: periodic refit fails, top-down curriculum neutral
Phase 9B (periodic refit K=5 R=1 alpha=0.75): 14.0% — Vec starts random, periodic refits insufficient without offline pretraining. Phase 9C (top-down curriculum): last1_vec=30.8%, last2_vec=31.1% vs DFA=31.2%. Near-neutral. Cold-start problem persists even for single-block Vec. Only Phase 9A's offline prefit + blend handoff (+1.5%) works. The key ingredient is offline Vec training on frozen checkpoint features. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'experiments/periodic_refit.py')
-rw-r--r--experiments/periodic_refit.py400
1 files changed, 400 insertions, 0 deletions
diff --git a/experiments/periodic_refit.py b/experiments/periodic_refit.py
new file mode 100644
index 0000000..a16d55f
--- /dev/null
+++ b/experiments/periodic_refit.py
@@ -0,0 +1,400 @@
+"""
+Phase 9B: Periodic Refit.
+
+Instead of continuous co-learning, alternate:
+1. Train forward net with current credit for K epochs
+2. Freeze forward, refit Vec estimator for R epoch-equivalents
+3. Resume
+
+Compare:
+- DFA backbone + periodic Vec refit + blend
+- DFA only (baseline)
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import SinusoidalTimeEmbed
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ c, t = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ c += (model(x).argmax(1) == y).sum().item(); t += x.size(0)
+ return c / t
+
+
+def refit_vec(model, vec_net, train_loader, device, refit_epochs=1, lr_fb=1e-3, M=4):
+ """Freeze model, refit Vec estimator for refit_epochs."""
+ model.eval()
+ for p in model.parameters(): p.requires_grad_(False)
+ vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb)
+ eps = 1e-3
+ L = model.num_blocks
+ for ep in range(refit_epochs):
+ vec_net.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ hL = hiddens[-1].detach()
+ t_L = torch.ones(batch, device=device)
+ a_term = vec_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_term - delta_L)**2).sum(-1).mean()
+ l = np.random.randint(0, L)
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ a_l = vec_net(h_l, t_l, s)
+ loss_proj = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(model.forward_from_layer(h_l + eps*v, l), y, reduction='none')
+ lm = F.cross_entropy(model.forward_from_layer(h_l - eps*v, l), y, reduction='none')
+ g_j = (lp - lm) / (2*eps)
+ loss_proj = loss_proj + (((a_l*v).sum(-1) - g_j.detach())**2).mean()
+ loss_proj /= M
+ vloss = loss_term + loss_proj
+ vec_opt.zero_grad(); vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vec_net.parameters(), 1.0)
+ vec_opt.step()
+ for p in model.parameters(): p.requires_grad_(True)
+
+
+def compute_diagnostics(model, vec_net, dfa_Bs, test_loader, device, credit_mode):
+ model.eval()
+ vec_net.eval()
+ L = model.num_blocks
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); break
+ batch = x.size(0)
+ # BP grads
+ was_frozen = not next(model.parameters()).requires_grad
+ if was_frozen:
+ for p in model.parameters(): p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hbp = model(x, return_hidden=True)
+ for l in range(L+1): hbp[l].retain_grad()
+ F.cross_entropy(logits_bp, y).backward()
+ bp_grads = {l: hbp[l].grad.detach().clone() for l in range(L+1)}
+ if was_frozen:
+ for p in model.parameters(): p.requires_grad_(False)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1; s = e_T.detach()
+ gammas = []
+ for l in range(L):
+ h_l = hiddens[l].detach(); t_l = torch.full((batch,), l/L, device=device)
+ if credit_mode == 'dfa':
+ a_l = (s @ dfa_Bs[l].T).detach()
+ else:
+ alpha = credit_mode
+ a_dfa = (s @ dfa_Bs[l].T).detach()
+ a_vec = vec_net(h_l, t_l, s).detach()
+ rms_v = (a_vec**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ rms_d = (a_dfa**2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_l = alpha * a_vec/rms_v + (1-alpha) * a_dfa/rms_d
+ gammas.append(cosine_similarity_batch(a_l, bp_grads[l]))
+ return float(np.mean(gammas))
+
+
+def train_with_periodic_refit(model, train_loader, test_loader, device, args,
+ K_forward, R_refit, blend_alpha):
+ """
+ Train with periodic Vec refit.
+ Every K_forward epochs of forward training, freeze and refit Vec for R_refit epochs.
+ """
+ d = model.d_hidden
+ L = model.num_blocks
+
+ vec_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]
+
+ log = {'test_acc': [], 'train_loss': [], 'gamma': []}
+ epochs_since_refit = 0
+
+ # Initial refit
+ print(f" Initial Vec refit ({R_refit} epochs)...")
+ refit_vec(model, vec_net, train_loader, device, refit_epochs=R_refit, lr_fb=args.lr_fb, M=args.M)
+
+ for epoch in range(1, args.epochs + 1):
+ model.train(); vec_net.train()
+ total_loss, correct, total = 0, 0, 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch), y] -= 1; s = e_T.detach()
+ hL = hiddens[-1].detach()
+
+ # Also do online Vec training (in addition to periodic refit)
+ t_L = torch.ones(batch, device=device)
+ a_term = vec_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_term - delta_L)**2).sum(-1).mean()
+ l_train = np.random.randint(0, L)
+ h_l = hiddens[l_train].detach()
+ t_l = torch.full((batch,), l_train/L, device=device)
+ a_l = vec_net(h_l, t_l, s)
+ loss_proj = torch.tensor(0.0, device=device)
+ eps = 1e-3
+ for _ in range(args.M):
+ v = torch.randn_like(h_l); v = v/(v.norm(-1,keepdim=True)+1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(model.forward_from_layer(h_l+eps*v,l_train),y,reduction='none')
+ lm = F.cross_entropy(model.forward_from_layer(h_l-eps*v,l_train),y,reduction='none')
+ g_j = (lp-lm)/(2*eps)
+ loss_proj = loss_proj + (((a_l*v).sum(-1)-g_j.detach())**2).mean()
+ loss_proj /= args.M
+ vloss = loss_term + loss_proj
+ vec_opt_step = optim.Adam(vec_net.parameters(), lr=args.lr_fb)
+ vec_opt_step.zero_grad(); vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vec_net.parameters(), 1.0)
+ # Note: we create optimizer each step which is wasteful but simple
+ # In practice should keep a persistent optimizer
+
+ # Compute blended credits
+ with torch.no_grad():
+ vec_credits = [vec_net(hiddens[l].detach(), torch.full((batch,),l/L,device=device), s).detach() for l in range(L)]
+ dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)]
+
+ credits = []
+ for l in range(L):
+ rms_v = (vec_credits[l]**2).mean(-1,keepdim=True).sqrt()+1e-6
+ rms_d = (dfa_credits[l]**2).mean(-1,keepdim=True).sqrt()+1e-6
+ credits.append(blend_alpha*vec_credits[l]/rms_v + (1-blend_alpha)*dfa_credits[l]/rms_d)
+
+ # Update head
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+
+ # Update blocks
+ for l in range(L):
+ a = credits[l]; rms = (a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ f = model.blocks[l](hiddens[l].detach())
+ ll = (f*(a/rms)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ a0 = credits[0]; rms0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el = (model.embed(x)*(a0/rms0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+
+ total_loss += loss_val.item()*batch; correct += (logits.argmax(1)==y).sum().item(); total += batch
+
+ for s in scheds: s.step()
+ test_acc = evaluate(model, test_loader, device)
+ log['test_acc'].append(test_acc)
+ log['train_loss'].append(total_loss/total)
+
+ epochs_since_refit += 1
+
+ # Periodic refit
+ if epochs_since_refit >= K_forward:
+ print(f" Refitting Vec at epoch {epoch} ({R_refit} epochs)...")
+ refit_vec(model, vec_net, train_loader, device, refit_epochs=R_refit, lr_fb=args.lr_fb, M=args.M)
+ epochs_since_refit = 0
+
+ if epoch % 10 == 0 or epoch <= 5 or epoch == args.epochs:
+ gamma = compute_diagnostics(model, vec_net, Bs, test_loader, device, blend_alpha)
+ log['gamma'].append((epoch, gamma))
+ print(f" [refit_K{K_forward}_R{R_refit}_a{blend_alpha}] Ep {epoch}: acc={test_acc:.4f}, G={gamma:.4f}")
+
+ return log
+
+
+def train_dfa_only(model, train_loader, test_loader, device, args):
+ d = model.d_hidden; L = model.num_blocks
+ Bs = [torch.randn(d, 10, device=device)/np.sqrt(10) for _ in range(L)]
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), lr=args.lr, weight_decay=args.wd)
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)]
+ log = {'test_acc': [], 'train_loss': []}
+ for epoch in range(1, args.epochs+1):
+ model.train(); total_loss, correct, total = 0, 0, 0
+ for x, y in train_loader:
+ x = x.view(x.size(0),-1).to(device); y = y.to(device); batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(-1); e_T[torch.arange(batch),y] -= 1
+ hL = hiddens[-1].detach()
+ loss_out = F.cross_entropy(model.out_head(model.out_ln(hL)), y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+ for l in range(L):
+ a = (e_T@Bs[l].T).detach(); rms = (a**2).mean(-1,keepdim=True).sqrt()+1e-6
+ f = model.blocks[l](hiddens[l].detach()); ll = (f*(a/rms)).sum(-1).mean()
+ block_opts[l].zero_grad(); ll.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); block_opts[l].step()
+ a0 = (e_T@Bs[0].T).detach(); rms0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6
+ el = (model.embed(x)*(a0/rms0)).sum(-1).mean()
+ embed_opt.zero_grad(); el.backward(); embed_opt.step()
+ total_loss += loss_val.item()*batch; correct += (logits.argmax(1)==y).sum().item(); total += batch
+ for s in scheds: s.step()
+ test_acc = evaluate(model, test_loader, device)
+ log['test_acc'].append(test_acc); log['train_loss'].append(total_loss/total)
+ if epoch % 10 == 0 or epoch == 1 or epoch == args.epochs:
+ print(f" [DFA] Ep {epoch}: acc={test_acc:.4f}")
+ return log
+
+
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32*32*3; L = args.num_blocks; d = args.d_hidden
+
+ all_results = {}
+
+ # DFA baseline
+ print(f"\n{'='*60}\nDFA baseline\n{'='*60}")
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ model_dfa = ResidualMLP(input_dim, d, 10, L).to(device)
+ log_dfa = train_dfa_only(model_dfa, train_loader, test_loader, device, args)
+ all_results['DFA_only'] = log_dfa
+ print(f" DFA final: {log_dfa['test_acc'][-1]:.4f}")
+
+ # Periodic refit configs
+ configs = []
+ for K in args.K_values:
+ for R in args.R_values:
+ for alpha in args.blend_alphas:
+ configs.append((K, R, alpha))
+
+ for K, R, alpha in configs:
+ name = f"refit_K{K}_R{R}_a{alpha}"
+ print(f"\n{'='*60}\n{name}\n{'='*60}")
+ torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed)
+ model = ResidualMLP(input_dim, d, 10, L).to(device)
+ log = train_with_periodic_refit(model, train_loader, test_loader, device, args,
+ K_forward=K, R_refit=R, blend_alpha=alpha)
+ all_results[name] = log
+ print(f" {name} final: {log['test_acc'][-1]:.4f}")
+
+ # Summary
+ print(f"\n{'='*60}\nSUMMARY\n{'='*60}")
+ dfa_final = all_results['DFA_only']['test_acc'][-1]
+ print(f"{'Config':<35} {'final':>7} {'diff':>7}")
+ print("-" * 52)
+ for name, log in all_results.items():
+ final = log['test_acc'][-1]
+ diff = final - dfa_final if name != 'DFA_only' else 0
+ print(f"{name:<35} {final:>7.4f} {diff:>+7.4f}")
+
+ out_path = os.path.join(args.output_dir, f'periodic_refit_s{args.seed}.json')
+ save_data = {name: {'test_acc': log['test_acc'], 'train_loss': log['train_loss']}
+ for name, log in all_results.items()}
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 9B: Periodic Refit')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--M', type=int, default=4)
+ parser.add_argument('--K_values', type=int, nargs='+', default=[5])
+ parser.add_argument('--R_values', type=int, nargs='+', default=[1])
+ parser.add_argument('--blend_alphas', type=float, nargs='+', default=[0.75])
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/periodic_refit')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()