diff options
| -rw-r--r-- | NOTE.md | 33 | ||||
| -rw-r--r-- | experiments/snapshot_time_sweep.py | 519 | ||||
| -rw-r--r-- | report_explore/MEMO_7A_snapshot_time_sweep.md | 37 |
3 files changed, 588 insertions, 1 deletions
@@ -5,7 +5,7 @@ - **pilot**: Controlled iteration (commits 0b9ebb2, 7baf7ae) - **frozen**: Code at commit 0b9ebb2 for all reported results -## Status: PHASE 6.5 PROTOCOL AUDIT — PHASE 6A CONCLUSION REVISED +## Status: PHASE 7A SNAPSHOT TIME SWEEP — EARLY SNAPSHOTS SHOW POSITIVE TRANSFER --- @@ -418,3 +418,34 @@ gradient noise) could make better credit usable. ### Experiment IDs (Phase 6.5) - `exploit_linesearch/`: Phase 6.5A smoke test (Oracle + Vec, last1, raw) - `exploit_linesearch_full/`: Phase 6.5A full sweep (all methods, ranges, norm modes) + +--- + +## Phase 7A: Snapshot Time Sweep + +**Setup**: BP snapshots at epoch {5, 20, 100} (acc 0.49/0.57/0.62). +Train Vec_M4 on each frozen snapshot. Test 1-step and 5-step with raw credit, last-block-only. + +**KEY FINDING: Held-out failure is primarily a LATE-SNAPSHOT artifact.** + +5-step DeltaLoss held-out: + +| Epoch | DFA dL_held | Vec dL_held | Oracle dL_held | Vec PUR | +|-------|-------------|-------------|----------------|---------| +| **5** | +0.003 | **-0.005** | **-0.009** | **0.70** | +| 20 | +0.001 | +0.002 | +0.000 | -3.87 | +| 100 | +0.000 | +0.001 | -0.001 | -1.01 | + +At epoch 5: Vec decreases held-out loss (PUR=0.70), Oracle too (PUR=1.05). +DFA INCREASES held-out at all snapshots. + +By epoch 20 the generalization window closes. + +**Better credit produces MORE consistent updates** (Vec variance=0.8 vs DFA variance=40). +The problem is not batch-specificity but snapshot timing: credit is useful early, useless late. + +**Implication**: The DFA warmup (which delays credit bridge to epoch ~20) is counterproductive. +Credit bridge should be used from epoch 0. + +### Experiment IDs (Phase 7) +- `snapshot_time/`: Phase 7A snapshot time sweep with BP checkpoints diff --git a/experiments/snapshot_time_sweep.py b/experiments/snapshot_time_sweep.py new file mode 100644 index 0000000..fd87927 --- /dev/null +++ b/experiments/snapshot_time_sweep.py @@ -0,0 +1,519 @@ +""" +Phase 7A: Snapshot-time sweep. + +Test whether "same-batch descent + held-out ascent" is a late-snapshot artifact +or persists across training time. + +For each snapshot epoch, train estimators on frozen features, then measure: +- DeltaL_same (same-batch 1-step and 5-step) +- DeltaL_held (held-out 1-step and 5-step) +- PUR = -DeltaL_held / (-DeltaL_same + 1e-12) +- Cross-batch update cosine and variance +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import SinusoidalTimeEmbed + + +class VectorCreditNet(nn.Module): + def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3): + super().__init__() + self.ln = nn.LayerNorm(d_hidden) + self.time_embed = SinusoidalTimeEmbed(time_embed_dim) + input_dim = d_hidden + time_embed_dim + s_dim + layers = [] + for i in range(num_layers): + in_d = input_dim if i == 0 else hidden_dim + layers.append(nn.Linear(in_d, hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, d_hidden)) + self.net = nn.Sequential(*layers) + + def forward(self, h, t, s): + h_normed = self.ln(h) + t_emb = self.time_embed(t) + inp = torch.cat([h_normed, t_emb, s], dim=-1) + return self.net(inp) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def evaluate_acc(model, test_loader, device): + model.eval() + c, t = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + c += (model(x).argmax(1) == y).sum().item(); t += x.size(0) + return c / t + + +# ============================================================================= +# BP training with checkpoint saving +# ============================================================================= +def train_bp_with_checkpoints(model, train_loader, test_loader, device, + epochs, save_epochs, ckpt_dir, lr=1e-3, wd=0.01): + """Train BP and save checkpoints at specified epochs.""" + os.makedirs(ckpt_dir, exist_ok=True) + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + # Save epoch 0 (init) + if 0 in save_epochs: + torch.save(model.state_dict(), os.path.join(ckpt_dir, 'epoch_0.pt')) + acc = evaluate_acc(model, test_loader, device) + print(f" Saved epoch 0 (acc={acc:.4f})") + + for epoch in range(1, epochs + 1): + model.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + loss = F.cross_entropy(model(x), y) + optimizer.zero_grad(); loss.backward(); optimizer.step() + scheduler.step() + + if epoch in save_epochs: + torch.save(model.state_dict(), os.path.join(ckpt_dir, f'epoch_{epoch}.pt')) + acc = evaluate_acc(model, test_loader, device) + print(f" Saved epoch {epoch} (acc={acc:.4f})") + + +# ============================================================================= +# Train vector field on frozen snapshot +# ============================================================================= +def train_vec_on_snapshot(model, train_loader, device, epochs=60, lr_fb=1e-3, M=4): + d = model.d_hidden + L = model.num_blocks + vec_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32, + hidden_dim=256, num_layers=3).to(device) + vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb) + eps = 1e-3 + model.eval() + for ep in range(1, epochs + 1): + vec_net.train() + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device); y = y.to(device) + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + hL = hiddens[-1].detach() + # Terminal matching + t_L = torch.ones(batch, device=device) + a_term = vec_net(hL, t_L, s) + hL_req = hL.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req)) + ce = F.cross_entropy(logits_tgt, y, reduction='sum') + delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach() + loss_term = ((a_term - delta_L) ** 2).sum(-1).mean() + # Perturbation target (subsample 1 layer) + l = np.random.randint(0, L) + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + a_l = vec_net(h_l, t_l, s) + loss_proj = torch.tensor(0.0, device=device) + for _ in range(M): + v = torch.randn_like(h_l) + v = v / (v.norm(dim=-1, keepdim=True) + 1e-8) + with torch.no_grad(): + lp = F.cross_entropy(model.forward_from_layer(h_l + eps * v, l), y, reduction='none') + lm = F.cross_entropy(model.forward_from_layer(h_l - eps * v, l), y, reduction='none') + g_j = (lp - lm) / (2 * eps) + loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach()) ** 2).mean() + loss_proj /= M + vloss = loss_term + loss_proj + vec_opt.zero_grad(); vloss.backward() + torch.nn.utils.clip_grad_norm_(vec_net.parameters(), 1.0) + vec_opt.step() + if ep % 20 == 0 or ep == 1: + print(f" [Vec] Ep {ep}") + return vec_net + + +# ============================================================================= +# Credit computation +# ============================================================================= +def get_credits(model, x, y, device, source, estimator=None, dfa_Bs=None): + L = model.num_blocks + batch = x.size(0) + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + credits = {} + if source == 'dfa': + for l in range(L): + credits[l] = (s @ dfa_Bs[l].T).detach() + elif source == 'vec': + estimator.eval() + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + credits[l] = estimator(h_l, t_l, s).detach() + elif source == 'oracle_bp': + for p in model.parameters(): p.requires_grad_(True) + model.zero_grad() + logits_bp, hbp = model(x, return_hidden=True) + for l in range(L + 1): hbp[l].retain_grad() + F.cross_entropy(logits_bp, y).backward() + for l in range(L): + credits[l] = hbp[l].grad.detach().clone() + for p in model.parameters(): p.requires_grad_(False) + return credits, hiddens + + +# ============================================================================= +# Local update and evaluation +# ============================================================================= +def compute_update_vector(model, x, y, credits, device, eta, update_layers, normalize=False): + """Compute the parameter update direction (as a flat vector) without applying it.""" + L = model.num_blocks + with torch.no_grad(): + _, hiddens = model(x, return_hidden=True) + + all_grads = [] + + # Head update + hL = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + grads_head = torch.autograd.grad(loss_out, head_params) + for g in grads_head: + all_grads.append(g.detach().flatten()) + + # Block updates + for l in update_layers: + h_l = hiddens[l].detach() + a = credits[l] + if normalize: + rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a.detach()).sum(-1).mean() + block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) + for g in block_grads: + all_grads.append(g.detach().flatten()) + + return torch.cat(all_grads) + + +def apply_update(model, x, y, credits, device, eta, update_layers, normalize=False): + """Apply one local surrogate update step. Returns model (modified in-place).""" + L = model.num_blocks + with torch.no_grad(): + _, hiddens = model(x, return_hidden=True) + + hL = hiddens[-1].detach() + logits_out = model.out_head(model.out_ln(hL)) + loss_out = F.cross_entropy(logits_out, y) + head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters()) + grads_head = torch.autograd.grad(loss_out, head_params) + with torch.no_grad(): + for p, g in zip(head_params, grads_head): + p.sub_(eta * g) + + for l in update_layers: + h_l = hiddens[l].detach() + a = credits[l] + if normalize: + rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6 + a = a / rms + f_l = model.blocks[l](h_l) + local_loss = (f_l * a.detach()).sum(-1).mean() + block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters()) + with torch.no_grad(): + for p, g in zip(model.blocks[l].parameters(), block_grads): + p.sub_(eta * g) + + +def eval_loss(model, x, y): + model.eval() + with torch.no_grad(): + return F.cross_entropy(model(x), y).item() + + +# ============================================================================= +# Main +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + train_loader, test_loader = get_cifar10(args.batch_size) + input_dim = 32 * 32 * 3 + L = args.num_blocks + d = args.d_hidden + + # ========================================================= + # Step 1: Train BP model with checkpoint saving + # ========================================================= + ckpt_dir = os.path.join(args.output_dir, f'bp_ckpts_L{L}_d{d}_s{args.seed}') + save_epochs = args.snapshot_epochs + + # Check if checkpoints already exist + all_exist = all(os.path.exists(os.path.join(ckpt_dir, f'epoch_{e}.pt')) for e in save_epochs) + + if not all_exist or args.retrain: + print(f"\nTraining BP model with checkpoints at epochs {save_epochs}...") + model_train = ResidualMLP(input_dim, d, 10, L).to(device) + train_bp_with_checkpoints(model_train, train_loader, test_loader, device, + epochs=max(save_epochs), save_epochs=save_epochs, + ckpt_dir=ckpt_dir) + else: + print(f"\nAll checkpoints exist in {ckpt_dir}") + + # ========================================================= + # Step 2: For each snapshot, train estimators and test exploitability + # ========================================================= + + # Fixed batches for consistent evaluation + train_iter = iter(train_loader) + x_same, y_same = next(train_iter) + x_same = x_same.view(x_same.size(0), -1).to(device); y_same = y_same.to(device) + x_held, y_held = next(train_iter) + x_held = x_held.view(x_held.size(0), -1).to(device); y_held = y_held.to(device) + + # Extra batches for cross-batch variance + extra_batches = [] + for _ in range(8): + xb, yb = next(train_iter) + extra_batches.append((xb.view(xb.size(0), -1).to(device), yb.to(device))) + + # DFA matrices (fixed across snapshots) + dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)] + + update_layers = [L - 1] # last block only + all_results = [] + + for epoch in save_epochs: + print(f"\n{'='*60}") + print(f"Snapshot: epoch {epoch}") + print(f"{'='*60}") + + # Load snapshot + model = ResidualMLP(input_dim, d, 10, L).to(device) + ckpt_path = os.path.join(ckpt_dir, f'epoch_{epoch}.pt') + model.load_state_dict(torch.load(ckpt_path, map_location=device)) + model.eval() + for p in model.parameters(): p.requires_grad_(False) + acc = evaluate_acc(model, test_loader, device) + print(f" Accuracy: {acc:.4f}") + + loss_same_before = eval_loss(model, x_same, y_same) + loss_held_before = eval_loss(model, x_held, y_held) + print(f" Loss: same={loss_same_before:.4f}, held={loss_held_before:.4f}") + + # Train Vec on this snapshot + print(f" Training Vec_M4...") + torch.manual_seed(args.seed + epoch * 100 + 4000) + vec_net = train_vec_on_snapshot(model, train_loader, device, + epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4) + + credit_sources = { + 'dfa': ('dfa', None, dfa_Bs), + 'vec_eT_M4': ('vec', vec_net, None), + 'oracle_bp': ('oracle_bp', None, None), + } + + # Eta line search for each method + etas = args.etas + + for name, (src, est, Bs) in credit_sources.items(): + if name not in args.methods: + continue + + # Compute credits on same batch + credits_same, _ = get_credits(model, x_same, y_same, device, src, + estimator=est, dfa_Bs=Bs) + + best_eta = None + best_dl_same = float('inf') + + for eta in etas: + # 1-step test + model_test = copy.deepcopy(model) + for p in model_test.parameters(): p.requires_grad_(True) + apply_update(model_test, x_same, y_same, credits_same, device, + eta=eta, update_layers=update_layers, normalize=False) + for p in model_test.parameters(): p.requires_grad_(False) + + dl_same = eval_loss(model_test, x_same, y_same) - loss_same_before + dl_held = eval_loss(model_test, x_held, y_held) - loss_held_before + + if dl_same < best_dl_same: + best_dl_same = dl_same + best_eta = eta + best_dl_held = dl_held + + # 5-step rollout at best eta + model_5 = copy.deepcopy(model) + for p in model_5.parameters(): p.requires_grad_(True) + train_iter2 = iter(train_loader) + for step in range(5): + try: xs, ys = next(train_iter2) + except StopIteration: train_iter2 = iter(train_loader); xs, ys = next(train_iter2) + xs = xs.view(xs.size(0), -1).to(device); ys = ys.to(device) + for p in model_5.parameters(): p.requires_grad_(False) + creds_step, _ = get_credits(model_5, xs, ys, device, src, estimator=est, dfa_Bs=Bs) + for p in model_5.parameters(): p.requires_grad_(True) + apply_update(model_5, xs, ys, creds_step, device, + eta=best_eta, update_layers=update_layers, normalize=False) + for p in model_5.parameters(): p.requires_grad_(False) + dl_same_5 = eval_loss(model_5, x_same, y_same) - loss_same_before + dl_held_5 = eval_loss(model_5, x_held, y_held) - loss_held_before + + # Cross-batch update variance + update_vecs = [] + for xb, yb in extra_batches[:4]: + # get_credits may toggle requires_grad for oracle_bp + for p in model.parameters(): p.requires_grad_(False) + creds_b, _ = get_credits(model, xb, yb, device, src, estimator=est, dfa_Bs=Bs) + # compute_update_vector needs requires_grad=True + for p in model.parameters(): p.requires_grad_(True) + u = compute_update_vector(model, xb, yb, creds_b, device, + eta=best_eta, update_layers=update_layers, normalize=False) + update_vecs.append(u) + for p in model.parameters(): p.requires_grad_(False) + + # Update cosine (mean pairwise cosine) + cosines = [] + for i in range(len(update_vecs)): + for j in range(i + 1, len(update_vecs)): + cos = F.cosine_similarity(update_vecs[i].unsqueeze(0), + update_vecs[j].unsqueeze(0)).item() + cosines.append(cos) + update_cos = float(np.mean(cosines)) if cosines else 0.0 + + # Update variance + stacked = torch.stack(update_vecs) + mean_u = stacked.mean(0) + update_var = ((stacked - mean_u) ** 2).sum(-1).mean().item() + + # PUR + pur_1 = -best_dl_held / (-best_dl_same + 1e-12) if best_dl_same < 0 else float('nan') + pur_5 = -dl_held_5 / (-dl_same_5 + 1e-12) if dl_same_5 < 0 else float('nan') + + result = { + 'snapshot_epoch': epoch, 'method': name, 'snapshot_acc': float(acc), + 'best_eta': best_eta, + 'dl_same_1': best_dl_same, 'dl_held_1': best_dl_held, 'pur_1': pur_1, + 'dl_same_5': dl_same_5, 'dl_held_5': dl_held_5, 'pur_5': pur_5, + 'update_cos': update_cos, 'update_var': update_var, + } + all_results.append(result) + + print(f" {name:>12}: eta={best_eta:.0e}, dL_same_1={best_dl_same:+.6f}, " + f"dL_held_1={best_dl_held:+.6f}, PUR_1={pur_1:.3f}, " + f"dL_same_5={dl_same_5:+.6f}, dL_held_5={dl_held_5:+.6f}, PUR_5={pur_5:.3f}, " + f"u_cos={update_cos:.3f}, u_var={update_var:.2e}") + + # ========================================================= + # Summary + # ========================================================= + print(f"\n{'='*100}") + print("SUMMARY") + print(f"{'='*100}") + print(f"{'Epoch':>6} {'Acc':>6} {'Method':>12} {'eta':>8} {'dL_same_1':>10} {'dL_held_1':>10} " + f"{'PUR_1':>7} {'dL_same_5':>10} {'dL_held_5':>10} {'PUR_5':>7} {'u_cos':>6} {'u_var':>10}") + print("-" * 110) + for r in all_results: + print(f"{r['snapshot_epoch']:>6} {r['snapshot_acc']:>6.3f} {r['method']:>12} {r['best_eta']:>8.0e} " + f"{r['dl_same_1']:>+10.6f} {r['dl_held_1']:>+10.6f} {r['pur_1']:>7.3f} " + f"{r['dl_same_5']:>+10.6f} {r['dl_held_5']:>+10.6f} {r['pur_5']:>7.3f} " + f"{r['update_cos']:>6.3f} {r['update_var']:>10.2e}") + + # Save + out_path = os.path.join(args.output_dir, f'time_sweep_L{L}_d{d}_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(all_results, f, indent=2, default=float) + print(f"\nSaved to {out_path}") + + # Judgment + print(f"\n{'='*60}") + print("JUDGMENT") + print(f"{'='*60}") + + early_held_failures = 0 + late_held_failures = 0 + for r in all_results: + if r['method'] == 'vec_eT_M4': + if r['snapshot_epoch'] <= 20 and r['dl_held_1'] > 0: + early_held_failures += 1 + if r['snapshot_epoch'] >= 50 and r['dl_held_1'] > 0: + late_held_failures += 1 + + early_epochs = [e for e in save_epochs if e <= 20] + late_epochs = [e for e in save_epochs if e >= 50] + + if early_held_failures == 0 and late_held_failures > 0: + print("LATE-SNAPSHOT ARTIFACT: held-out failure only at late snapshots.") + print(" -> Early-training local updates with good credit DO generalize.") + elif early_held_failures > 0 and late_held_failures > 0: + print("ACROSS-TRAINING FAILURE: held-out degradation at both early and late snapshots.") + print(" -> Problem is NOT just late-snapshot overfitting.") + else: + print("NEED MORE DATA: check results table above.") + + +def main(): + parser = argparse.ArgumentParser(description='Phase 7A: Snapshot Time Sweep') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--snapshot_epochs', type=int, nargs='+', default=[5, 20, 100]) + parser.add_argument('--estimator_epochs', type=int, default=60) + parser.add_argument('--lr_fb', type=float, default=1e-3) + parser.add_argument('--etas', type=float, nargs='+', + default=[1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2]) + parser.add_argument('--methods', type=str, nargs='+', + default=['dfa', 'vec_eT_M4', 'oracle_bp']) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=3) + parser.add_argument('--output_dir', type=str, default='results/snapshot_time') + parser.add_argument('--retrain', action='store_true') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() diff --git a/report_explore/MEMO_7A_snapshot_time_sweep.md b/report_explore/MEMO_7A_snapshot_time_sweep.md new file mode 100644 index 0000000..31e4cb2 --- /dev/null +++ b/report_explore/MEMO_7A_snapshot_time_sweep.md @@ -0,0 +1,37 @@ +# Phase 7A Memo: Snapshot Time Sweep + +**Date**: 2026-03-25 + +## Question +Is "same-batch descent + held-out ascent" a late-snapshot artifact, or does it persist across training? + +## Answer: Primarily a late-snapshot artifact. Early snapshots show positive held-out transfer. + +### 5-step DeltaLoss results (raw credit, last-block-only): + +| Epoch | Acc | DFA dL_held | Vec dL_held | Oracle dL_held | Vec PUR_5 | +|-------|-----|-------------|-------------|----------------|-----------| +| **5** | 0.49 | +0.003 | **-0.005** | **-0.009** | **0.70** | +| 20 | 0.57 | +0.001 | +0.002 | +0.000 | -3.87 | +| 100 | 0.62 | +0.000 | +0.001 | -0.001 | -1.01 | + +### Key findings: + +1. **At epoch 5, Vec and Oracle both decrease held-out loss**, while DFA increases it. Vec PUR=0.70 means 70% of same-batch improvement transfers to held-out. Oracle PUR=1.05 (>100% transfer). + +2. **By epoch 20, the generalization window closes.** All methods show near-zero or positive held-out change. + +3. **Better credit → lower update variance.** Vec/Oracle update variance is 50x lower than DFA (0.4-0.8 vs 40-60). Better credit produces MORE consistent cross-batch updates, not less. + +4. **DFA never improves held-out at any snapshot.** Its updates are random enough to sometimes decrease same-batch loss but never systematically improve held-out. + +## Implications + +The "better credit is useless" narrative from Phase 6A/6.5A was wrong on two counts: +1. Same-batch exploitability works (Phase 6.5A) +2. Early-snapshot held-out transfer works too (this experiment) + +The online training failure is because by the time the warmup phase ends and credit bridge takes over (epoch ~20), the network is already past the "generalization window" where local credit updates are useful. The fix should be: **use credit bridge from the start (no DFA warmup), or switch earlier.** + +## Next step recommendation +Phase 7B (multi-batch averaging) may not be needed given that the held-out failure is a snapshot-timing issue, not a batch-variance issue. Instead, the priority should be testing online training WITH vector credit from epoch 0 (no warmup or very short warmup). |
