summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
Diffstat (limited to 'experiments')
-rw-r--r--experiments/online_schedule_timing.py486
1 files changed, 486 insertions, 0 deletions
diff --git a/experiments/online_schedule_timing.py b/experiments/online_schedule_timing.py
new file mode 100644
index 0000000..f36627b
--- /dev/null
+++ b/experiments/online_schedule_timing.py
@@ -0,0 +1,486 @@
+"""
+Phase 8: Schedule Hypothesis Test.
+
+Test whether high-quality local credit should be used from epoch 0
+rather than after a DFA warmup period.
+
+Schedules:
+1. DFA_only: full DFA baseline
+2. Vec_only_from_0: Vec from epoch 0, no warmup
+3. Vec_early_then_DFA_T{k}: Vec for first k epochs, then DFA
+4. DFA_then_Vec_T{k}: DFA for first k epochs, then Vec
+5. Hybrid_blend: alpha*Vec + (1-alpha)*DFA from epoch 0
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import SinusoidalTimeEmbed
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate(model, test_loader, device):
+ model.eval()
+ c, t = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ c += (model(x).argmax(1) == y).sum().item(); t += x.size(0)
+ return c / t
+
+
+def compute_epoch_diagnostics(model, vector_net, dfa_Bs, test_loader, device, credit_mode):
+ """Compute Gamma and rho for current epoch's credit source."""
+ model.eval()
+ if vector_net is not None:
+ vector_net.eval()
+ L = model.num_blocks
+ d = model.d_hidden
+
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device); break
+ batch = x.size(0)
+
+ # BP gradients (eval only)
+ logits_bp, hbp = model(x, return_hidden=True)
+ for l in range(L + 1): hbp[l].retain_grad()
+ F.cross_entropy(logits_bp, y).backward()
+ bp_grads = {l: hbp[l].grad.detach().clone() for l in range(L + 1)}
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ gammas, rhos = [], []
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+
+ if credit_mode == 'dfa':
+ a_l = (s @ dfa_Bs[l].T).detach()
+ elif credit_mode == 'vec':
+ a_l = vector_net(h_l, t_l, s).detach()
+ else: # blend
+ a_dfa = (s @ dfa_Bs[l].T).detach()
+ a_vec = vector_net(h_l, t_l, s).detach()
+ alpha = credit_mode # numeric blend factor
+ rms_v = (a_vec ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ rms_d = (a_dfa ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_l = alpha * a_vec / rms_v + (1 - alpha) * a_dfa / rms_d
+
+ gammas.append(cosine_similarity_batch(a_l, bp_grads[l]))
+
+ def make_fwd(sl):
+ def f(h):
+ with torch.no_grad():
+ c = h
+ for i in range(sl, L): c = c + model.blocks[i](c)
+ return F.cross_entropy(model.out_head(model.out_ln(c)), y, reduction='none')
+ return f
+ rhos.append(perturbation_correlation(h_l, a_l, make_fwd(l), epsilon=1e-3, M=16))
+
+ return float(np.mean(gammas)), float(np.mean(rhos))
+
+
+# =============================================================================
+# Unified training loop with configurable credit schedule
+# =============================================================================
+def train_with_schedule(model, train_loader, test_loader, device, args, schedule):
+ """
+ Train with a configurable credit schedule.
+
+ schedule: dict with keys:
+ 'name': str
+ 'type': one of 'dfa_only', 'vec_only', 'vec_then_dfa', 'dfa_then_vec', 'blend'
+ 'switch_epoch': int (for vec_then_dfa, dfa_then_vec)
+ 'blend_alpha': float (for blend)
+ """
+ d = model.d_hidden
+ L = model.num_blocks
+ epochs = args.epochs
+ sname = schedule['name']
+ stype = schedule['type']
+
+ # Vector net (always created, trained when active)
+ vector_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ block_opts = [optim.AdamW(b.parameters(), lr=args.lr, weight_decay=args.wd) for b in model.blocks]
+ embed_opt = optim.AdamW(model.embed.parameters(), lr=args.lr, weight_decay=args.wd)
+ head_opt = optim.AdamW(list(model.out_head.parameters()) + list(model.out_ln.parameters()),
+ lr=args.lr, weight_decay=args.wd)
+ vec_opt = optim.Adam(vector_net.parameters(), lr=args.lr_fb)
+
+ scheds = [optim.lr_scheduler.CosineAnnealingLR(o, T_max=epochs) for o in block_opts] + \
+ [optim.lr_scheduler.CosineAnnealingLR(embed_opt, T_max=epochs),
+ optim.lr_scheduler.CosineAnnealingLR(head_opt, T_max=epochs)]
+
+ eps_pert = args.pert_eps
+ M = args.M
+
+ log = {'train_loss': [], 'test_acc': [], 'gamma': [], 'rho': [], 'credit_mode': []}
+
+ for epoch in range(1, epochs + 1):
+ # Determine credit mode for this epoch
+ if stype == 'dfa_only':
+ use_vec = False
+ use_dfa = True
+ credit_mode_tag = 'dfa'
+ elif stype == 'vec_only':
+ use_vec = True
+ use_dfa = False
+ credit_mode_tag = 'vec'
+ elif stype == 'vec_then_dfa':
+ T = schedule['switch_epoch']
+ if epoch <= T:
+ use_vec = True; use_dfa = False; credit_mode_tag = 'vec'
+ else:
+ use_vec = False; use_dfa = True; credit_mode_tag = 'dfa'
+ elif stype == 'dfa_then_vec':
+ T = schedule['switch_epoch']
+ if epoch <= T:
+ use_vec = False; use_dfa = True; credit_mode_tag = 'dfa'
+ else:
+ use_vec = True; use_dfa = False; credit_mode_tag = 'vec'
+ elif stype == 'blend':
+ use_vec = True; use_dfa = True
+ credit_mode_tag = f"blend_{schedule['blend_alpha']:.2f}"
+ else:
+ raise ValueError(f"Unknown schedule type: {stype}")
+
+ # Always train vec net when it's active (or will be active soon)
+ train_vec = use_vec or (stype == 'dfa_then_vec' and epoch >= schedule['switch_epoch'] - 5)
+
+ model.train()
+ vector_net.train()
+ total_loss, correct, total = 0, 0, 0
+
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ loss_val = F.cross_entropy(logits, y)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ hL = hiddens[-1].detach()
+
+ # --- Train vector net (when needed) ---
+ if train_vec:
+ # Terminal matching
+ t_L = torch.ones(batch, device=device)
+ a_term = vector_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_term - delta_L) ** 2).sum(-1).mean()
+
+ # Perturbation target (subsample 1 layer)
+ l_train = np.random.randint(0, L)
+ h_l = hiddens[l_train].detach()
+ t_l = torch.full((batch,), l_train / L, device=device)
+ a_l = vector_net(h_l, t_l, s)
+ loss_proj = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(model.forward_from_layer(h_l + eps_pert * v, l_train), y, reduction='none')
+ lm = F.cross_entropy(model.forward_from_layer(h_l - eps_pert * v, l_train), y, reduction='none')
+ g_j = (lp - lm) / (2 * eps_pert)
+ loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach()) ** 2).mean()
+ loss_proj /= M
+
+ vloss = loss_term + loss_proj
+ vec_opt.zero_grad(); vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0)
+ vec_opt.step()
+
+ # --- Compute credits ---
+ with torch.no_grad():
+ vec_credits = [vector_net(hiddens[l].detach(),
+ torch.full((batch,), l / L, device=device), s).detach() for l in range(L)]
+ dfa_credits = [(e_T @ Bs[l].T).detach() for l in range(L)]
+
+ # Select credits based on schedule
+ credits = []
+ for l in range(L):
+ if use_vec and not use_dfa:
+ # Pure vec — use raw credit (no normalization)
+ a = vec_credits[l]
+ elif use_dfa and not use_vec:
+ a = dfa_credits[l]
+ else:
+ # Blend
+ alpha = schedule.get('blend_alpha', 0.5)
+ rms_v = (vec_credits[l] ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ rms_d = (dfa_credits[l] ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a = alpha * vec_credits[l] / rms_v + (1 - alpha) * dfa_credits[l] / rms_d
+ credits.append(a)
+
+ # --- Update output head ---
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_opt.zero_grad(); loss_out.backward(); head_opt.step()
+
+ # --- Update blocks ---
+ for l in range(L):
+ h_l_det = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l_det)
+ local_loss = (f_l * a_norm).sum(-1).mean()
+ block_opts[l].zero_grad(); local_loss.backward()
+ torch.nn.utils.clip_grad_norm_(model.blocks[l].parameters(), 1.0)
+ block_opts[l].step()
+
+ # --- Update embedding ---
+ a0 = credits[0]
+ rms0 = (a0 ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ embed_loss = (model.embed(x) * (a0 / rms0)).sum(-1).mean()
+ embed_opt.zero_grad(); embed_loss.backward(); embed_opt.step()
+
+ total_loss += loss_val.item() * batch
+ correct += (logits.argmax(1) == y).sum().item()
+ total += batch
+
+ for sch in scheds:
+ sch.step()
+
+ train_loss = total_loss / total
+ test_acc = evaluate(model, test_loader, device)
+ log['train_loss'].append(train_loss)
+ log['test_acc'].append(test_acc)
+ log['credit_mode'].append(credit_mode_tag)
+
+ # Diagnostics every 5 epochs (or at key epochs)
+ if epoch % 5 == 0 or epoch <= 5 or epoch == epochs:
+ gamma, rho = compute_epoch_diagnostics(
+ model, vector_net, Bs, test_loader, device,
+ 'vec' if use_vec and not use_dfa else ('dfa' if use_dfa and not use_vec else schedule.get('blend_alpha', 0.5))
+ )
+ log['gamma'].append((epoch, gamma))
+ log['rho'].append((epoch, rho))
+ else:
+ gamma, rho = None, None
+
+ if epoch % 10 == 0 or epoch <= 5 or epoch == epochs:
+ g_str = f", Gamma={gamma:.4f}, rho={rho:.4f}" if gamma is not None else ""
+ print(f" [{sname}] Ep {epoch} ({credit_mode_tag}): loss={train_loss:.4f}, "
+ f"test={test_acc:.4f}{g_str}")
+
+ return log, vector_net, Bs
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3
+ L = args.num_blocks
+ d = args.d_hidden
+
+ # Define schedules
+ schedules = []
+ for sname in args.schedules:
+ if sname == 'DFA_only':
+ schedules.append({'name': 'DFA_only', 'type': 'dfa_only'})
+ elif sname == 'Vec_only_from_0':
+ schedules.append({'name': 'Vec_only_from_0', 'type': 'vec_only'})
+ elif sname.startswith('Vec_early_then_DFA_T'):
+ T = int(sname.split('T')[1])
+ schedules.append({'name': sname, 'type': 'vec_then_dfa', 'switch_epoch': T})
+ elif sname.startswith('DFA_then_Vec_T'):
+ T = int(sname.split('T')[1])
+ schedules.append({'name': sname, 'type': 'dfa_then_vec', 'switch_epoch': T})
+ elif sname.startswith('Hybrid_blend_'):
+ alpha = float(sname.split('_')[-1])
+ schedules.append({'name': sname, 'type': 'blend', 'blend_alpha': alpha})
+ else:
+ raise ValueError(f"Unknown schedule: {sname}")
+
+ all_results = {}
+
+ for schedule in schedules:
+ sname = schedule['name']
+ print(f"\n{'='*60}")
+ print(f"Schedule: {sname}")
+ print(f"{'='*60}")
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ model = ResidualMLP(input_dim, d, 10, L).to(device)
+ log, vec_net, Bs = train_with_schedule(model, train_loader, test_loader, device, args, schedule)
+ all_results[sname] = log
+
+ # =========================================================
+ # Summary table
+ # =========================================================
+ print(f"\n{'='*100}")
+ print("SUMMARY")
+ print(f"{'='*100}")
+
+ # Extract key metrics
+ print(f"\n{'Schedule':<30} {'acc@5':>7} {'acc@10':>7} {'acc@20':>7} {'acc@50':>7} {'final':>7} "
+ f"{'mGamma[0:20]':>13} {'mRho[0:20]':>12}")
+ print("-" * 100)
+
+ for sname, log in all_results.items():
+ accs = log['test_acc']
+ acc5 = accs[4] if len(accs) >= 5 else accs[-1]
+ acc10 = accs[9] if len(accs) >= 10 else accs[-1]
+ acc20 = accs[19] if len(accs) >= 20 else accs[-1]
+ acc50 = accs[49] if len(accs) >= 50 else accs[-1]
+ final = accs[-1]
+
+ # Mean Gamma/rho for epochs 1-20
+ gammas_early = [g for e, g in log['gamma'] if e <= 20]
+ rhos_early = [r for e, r in log['rho'] if e <= 20]
+ mg = np.mean(gammas_early) if gammas_early else float('nan')
+ mr = np.mean(rhos_early) if rhos_early else float('nan')
+
+ print(f"{sname:<30} {acc5:>7.4f} {acc10:>7.4f} {acc20:>7.4f} {acc50:>7.4f} {final:>7.4f} "
+ f"{mg:>13.4f} {mr:>12.4f}")
+
+ # AUC early benefit
+ print(f"\nEarly accuracy AUC (sum of acc for epochs 1-20):")
+ for sname, log in all_results.items():
+ auc = sum(log['test_acc'][:20])
+ print(f" {sname:<30}: AUC_acc(0,20) = {auc:.2f}")
+
+ # Save
+ save_data = {}
+ for sname, log in all_results.items():
+ save_data[sname] = {
+ 'test_acc': log['test_acc'],
+ 'train_loss': log['train_loss'],
+ 'gamma': log['gamma'],
+ 'rho': log['rho'],
+ 'credit_mode': log['credit_mode'],
+ }
+
+ out_path = os.path.join(args.output_dir, f'schedules_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+ # =========================================================
+ # Judgment
+ # =========================================================
+ print(f"\n{'='*60}")
+ print("JUDGMENT")
+ print(f"{'='*60}")
+
+ if 'Vec_only_from_0' in all_results and 'DFA_only' in all_results:
+ vec0_acc20 = all_results['Vec_only_from_0']['test_acc'][19] if len(all_results['Vec_only_from_0']['test_acc']) >= 20 else 0
+ dfa_acc20 = all_results['DFA_only']['test_acc'][19] if len(all_results['DFA_only']['test_acc']) >= 20 else 0
+ vec0_final = all_results['Vec_only_from_0']['test_acc'][-1]
+ dfa_final = all_results['DFA_only']['test_acc'][-1]
+
+ print(f" Vec_from_0 acc@20={vec0_acc20:.4f} vs DFA acc@20={dfa_acc20:.4f}: "
+ f"{'Vec better' if vec0_acc20 > dfa_acc20 else 'DFA better'}")
+ print(f" Vec_from_0 final={vec0_final:.4f} vs DFA final={dfa_final:.4f}: "
+ f"{'Vec better' if vec0_final > dfa_final else 'DFA better'}")
+
+ if 'DFA_then_Vec_T20' in all_results and 'Vec_only_from_0' in all_results:
+ late_final = all_results['DFA_then_Vec_T20']['test_acc'][-1]
+ early_final = all_results['Vec_only_from_0']['test_acc'][-1]
+ print(f" Vec_from_0 final={early_final:.4f} vs DFA_then_Vec_T20 final={late_final:.4f}")
+ if early_final > late_final + 0.005:
+ print(f" -> WARMUP TIMING HYPOTHESIS SUPPORTED: early Vec is better")
+ elif abs(early_final - late_final) <= 0.005:
+ print(f" -> INCONCLUSIVE: similar final accuracy")
+ else:
+ print(f" -> WARMUP TIMING HYPOTHESIS NOT SUPPORTED")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 8: Schedule Hypothesis Test')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--epochs', type=int, default=100)
+ parser.add_argument('--lr', type=float, default=1e-3)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--wd', type=float, default=0.01)
+ parser.add_argument('--M', type=int, default=4)
+ parser.add_argument('--pert_eps', type=float, default=1e-3)
+ parser.add_argument('--schedules', type=str, nargs='+',
+ default=['DFA_only', 'Vec_only_from_0', 'Vec_early_then_DFA_T5', 'DFA_then_Vec_T20'])
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/schedule_timing')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()