From 3ec9a5cd63b4578999d89b49f5223024a1acb723 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 25 Mar 2026 14:23:13 -0500 Subject: =?UTF-8?q?Add=20Phase=208:=20schedule=20timing=20test=20=E2=80=94?= =?UTF-8?q?=20online=20co-learning=20is=20the=20remaining=20bottleneck?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Vec_only_from_0: 15.4% (cold-start failure, can't learn credit on random features) DFA_only: 31.2% (remains best non-BP method) DFA_then_Vec_T20: 12.9% (switching to Vec destroys DFA-built features) Vec_T5_then_DFA: 26.6% (partial recovery but still worse than pure DFA) Phase 7A's early-window finding doesn't transfer: it required offline-trained Vec on frozen features. Online Vec estimator faces cold-start paradox — needs structured features to learn credit, but structured features need good credit to form. Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/online_schedule_timing.py | 486 ++++++++++++++++++++++++++++++++++ 1 file changed, 486 insertions(+) create mode 100644 experiments/online_schedule_timing.py (limited to 'experiments/online_schedule_timing.py') diff --git a/experiments/online_schedule_timing.py b/experiments/online_schedule_timing.py new file mode 100644 index 0000000..f36627b --- /dev/null +++ b/experiments/online_schedule_timing.py @@ -0,0 +1,486 @@ +""" +Phase 8: Schedule Hypothesis Test. + +Test whether high-quality local credit should be used from epoch 0 +rather than after a DFA warmup period. + +Schedules: +1. DFA_only: full DFA baseline +2. Vec_only_from_0: Vec from epoch 0, no warmup +3. Vec_early_then_DFA_T{k}: Vec for first k epochs, then DFA +4. DFA_then_Vec_T{k}: DFA for first k epochs, then Vec +5. Hybrid_blend: alpha*Vec + (1-alpha)*DFA from epoch 0 +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import SinusoidalTimeEmbed +from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test + + +class VectorCreditNet(nn.Module): + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def evaluate(model, test_loader, device): + model.eval() + c, t = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + c += (model(x).argmax(1) == y).sum().item(); t += x.size(0) + return c / t + + +def compute_epoch_diagnostics(model, vector_net, dfa_Bs, test_loader, device, credit_mode): + """Compute Gamma and rho for current epoch's credit source.""" + model.eval() + if vector_net is not None: + vector_net.eval() + L = model.num_blocks + d = model.d_hidden + + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device); break + batch = x.size(0) + + # BP gradients (eval only) + logits_bp, hbp = model(x, return_hidden=True) + for l in range(L + 1): hbp[l].retain_grad() + F.cross_entropy(logits_bp, y).backward() + bp_grads = {l: hbp[l].grad.detach().clone() for l in range(L + 1)} + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + gammas, rhos = [], [] + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + if credit_mode == 'dfa': + a_l = (s @ dfa_Bs[l].T).detach() + elif credit_mode == 'vec': + a_l = vector_net(h_l, t_l, s).detach() + else: # blend + a_dfa = (s @ dfa_Bs[l].T).detach() + a_vec = vector_net(h_l, t_l, s).detach() + alpha = credit_mode # numeric blend factor + rms_v = (a_vec ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + rms_d = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_l = alpha * a_vec / rms_v + (1 - alpha) * a_dfa / rms_d + + gammas.append(cosine_similarity_batch(a_l, bp_grads[l])) + + def make_fwd(sl): + def f(h): + with torch.no_grad(): + c = h + for i in range(sl, L): c = c + model.blocks[i](c) + return F.cross_entropy(model.out_head(model.out_ln(c)), y, reduction='none') + return f + rhos.append(perturbation_correlation(h_l, a_l, make_fwd(l), epsilon=1e-3, M=16)) + + return float(np.mean(gammas)), float(np.mean(rhos)) + + +# ============================================================================= +# Unified training loop with configurable credit schedule +# ============================================================================= +def train_with_schedule(model, train_loader, test_loader, device, args, schedule): + """ + Train with a configurable credit schedule. + + schedule: dict with keys: + 'name': str + 'type': one of 'dfa_only', 'vec_only', 'vec_then_dfa', 'dfa_then_vec', 'blend' + 'switch_epoch': int (for vec_then_dfa, dfa_then_vec) + 'blend_alpha': float (for blend) + """ + d = model.d_hidden + L = model.num_blocks + epochs = args.epochs + sname = schedule['name'] + stype = schedule['type'] + + # Vector net (always created, trained when active) + vector_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + + block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in model.blocks] + embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd) + head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()), + lr=args.lr, weight_decay=args.wd) + vec_opt = optim.Adam(vector_net.parameters(), lr=args.lr_fb) + + scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \ + [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs), + optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)] + + eps_pert = args.pert_eps + M = args.M + + log = {'train_loss': [], 'test_acc': [], 'gamma': [], 'rho': [], 'credit_mode': []} + + for epoch in range(1, epochs + 1): + # Determine credit mode for this epoch + if stype == 'dfa_only': + use_vec = False + use_dfa = True + credit_mode_tag = 'dfa' + elif stype == 'vec_only': + use_vec = True + use_dfa = False + credit_mode_tag = 'vec' + elif stype == 'vec_then_dfa': + T = schedule['switch_epoch'] + if epoch <= T: + use_vec = True; use_dfa = False; credit_mode_tag = 'vec' + else: + use_vec = False; use_dfa = True; credit_mode_tag = 'dfa' + elif stype == 'dfa_then_vec': + T = schedule['switch_epoch'] + if epoch <= T: + use_vec = False; use_dfa = True; credit_mode_tag = 'dfa' + else: + use_vec = True; use_dfa = False; credit_mode_tag = 'vec' + elif stype == 'blend': + use_vec = True; use_dfa = True + credit_mode_tag = f"blend_{schedule['blend_alpha']:.2f}" + else: + raise ValueError(f"Unknown schedule type: {stype}") + + # Always train vec net when it's active (or will be active soon) + train_vec = use_vec or (stype == 'dfa_then_vec' and epoch >= schedule['switch_epoch'] - 5) + + model.train() + vector_net.train() + total_loss, correct, total = 0, 0, 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + loss_val = F.cross_entropy(logits, y) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + + hL = hiddens[-1].detach() + + # --- Train vector net (when needed) --- + if train_vec: + # Terminal matching + t_L = torch.ones(batch, device=device) + a_term = vector_net(hL, t_L, s) + hL_req = hL.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req)) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() + loss_term = ((a_term - delta_L) ** 2).sum(-1).mean() + + # Perturbation target (subsample 1 layer) + l_train = np.random.randint(0, L) + h_l = hiddens[l_train].detach() + t_l = torch.full((batch,), l_train / L, device=device) + a_l = vector_net(h_l, t_l, s) + loss_proj = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy(model.forward_from_layer(h_l + eps_pert * v, l_train), y, reduction='none') + lm = F.cross_entropy(model.forward_from_layer(h_l - eps_pert * v, l_train), y, reduction='none') + g_j = (lp - lm) / (2 * eps_pert) + loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach()) ** 2).mean() + loss_proj /= M + + vloss = loss_term + loss_proj + vec_opt.zero_grad(); vloss.backward() + torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0) + vec_opt.step() + + # --- Compute credits --- + with torch.no_grad(): + vec_credits = [vector_net(hiddens[l].detach(), + torch.full((batch,), l / L, device=device), s).detach() for l in range(L)] + dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)] + + # Select credits based on schedule + credits = [] + for l in range(L): + if use_vec and not use_dfa: + # Pure vec — use raw credit (no normalization) + a = vec_credits[l] + elif use_dfa and not use_vec: + a = dfa_credits[l] + else: + # Blend + alpha = schedule.get('blend_alpha', 0.5) + rms_v = (vec_credits[l] ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + rms_d = (dfa_credits[l] ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a = alpha * vec_credits[l] / rms_v + (1 - alpha) * dfa_credits[l] / rms_d + credits.append(a) + + # --- Update output head --- + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_opt.zero_grad(); loss_out.backward(); head_opt.step() + + # --- Update blocks --- + for l in range(L): + h_l_det = hiddens[l].detach() + a = credits[l] + rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a_norm = a / rms + f_l = model.blocks[l](h_l_det) + local_loss = (f_l * a_norm).sum(-1).mean() + block_opts[l].zero_grad(); local_loss.backward() + torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0) + block_opts[l].step() + + # --- Update embedding --- + a0 = credits[0] + rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + embed_loss = (model.embed(x) * (a0 / rms0)).sum(-1).mean() + embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step() + + total_loss += loss_val.item() * batch + correct += (logits.argmax(1) == y).sum().item() + total += batch + + for sch in scheds: + sch.step() + + train_loss = total_loss / total + test_acc = evaluate(model, test_loader, device) + log['train_loss'].append(train_loss) + log['test_acc'].append(test_acc) + log['credit_mode'].append(credit_mode_tag) + + # Diagnostics every 5 epochs (or at key epochs) + if epoch % 5 == 0 or epoch <= 5 or epoch == epochs: + gamma, rho = compute_epoch_diagnostics( + model, vector_net, Bs, test_loader, device, + 'vec' if use_vec and not use_dfa else ('dfa' if use_dfa and not use_vec else schedule.get('blend_alpha', 0.5)) + ) + log['gamma'].append((epoch, gamma)) + log['rho'].append((epoch, rho)) + else: + gamma, rho = None, None + + if epoch % 10 == 0 or epoch <= 5 or epoch == epochs: + g_str = f", Gamma={gamma:.4f}, rho={rho:.4f}" if gamma is not None else "" + print(f" [{sname}] Ep {epoch} ({credit_mode_tag}): loss={train_loss:.4f}, " + f"test={test_acc:.4f}{g_str}") + + return log, vector_net, Bs + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + train_loader, test_loader = get_cifar10(args.batch_size) + input_dim = 32 * 32 * 3 + L = args.num_blocks + d = args.d_hidden + + # Define schedules + schedules = [] + for sname in args.schedules: + if sname == 'DFA_only': + schedules.append({'name': 'DFA_only', 'type': 'dfa_only'}) + elif sname == 'Vec_only_from_0': + schedules.append({'name': 'Vec_only_from_0', 'type': 'vec_only'}) + elif sname.startswith('Vec_early_then_DFA_T'): + T = int(sname.split('T')[1]) + schedules.append({'name': sname, 'type': 'vec_then_dfa', 'switch_epoch': T}) + elif sname.startswith('DFA_then_Vec_T'): + T = int(sname.split('T')[1]) + schedules.append({'name': sname, 'type': 'dfa_then_vec', 'switch_epoch': T}) + elif sname.startswith('Hybrid_blend_'): + alpha = float(sname.split('_')[-1]) + schedules.append({'name': sname, 'type': 'blend', 'blend_alpha': alpha}) + else: + raise ValueError(f"Unknown schedule: {sname}") + + all_results = {} + + for schedule in schedules: + sname = schedule['name'] + print(f"\n{'='*60}") + print(f"Schedule: {sname}") + print(f"{'='*60}") + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + model = ResidualMLP(input_dim, d, 10, L).to(device) + log, vec_net, Bs = train_with_schedule(model, train_loader, test_loader, device, args, schedule) + all_results[sname] = log + + # ========================================================= + # Summary table + # ========================================================= + print(f"\n{'='*100}") + print("SUMMARY") + print(f"{'='*100}") + + # Extract key metrics + print(f"\n{'Schedule':<30} {'acc@5':>7} {'acc@10':>7} {'acc@20':>7} {'acc@50':>7} {'final':>7} " + f"{'mGamma[0:20]':>13} {'mRho[0:20]':>12}") + print("-" * 100) + + for sname, log in all_results.items(): + accs = log['test_acc'] + acc5 = accs[4] if len(accs) >= 5 else accs[-1] + acc10 = accs[9] if len(accs) >= 10 else accs[-1] + acc20 = accs[19] if len(accs) >= 20 else accs[-1] + acc50 = accs[49] if len(accs) >= 50 else accs[-1] + final = accs[-1] + + # Mean Gamma/rho for epochs 1-20 + gammas_early = [g for e, g in log['gamma'] if e <= 20] + rhos_early = [r for e, r in log['rho'] if e <= 20] + mg = np.mean(gammas_early) if gammas_early else float('nan') + mr = np.mean(rhos_early) if rhos_early else float('nan') + + print(f"{sname:<30} {acc5:>7.4f} {acc10:>7.4f} {acc20:>7.4f} {acc50:>7.4f} {final:>7.4f} " + f"{mg:>13.4f} {mr:>12.4f}") + + # AUC early benefit + print(f"\nEarly accuracy AUC (sum of acc for epochs 1-20):") + for sname, log in all_results.items(): + auc = sum(log['test_acc'][:20]) + print(f" {sname:<30}: AUC_acc(0,20) = {auc:.2f}") + + # Save + save_data = {} + for sname, log in all_results.items(): + save_data[sname] = { + 'test_acc': log['test_acc'], + 'train_loss': log['train_loss'], + 'gamma': log['gamma'], + 'rho': log['rho'], + 'credit_mode': log['credit_mode'], + } + + out_path = os.path.join(args.output_dir, f'schedules_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2, default=float) + print(f"\nSaved to {out_path}") + + # ========================================================= + # Judgment + # ========================================================= + print(f"\n{'='*60}") + print("JUDGMENT") + print(f"{'='*60}") + + if 'Vec_only_from_0' in all_results and 'DFA_only' in all_results: + vec0_acc20 = all_results['Vec_only_from_0']['test_acc'][19] if len(all_results['Vec_only_from_0']['test_acc']) >= 20 else 0 + dfa_acc20 = all_results['DFA_only']['test_acc'][19] if len(all_results['DFA_only']['test_acc']) >= 20 else 0 + vec0_final = all_results['Vec_only_from_0']['test_acc'][-1] + dfa_final = all_results['DFA_only']['test_acc'][-1] + + print(f" Vec_from_0 acc@20={vec0_acc20:.4f} vs DFA acc@20={dfa_acc20:.4f}: " + f"{'Vec better' if vec0_acc20 > dfa_acc20 else 'DFA better'}") + print(f" Vec_from_0 final={vec0_final:.4f} vs DFA final={dfa_final:.4f}: " + f"{'Vec better' if vec0_final > dfa_final else 'DFA better'}") + + if 'DFA_then_Vec_T20' in all_results and 'Vec_only_from_0' in all_results: + late_final = all_results['DFA_then_Vec_T20']['test_acc'][-1] + early_final = all_results['Vec_only_from_0']['test_acc'][-1] + print(f" Vec_from_0 final={early_final:.4f} vs DFA_then_Vec_T20 final={late_final:.4f}") + if early_final > late_final + 0.005: + print(f" -> WARMUP TIMING HYPOTHESIS SUPPORTED: early Vec is better") + elif abs(early_final - late_final) <= 0.005: + print(f" -> INCONCLUSIVE: similar final accuracy") + else: + print(f" -> WARMUP TIMING HYPOTHESIS NOT SUPPORTED") + + +def main(): + parser = argparse.ArgumentParser(description='Phase 8: Schedule Hypothesis Test') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--epochs', type=int, default=100) + parser.add_argument('--lr', type=float, default=1e-3) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--M', type=int, default=4) + parser.add_argument('--pert_eps', type=float, default=1e-3) + parser.add_argument('--schedules', type=str, nargs='+', + default=['DFA_only', 'Vec_only_from_0', 'Vec_early_then_DFA_T5', 'DFA_then_Vec_T20']) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=3) + parser.add_argument('--output_dir', type=str, default='results/schedule_timing') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() -- cgit v1.2.3