summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--NOTE.md33
-rw-r--r--experiments/snapshot_time_sweep.py519
-rw-r--r--report_explore/MEMO_7A_snapshot_time_sweep.md37
3 files changed, 588 insertions, 1 deletions
diff --git a/NOTE.md b/NOTE.md
index a57fd30..892cf1e 100644
--- a/NOTE.md
+++ b/NOTE.md
@@ -5,7 +5,7 @@
- **pilot**: Controlled iteration (commits 0b9ebb2, 7baf7ae)
- **frozen**: Code at commit 0b9ebb2 for all reported results
-## Status: PHASE 6.5 PROTOCOL AUDIT — PHASE 6A CONCLUSION REVISED
+## Status: PHASE 7A SNAPSHOT TIME SWEEP — EARLY SNAPSHOTS SHOW POSITIVE TRANSFER
---
@@ -418,3 +418,34 @@ gradient noise) could make better credit usable.
### Experiment IDs (Phase 6.5)
- `exploit_linesearch/`: Phase 6.5A smoke test (Oracle + Vec, last1, raw)
- `exploit_linesearch_full/`: Phase 6.5A full sweep (all methods, ranges, norm modes)
+
+---
+
+## Phase 7A: Snapshot Time Sweep
+
+**Setup**: BP snapshots at epoch {5, 20, 100} (acc 0.49/0.57/0.62).
+Train Vec_M4 on each frozen snapshot. Test 1-step and 5-step with raw credit, last-block-only.
+
+**KEY FINDING: Held-out failure is primarily a LATE-SNAPSHOT artifact.**
+
+5-step DeltaLoss held-out:
+
+| Epoch | DFA dL_held | Vec dL_held | Oracle dL_held | Vec PUR |
+|-------|-------------|-------------|----------------|---------|
+| **5** | +0.003 | **-0.005** | **-0.009** | **0.70** |
+| 20 | +0.001 | +0.002 | +0.000 | -3.87 |
+| 100 | +0.000 | +0.001 | -0.001 | -1.01 |
+
+At epoch 5: Vec decreases held-out loss (PUR=0.70), Oracle too (PUR=1.05).
+DFA INCREASES held-out at all snapshots.
+
+By epoch 20 the generalization window closes.
+
+**Better credit produces MORE consistent updates** (Vec variance=0.8 vs DFA variance=40).
+The problem is not batch-specificity but snapshot timing: credit is useful early, useless late.
+
+**Implication**: The DFA warmup (which delays credit bridge to epoch ~20) is counterproductive.
+Credit bridge should be used from epoch 0.
+
+### Experiment IDs (Phase 7)
+- `snapshot_time/`: Phase 7A snapshot time sweep with BP checkpoints
diff --git a/experiments/snapshot_time_sweep.py b/experiments/snapshot_time_sweep.py
new file mode 100644
index 0000000..fd87927
--- /dev/null
+++ b/experiments/snapshot_time_sweep.py
@@ -0,0 +1,519 @@
+"""
+Phase 7A: Snapshot-time sweep.
+
+Test whether "same-batch descent + held-out ascent" is a late-snapshot artifact
+or persists across training time.
+
+For each snapshot epoch, train estimators on frozen features, then measure:
+- DeltaL_same (same-batch 1-step and 5-step)
+- DeltaL_held (held-out 1-step and 5-step)
+- PUR = -DeltaL_held / (-DeltaL_same + 1e-12)
+- Cross-batch update cosine and variance
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import SinusoidalTimeEmbed
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate_acc(model, test_loader, device):
+ model.eval()
+ c, t = 0, 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ c += (model(x).argmax(1) == y).sum().item(); t += x.size(0)
+ return c / t
+
+
+# =============================================================================
+# BP training with checkpoint saving
+# =============================================================================
+def train_bp_with_checkpoints(model, train_loader, test_loader, device,
+ epochs, save_epochs, ckpt_dir, lr=1e-3, wd=0.01):
+ """Train BP and save checkpoints at specified epochs."""
+ os.makedirs(ckpt_dir, exist_ok=True)
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
+
+ # Save epoch 0 (init)
+ if 0 in save_epochs:
+ torch.save(model.state_dict(), os.path.join(ckpt_dir, 'epoch_0.pt'))
+ acc = evaluate_acc(model, test_loader, device)
+ print(f" Saved epoch 0 (acc={acc:.4f})")
+
+ for epoch in range(1, epochs + 1):
+ model.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ loss = F.cross_entropy(model(x), y)
+ optimizer.zero_grad(); loss.backward(); optimizer.step()
+ scheduler.step()
+
+ if epoch in save_epochs:
+ torch.save(model.state_dict(), os.path.join(ckpt_dir, f'epoch_{epoch}.pt'))
+ acc = evaluate_acc(model, test_loader, device)
+ print(f" Saved epoch {epoch} (acc={acc:.4f})")
+
+
+# =============================================================================
+# Train vector field on frozen snapshot
+# =============================================================================
+def train_vec_on_snapshot(model, train_loader, device, epochs=60, lr_fb=1e-3, M=4):
+ d = model.d_hidden
+ L = model.num_blocks
+ vec_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ vec_opt = optim.Adam(vec_net.parameters(), lr=lr_fb)
+ eps = 1e-3
+ model.eval()
+ for ep in range(1, epochs + 1):
+ vec_net.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ hL = hiddens[-1].detach()
+ # Terminal matching
+ t_L = torch.ones(batch, device=device)
+ a_term = vec_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_term - delta_L) ** 2).sum(-1).mean()
+ # Perturbation target (subsample 1 layer)
+ l = np.random.randint(0, L)
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ a_l = vec_net(h_l, t_l, s)
+ loss_proj = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(model.forward_from_layer(h_l + eps * v, l), y, reduction='none')
+ lm = F.cross_entropy(model.forward_from_layer(h_l - eps * v, l), y, reduction='none')
+ g_j = (lp - lm) / (2 * eps)
+ loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach()) ** 2).mean()
+ loss_proj /= M
+ vloss = loss_term + loss_proj
+ vec_opt.zero_grad(); vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vec_net.parameters(), 1.0)
+ vec_opt.step()
+ if ep % 20 == 0 or ep == 1:
+ print(f" [Vec] Ep {ep}")
+ return vec_net
+
+
+# =============================================================================
+# Credit computation
+# =============================================================================
+def get_credits(model, x, y, device, source, estimator=None, dfa_Bs=None):
+ L = model.num_blocks
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ credits = {}
+ if source == 'dfa':
+ for l in range(L):
+ credits[l] = (s @ dfa_Bs[l].T).detach()
+ elif source == 'vec':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ credits[l] = estimator(h_l, t_l, s).detach()
+ elif source == 'oracle_bp':
+ for p in model.parameters(): p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hbp = model(x, return_hidden=True)
+ for l in range(L + 1): hbp[l].retain_grad()
+ F.cross_entropy(logits_bp, y).backward()
+ for l in range(L):
+ credits[l] = hbp[l].grad.detach().clone()
+ for p in model.parameters(): p.requires_grad_(False)
+ return credits, hiddens
+
+
+# =============================================================================
+# Local update and evaluation
+# =============================================================================
+def compute_update_vector(model, x, y, credits, device, eta, update_layers, normalize=False):
+ """Compute the parameter update direction (as a flat vector) without applying it."""
+ L = model.num_blocks
+ with torch.no_grad():
+ _, hiddens = model(x, return_hidden=True)
+
+ all_grads = []
+
+ # Head update
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ for g in grads_head:
+ all_grads.append(g.detach().flatten())
+
+ # Block updates
+ for l in update_layers:
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ if normalize:
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a.detach()).sum(-1).mean()
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ for g in block_grads:
+ all_grads.append(g.detach().flatten())
+
+ return torch.cat(all_grads)
+
+
+def apply_update(model, x, y, credits, device, eta, update_layers, normalize=False):
+ """Apply one local surrogate update step. Returns model (modified in-place)."""
+ L = model.num_blocks
+ with torch.no_grad():
+ _, hiddens = model(x, return_hidden=True)
+
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(eta * g)
+
+ for l in update_layers:
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ if normalize:
+ rms = (a ** 2).mean(-1, keepdim=True).sqrt() + 1e-6
+ a = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a.detach()).sum(-1).mean()
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(eta * g)
+
+
+def eval_loss(model, x, y):
+ model.eval()
+ with torch.no_grad():
+ return F.cross_entropy(model(x), y).item()
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3
+ L = args.num_blocks
+ d = args.d_hidden
+
+ # =========================================================
+ # Step 1: Train BP model with checkpoint saving
+ # =========================================================
+ ckpt_dir = os.path.join(args.output_dir, f'bp_ckpts_L{L}_d{d}_s{args.seed}')
+ save_epochs = args.snapshot_epochs
+
+ # Check if checkpoints already exist
+ all_exist = all(os.path.exists(os.path.join(ckpt_dir, f'epoch_{e}.pt')) for e in save_epochs)
+
+ if not all_exist or args.retrain:
+ print(f"\nTraining BP model with checkpoints at epochs {save_epochs}...")
+ model_train = ResidualMLP(input_dim, d, 10, L).to(device)
+ train_bp_with_checkpoints(model_train, train_loader, test_loader, device,
+ epochs=max(save_epochs), save_epochs=save_epochs,
+ ckpt_dir=ckpt_dir)
+ else:
+ print(f"\nAll checkpoints exist in {ckpt_dir}")
+
+ # =========================================================
+ # Step 2: For each snapshot, train estimators and test exploitability
+ # =========================================================
+
+ # Fixed batches for consistent evaluation
+ train_iter = iter(train_loader)
+ x_same, y_same = next(train_iter)
+ x_same = x_same.view(x_same.size(0), -1).to(device); y_same = y_same.to(device)
+ x_held, y_held = next(train_iter)
+ x_held = x_held.view(x_held.size(0), -1).to(device); y_held = y_held.to(device)
+
+ # Extra batches for cross-batch variance
+ extra_batches = []
+ for _ in range(8):
+ xb, yb = next(train_iter)
+ extra_batches.append((xb.view(xb.size(0), -1).to(device), yb.to(device)))
+
+ # DFA matrices (fixed across snapshots)
+ dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ update_layers = [L - 1] # last block only
+ all_results = []
+
+ for epoch in save_epochs:
+ print(f"\n{'='*60}")
+ print(f"Snapshot: epoch {epoch}")
+ print(f"{'='*60}")
+
+ # Load snapshot
+ model = ResidualMLP(input_dim, d, 10, L).to(device)
+ ckpt_path = os.path.join(ckpt_dir, f'epoch_{epoch}.pt')
+ model.load_state_dict(torch.load(ckpt_path, map_location=device))
+ model.eval()
+ for p in model.parameters(): p.requires_grad_(False)
+ acc = evaluate_acc(model, test_loader, device)
+ print(f" Accuracy: {acc:.4f}")
+
+ loss_same_before = eval_loss(model, x_same, y_same)
+ loss_held_before = eval_loss(model, x_held, y_held)
+ print(f" Loss: same={loss_same_before:.4f}, held={loss_held_before:.4f}")
+
+ # Train Vec on this snapshot
+ print(f" Training Vec_M4...")
+ torch.manual_seed(args.seed + epoch * 100 + 4000)
+ vec_net = train_vec_on_snapshot(model, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4)
+
+ credit_sources = {
+ 'dfa': ('dfa', None, dfa_Bs),
+ 'vec_eT_M4': ('vec', vec_net, None),
+ 'oracle_bp': ('oracle_bp', None, None),
+ }
+
+ # Eta line search for each method
+ etas = args.etas
+
+ for name, (src, est, Bs) in credit_sources.items():
+ if name not in args.methods:
+ continue
+
+ # Compute credits on same batch
+ credits_same, _ = get_credits(model, x_same, y_same, device, src,
+ estimator=est, dfa_Bs=Bs)
+
+ best_eta = None
+ best_dl_same = float('inf')
+
+ for eta in etas:
+ # 1-step test
+ model_test = copy.deepcopy(model)
+ for p in model_test.parameters(): p.requires_grad_(True)
+ apply_update(model_test, x_same, y_same, credits_same, device,
+ eta=eta, update_layers=update_layers, normalize=False)
+ for p in model_test.parameters(): p.requires_grad_(False)
+
+ dl_same = eval_loss(model_test, x_same, y_same) - loss_same_before
+ dl_held = eval_loss(model_test, x_held, y_held) - loss_held_before
+
+ if dl_same < best_dl_same:
+ best_dl_same = dl_same
+ best_eta = eta
+ best_dl_held = dl_held
+
+ # 5-step rollout at best eta
+ model_5 = copy.deepcopy(model)
+ for p in model_5.parameters(): p.requires_grad_(True)
+ train_iter2 = iter(train_loader)
+ for step in range(5):
+ try: xs, ys = next(train_iter2)
+ except StopIteration: train_iter2 = iter(train_loader); xs, ys = next(train_iter2)
+ xs = xs.view(xs.size(0), -1).to(device); ys = ys.to(device)
+ for p in model_5.parameters(): p.requires_grad_(False)
+ creds_step, _ = get_credits(model_5, xs, ys, device, src, estimator=est, dfa_Bs=Bs)
+ for p in model_5.parameters(): p.requires_grad_(True)
+ apply_update(model_5, xs, ys, creds_step, device,
+ eta=best_eta, update_layers=update_layers, normalize=False)
+ for p in model_5.parameters(): p.requires_grad_(False)
+ dl_same_5 = eval_loss(model_5, x_same, y_same) - loss_same_before
+ dl_held_5 = eval_loss(model_5, x_held, y_held) - loss_held_before
+
+ # Cross-batch update variance
+ update_vecs = []
+ for xb, yb in extra_batches[:4]:
+ # get_credits may toggle requires_grad for oracle_bp
+ for p in model.parameters(): p.requires_grad_(False)
+ creds_b, _ = get_credits(model, xb, yb, device, src, estimator=est, dfa_Bs=Bs)
+ # compute_update_vector needs requires_grad=True
+ for p in model.parameters(): p.requires_grad_(True)
+ u = compute_update_vector(model, xb, yb, creds_b, device,
+ eta=best_eta, update_layers=update_layers, normalize=False)
+ update_vecs.append(u)
+ for p in model.parameters(): p.requires_grad_(False)
+
+ # Update cosine (mean pairwise cosine)
+ cosines = []
+ for i in range(len(update_vecs)):
+ for j in range(i + 1, len(update_vecs)):
+ cos = F.cosine_similarity(update_vecs[i].unsqueeze(0),
+ update_vecs[j].unsqueeze(0)).item()
+ cosines.append(cos)
+ update_cos = float(np.mean(cosines)) if cosines else 0.0
+
+ # Update variance
+ stacked = torch.stack(update_vecs)
+ mean_u = stacked.mean(0)
+ update_var = ((stacked - mean_u) ** 2).sum(-1).mean().item()
+
+ # PUR
+ pur_1 = -best_dl_held / (-best_dl_same + 1e-12) if best_dl_same < 0 else float('nan')
+ pur_5 = -dl_held_5 / (-dl_same_5 + 1e-12) if dl_same_5 < 0 else float('nan')
+
+ result = {
+ 'snapshot_epoch': epoch, 'method': name, 'snapshot_acc': float(acc),
+ 'best_eta': best_eta,
+ 'dl_same_1': best_dl_same, 'dl_held_1': best_dl_held, 'pur_1': pur_1,
+ 'dl_same_5': dl_same_5, 'dl_held_5': dl_held_5, 'pur_5': pur_5,
+ 'update_cos': update_cos, 'update_var': update_var,
+ }
+ all_results.append(result)
+
+ print(f" {name:>12}: eta={best_eta:.0e}, dL_same_1={best_dl_same:+.6f}, "
+ f"dL_held_1={best_dl_held:+.6f}, PUR_1={pur_1:.3f}, "
+ f"dL_same_5={dl_same_5:+.6f}, dL_held_5={dl_held_5:+.6f}, PUR_5={pur_5:.3f}, "
+ f"u_cos={update_cos:.3f}, u_var={update_var:.2e}")
+
+ # =========================================================
+ # Summary
+ # =========================================================
+ print(f"\n{'='*100}")
+ print("SUMMARY")
+ print(f"{'='*100}")
+ print(f"{'Epoch':>6} {'Acc':>6} {'Method':>12} {'eta':>8} {'dL_same_1':>10} {'dL_held_1':>10} "
+ f"{'PUR_1':>7} {'dL_same_5':>10} {'dL_held_5':>10} {'PUR_5':>7} {'u_cos':>6} {'u_var':>10}")
+ print("-" * 110)
+ for r in all_results:
+ print(f"{r['snapshot_epoch']:>6} {r['snapshot_acc']:>6.3f} {r['method']:>12} {r['best_eta']:>8.0e} "
+ f"{r['dl_same_1']:>+10.6f} {r['dl_held_1']:>+10.6f} {r['pur_1']:>7.3f} "
+ f"{r['dl_same_5']:>+10.6f} {r['dl_held_5']:>+10.6f} {r['pur_5']:>7.3f} "
+ f"{r['update_cos']:>6.3f} {r['update_var']:>10.2e}")
+
+ # Save
+ out_path = os.path.join(args.output_dir, f'time_sweep_L{L}_d{d}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(all_results, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+ # Judgment
+ print(f"\n{'='*60}")
+ print("JUDGMENT")
+ print(f"{'='*60}")
+
+ early_held_failures = 0
+ late_held_failures = 0
+ for r in all_results:
+ if r['method'] == 'vec_eT_M4':
+ if r['snapshot_epoch'] <= 20 and r['dl_held_1'] > 0:
+ early_held_failures += 1
+ if r['snapshot_epoch'] >= 50 and r['dl_held_1'] > 0:
+ late_held_failures += 1
+
+ early_epochs = [e for e in save_epochs if e <= 20]
+ late_epochs = [e for e in save_epochs if e >= 50]
+
+ if early_held_failures == 0 and late_held_failures > 0:
+ print("LATE-SNAPSHOT ARTIFACT: held-out failure only at late snapshots.")
+ print(" -> Early-training local updates with good credit DO generalize.")
+ elif early_held_failures > 0 and late_held_failures > 0:
+ print("ACROSS-TRAINING FAILURE: held-out degradation at both early and late snapshots.")
+ print(" -> Problem is NOT just late-snapshot overfitting.")
+ else:
+ print("NEED MORE DATA: check results table above.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 7A: Snapshot Time Sweep')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--snapshot_epochs', type=int, nargs='+', default=[5, 20, 100])
+ parser.add_argument('--estimator_epochs', type=int, default=60)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--etas', type=float, nargs='+',
+ default=[1e-5, 3e-5, 1e-4, 3e-4, 1e-3, 3e-3, 1e-2])
+ parser.add_argument('--methods', type=str, nargs='+',
+ default=['dfa', 'vec_eT_M4', 'oracle_bp'])
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/snapshot_time')
+ parser.add_argument('--retrain', action='store_true')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/report_explore/MEMO_7A_snapshot_time_sweep.md b/report_explore/MEMO_7A_snapshot_time_sweep.md
new file mode 100644
index 0000000..31e4cb2
--- /dev/null
+++ b/report_explore/MEMO_7A_snapshot_time_sweep.md
@@ -0,0 +1,37 @@
+# Phase 7A Memo: Snapshot Time Sweep
+
+**Date**: 2026-03-25
+
+## Question
+Is "same-batch descent + held-out ascent" a late-snapshot artifact, or does it persist across training?
+
+## Answer: Primarily a late-snapshot artifact. Early snapshots show positive held-out transfer.
+
+### 5-step DeltaLoss results (raw credit, last-block-only):
+
+| Epoch | Acc | DFA dL_held | Vec dL_held | Oracle dL_held | Vec PUR_5 |
+|-------|-----|-------------|-------------|----------------|-----------|
+| **5** | 0.49 | +0.003 | **-0.005** | **-0.009** | **0.70** |
+| 20 | 0.57 | +0.001 | +0.002 | +0.000 | -3.87 |
+| 100 | 0.62 | +0.000 | +0.001 | -0.001 | -1.01 |
+
+### Key findings:
+
+1. **At epoch 5, Vec and Oracle both decrease held-out loss**, while DFA increases it. Vec PUR=0.70 means 70% of same-batch improvement transfers to held-out. Oracle PUR=1.05 (>100% transfer).
+
+2. **By epoch 20, the generalization window closes.** All methods show near-zero or positive held-out change.
+
+3. **Better credit → lower update variance.** Vec/Oracle update variance is 50x lower than DFA (0.4-0.8 vs 40-60). Better credit produces MORE consistent cross-batch updates, not less.
+
+4. **DFA never improves held-out at any snapshot.** Its updates are random enough to sometimes decrease same-batch loss but never systematically improve held-out.
+
+## Implications
+
+The "better credit is useless" narrative from Phase 6A/6.5A was wrong on two counts:
+1. Same-batch exploitability works (Phase 6.5A)
+2. Early-snapshot held-out transfer works too (this experiment)
+
+The online training failure is because by the time the warmup phase ends and credit bridge takes over (epoch ~20), the network is already past the "generalization window" where local credit updates are useful. The fix should be: **use credit bridge from the start (no DFA warmup), or switch earlier.**
+
+## Next step recommendation
+Phase 7B (multi-batch averaging) may not be needed given that the held-out failure is a snapshot-timing issue, not a batch-variance issue. Instead, the priority should be testing online training WITH vector credit from epoch 0 (no warmup or very short warmup).