From 05ccd23154d1e9d090178b9d4d5f2c821711e784 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Thu, 26 Mar 2026 00:07:01 -0500 Subject: Add Phase 9B+9C: periodic refit fails, top-down curriculum neutral MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Phase 9B (periodic refit K=5 R=1 alpha=0.75): 14.0% — Vec starts random, periodic refits insufficient without offline pretraining. Phase 9C (top-down curriculum): last1_vec=30.8%, last2_vec=31.1% vs DFA=31.2%. Near-neutral. Cold-start problem persists even for single-block Vec. Only Phase 9A's offline prefit + blend handoff (+1.5%) works. The key ingredient is offline Vec training on frozen checkpoint features. Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/topdown_curriculum.py | 286 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 286 insertions(+) create mode 100644 experiments/topdown_curriculum.py (limited to 'experiments/topdown_curriculum.py') diff --git a/experiments/topdown_curriculum.py b/experiments/topdown_curriculum.py new file mode 100644 index 0000000..50129c0 --- /dev/null +++ b/experiments/topdown_curriculum.py @@ -0,0 +1,286 @@ +""" +Phase 9C: Top-Down Curriculum. + +DFA as default backbone, but Vec takes over only the last k blocks. +Bottom blocks continue using DFA credit. + +Compare: +- DFA_only (baseline) +- last1_vec_rest_dfa: Vec for last block only, DFA for blocks 0-2 +- last2_vec_rest_dfa: Vec for last 2 blocks, DFA for blocks 0-1 +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import SinusoidalTimeEmbed +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation + + +class VectorCreditNet(nn.Module): + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), transforms.RandomHorizontalFlip(), + transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), transforms.Normalize((0.4914,0.4822,0.4465),(0.2470,0.2435,0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + return (DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True), + DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)) + + +def evaluate(model, test_loader, device): + model.eval() + c, t = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device) + c += (model(x).argmax(1)==y).sum().item(); t += x.size(0) + return c/t + + +def train_topdown(model, train_loader, test_loader, device, args, vec_layers, name): + """ + Train with Vec credit for specified layers, DFA for the rest. + vec_layers: list of layer indices to use Vec credit (e.g., [3] for last1) + """ + d = model.d_hidden; L = model.num_blocks + vec_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + Bs = [torch.randn(d, 10, device=device)/np.sqrt(10) for _ in range(L)] + + block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd) + vec_opt = optim.Adam(vec_net.parameters(), lr=args.lr_fb) + scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=args.epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=args.epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=args.epochs)] + + log = {'test_acc': [], 'train_loss': [], 'gamma': []} + eps = 1e-3 + + for epoch in range(1, args.epochs+1): + model.train(); vec_net.train() + total_loss, correct, total = 0, 0, 0 + + for x, y in train_loader: + x = x.view(x.size(0),-1).to(device); y = y.to(device); batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(-1); e_T[torch.arange(batch),y] -= 1; s = e_T.detach() + hL = hiddens[-1].detach() + + # Train Vec on the layers it's responsible for + if vec_layers: + # Terminal matching (always) + t_L = torch.ones(batch, device=device) + a_term = vec_net(hL, t_L, s) + hL_req = hL.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req)) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() + loss_term = ((a_term - delta_L)**2).sum(-1).mean() + + # Perturbation target (only on Vec layers) + l_train = vec_layers[np.random.randint(0, len(vec_layers))] + h_l = hiddens[l_train].detach() + t_l = torch.full((batch,), l_train/L, device=device) + a_l = vec_net(h_l, t_l, s) + loss_proj = torch.tensor(0.0, device=device) + for _ in range(args.M): + v = torch.randn_like(h_l); v = v/(v.norm(-1,keepdim=True)+1e-8) + with torch.no_grad(): + lp = F.cross_entropy(model.forward_from_layer(h_l+eps*v,l_train),y,reduction='none') + lm = F.cross_entropy(model.forward_from_layer(h_l-eps*v,l_train),y,reduction='none') + g_j = (lp-lm)/(2*eps) + loss_proj = loss_proj + (((a_l*v).sum(-1)-g_j.detach())**2).mean() + loss_proj /= args.M + vloss = loss_term + loss_proj + vec_opt.zero_grad(); vloss.backward() + torch.nn.utils.clip_grad_norm_(vec_net.parameters(), 1.0) + vec_opt.step() + + # Compute credits: Vec for vec_layers, DFA for others + with torch.no_grad(): + vec_credits = {l: vec_net(hiddens[l].detach(), torch.full((batch,),l/L,device=device), s).detach() + for l in vec_layers} + dfa_credits = {l: (e_T @ Bs[l].T).detach() for l in range(L)} + + credits = [] + for l in range(L): + if l in vec_layers: + # Blend Vec + DFA for vec layers + a_vec = vec_credits[l] + a_dfa = dfa_credits[l] + rms_v = (a_vec**2).mean(-1,keepdim=True).sqrt()+1e-6 + rms_d = (a_dfa**2).mean(-1,keepdim=True).sqrt()+1e-6 + credits.append(args.blend_alpha * a_vec/rms_v + (1-args.blend_alpha) * a_dfa/rms_d) + else: + credits.append(dfa_credits[l]) + + # Update head + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + + # Update blocks + for l in range(L): + a = credits[l]; rms = (a**2).mean(-1,keepdim=True).sqrt()+1e-6 + f = model.blocks[l](hiddens[l].detach()) + ll = (f*(a/rms)).sum(-1).mean() + block_opts[l].zero_grad(); ll.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0); block_opts[l].step() + + # Update embed (always DFA credit) + a0 = dfa_credits[0]; rms0 = (a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el = (model.embed(x)*(a0/rms0)).sum(-1).mean() + embed_opt.zero_grad(); el.backward(); embed_opt.step() + + total_loss += loss_val.item()*batch; correct += (logits.argmax(1)==y).sum().item(); total += batch + + for s in scheds: s.step() + test_acc = evaluate(model, test_loader, device) + log['test_acc'].append(test_acc); log['train_loss'].append(total_loss/total) + + if epoch % 10 == 0 or epoch <= 5 or epoch == args.epochs: + print(f" [{name}] Ep {epoch}: acc={test_acc:.4f}") + + return log + + +def train_dfa_only(model, train_loader, test_loader, device, args): + d = model.d_hidden; L = model.num_blocks + Bs = [torch.randn(d,10,device=device)/np.sqrt(10) for _ in range(L)] + block_opts = [optim.AdamW(b.parameters(),lr=args.lr,weight_decay=args.wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(),lr=args.lr,weight_decay=args.wd) + head_opt = optim.AdamW(list(model.out_head.parameters())+list(model.out_ln.parameters()),lr=args.lr,weight_decay=args.wd) + scheds = [optim.lr_scheduler.CosineAnnealingLR(o,T_max=args.epochs) for o in block_opts]+[optim.lr_scheduler.CosineAnnealingLR(embed_opt,T_max=args.epochs),optim.lr_scheduler.CosineAnnealingLR(head_opt,T_max=args.epochs)] + log = {'test_acc':[],'train_loss':[]} + for epoch in range(1,args.epochs+1): + model.train(); tl,c,t=0,0,0 + for x,y in train_loader: + x=x.view(x.size(0),-1).to(device);y=y.to(device);b=x.size(0) + with torch.no_grad(): + lo,hi=model(x,return_hidden=True);lv=F.cross_entropy(lo,y) + eT=lo.softmax(-1);eT[torch.arange(b),y]-=1 + hL=hi[-1].detach() + lo2=F.cross_entropy(model.out_head(model.out_ln(hL)),y) + head_opt.zero_grad();lo2.backward();head_opt.step() + for l in range(L): + a=(eT@Bs[l].T).detach();rm=(a**2).mean(-1,keepdim=True).sqrt()+1e-6 + f=model.blocks[l](hi[l].detach());ll=(f*(a/rm)).sum(-1).mean() + block_opts[l].zero_grad();ll.backward();torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(),1.0);block_opts[l].step() + a0=(eT@Bs[0].T).detach();r0=(a0**2).mean(-1,keepdim=True).sqrt()+1e-6 + el=(model.embed(x)*(a0/r0)).sum(-1).mean() + embed_opt.zero_grad();el.backward();embed_opt.step() + tl+=lv.item()*b;c+=(lo.argmax(1)==y).sum().item();t+=b + for s in scheds:s.step() + ta=evaluate(model,test_loader,device);log['test_acc'].append(ta);log['train_loss'].append(tl/t) + if epoch%10==0 or epoch==1 or epoch==args.epochs:print(f" [DFA] Ep {epoch}: acc={ta:.4f}") + return log + + +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + train_loader, test_loader = get_cifar10(args.batch_size) + input_dim = 32*32*3; L = args.num_blocks + + all_results = {} + + # DFA baseline + print(f"\n{'='*60}\nDFA_only\n{'='*60}") + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + model = ResidualMLP(input_dim, args.d_hidden, 10, L).to(device) + log = train_dfa_only(model, train_loader, test_loader, device, args) + all_results['DFA_only'] = log + + # Top-down configs + configs = [ + ('last1_vec', [L-1]), + ('last2_vec', [L-2, L-1]), + ] + + for cname, vec_layers in configs: + print(f"\n{'='*60}\n{cname}\n{'='*60}") + torch.manual_seed(args.seed); np.random.seed(args.seed); torch.cuda.manual_seed_all(args.seed) + model = ResidualMLP(input_dim, args.d_hidden, 10, L).to(device) + log = train_topdown(model, train_loader, test_loader, device, args, vec_layers, cname) + all_results[cname] = log + + # Summary + print(f"\n{'='*60}\nSUMMARY\n{'='*60}") + dfa_final = all_results['DFA_only']['test_acc'][-1] + print(f"{'Config':<25} {'final':>7} {'diff':>7}") + print("-"*42) + for name, log in all_results.items(): + final = log['test_acc'][-1] + diff = final - dfa_final if name != 'DFA_only' else 0 + print(f"{name:<25} {final:>7.4f} {diff:>+7.4f}") + + out_path = os.path.join(args.output_dir, f'topdown_s{args.seed}.json') + save_data = {n: {'test_acc': l['test_acc'], 'train_loss': l['train_loss']} for n, l in all_results.items()} + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2, default=float) + print(f"\nSaved to {out_path}") + + +def main(): + parser = argparse.ArgumentParser(description='Phase 9C: Top-Down Curriculum') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--M', type=int, default=4) + parser.add_argument('--blend_alpha', type=float, default=0.75) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=3) + parser.add_argument('--output_dir', type=str, default='results/topdown_curriculum') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() -- cgit v1.2.3