From 3d17cbad98f320905c52509c7f18691eab8bf2a0 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 24 Mar 2026 12:47:19 -0500 Subject: Add Phase 4 diagnostic dissection: frozen credit recovery, online shallow scan, vector field pilot Key findings: - Frozen CIFAR: estimators CAN recover credit (SB best, CB 20x > DFA) - Online shallow: cb_eT wr=0.2 tgw=1.0 achieves S1>0, S2 marginal - Vector credit field: 0.91-0.96 Gamma/rho on synthetic (vs 0.34 scalar CB) - Direct vector field avoids scalar V curvature problem entirely Co-Authored-By: Claude Opus 4.6 (1M context) --- experiments/cifar_frozen_credit_recovery.py | 693 ++++++++++++++++++++++++++++ 1 file changed, 693 insertions(+) create mode 100644 experiments/cifar_frozen_credit_recovery.py (limited to 'experiments/cifar_frozen_credit_recovery.py') diff --git a/experiments/cifar_frozen_credit_recovery.py b/experiments/cifar_frozen_credit_recovery.py new file mode 100644 index 0000000..5d39308 --- /dev/null +++ b/experiments/cifar_frozen_credit_recovery.py @@ -0,0 +1,693 @@ +""" +Phase A: Frozen CIFAR Credit Recovery. + +Goal: Separate "estimator problem" from "forward exploitability problem". +1. Train a BP reference network to convergence, freeze it. +2. On frozen features, train credit estimators (state bridge, scalar CB with eT/deltaL). +3. Evaluate Gamma, rho, nudging per layer. + +This answers: can the credit estimator recover useful local credit from fixed representations? +""" +import os +import sys +import json +import argparse +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import DataLoader +import torchvision +import torchvision.transforms as transforms +import copy +import time + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP +from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema +from models.state_bridge import StateBridgeNet +from metrics.credit_metrics import ( + cosine_similarity_batch, perturbation_correlation, nudging_test, + offline_bp_cosine +) + + +def get_cifar10(batch_size=128): + transform_train = transforms.Compose([ + transforms.RandomCrop(32, padding=4), + transforms.RandomHorizontalFlip(), + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + transform_test = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train) + testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test) + train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True) + test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True) + return train_loader, test_loader + + +def evaluate(model, test_loader, device): + model.eval() + correct, total = 0, 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + return correct / total + + +# ============================================================================= +# Step 1: Train BP reference network +# ============================================================================= +def train_bp_reference(model, train_loader, test_loader, device, epochs=100, lr=1e-3, wd=0.01): + """Train BP reference to convergence.""" + optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=wd) + scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs) + + for epoch in range(1, epochs + 1): + model.train() + total_loss, correct, total = 0, 0, 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + logits = model(x) + loss = F.cross_entropy(logits, y) + optimizer.zero_grad() + loss.backward() + optimizer.step() + total_loss += loss.item() * x.size(0) + correct += (logits.argmax(1) == y).sum().item() + total += x.size(0) + scheduler.step() + if epoch % 10 == 0 or epoch == 1: + test_acc = evaluate(model, test_loader, device) + print(f" [BP ref] Epoch {epoch}: loss={total_loss/total:.4f}, " + f"train_acc={correct/total:.4f}, test_acc={test_acc:.4f}") + + test_acc = evaluate(model, test_loader, device) + print(f" [BP ref] Final test accuracy: {test_acc:.4f}") + return test_acc + + +# ============================================================================= +# Step 2: Train estimators on frozen features +# ============================================================================= + +def train_state_bridge_frozen(model, train_loader, device, args): + """Train state bridge on frozen BP features.""" + d = model.d_hidden + L = model.num_blocks + num_classes = 10 + + state_pred = StateBridgeNet( + d_hidden=d, s_dim=num_classes, time_embed_dim=32, + hidden_dim=256, num_layers=3 + ).to(device) + state_opt = optim.Adam(state_pred.parameters(), lr=args.lr_fb) + + model.eval() + for epoch in range(1, args.estimator_epochs + 1): + state_pred.train() + total_loss = 0 + n = 0 + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + hL_det = hiddens[-1].detach() + + # Train state predictor + state_loss = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + pred_hL = state_pred(h_l_det, t_l, s) + target_norm = hL_det.norm(dim=-1, keepdim=True).clamp(min=1.0) + state_loss = state_loss + (((pred_hL - hL_det) / target_norm) ** 2).sum(dim=-1).mean() + state_loss = state_loss / L + + state_opt.zero_grad() + state_loss.backward() + state_opt.step() + total_loss += state_loss.item() * batch + n += batch + + if epoch % 20 == 0 or epoch == 1: + print(f" [SB] Epoch {epoch}: state_loss={total_loss/n:.6f}") + + return state_pred + + +def train_scalar_cb_frozen(model, train_loader, device, args, s_type='eT'): + """ + Train scalar credit bridge on frozen BP features. + s_type: 'eT' (softmax error, dim=10) or 'deltaL' (grad_{h_L} CE, dim=d_hidden) + """ + d = model.d_hidden + L = model.num_blocks + num_classes = 10 + + if s_type == 'eT': + s_dim = num_classes + elif s_type == 'deltaL': + s_dim = d + else: + raise ValueError(f"Unknown s_type: {s_type}") + + value_net = ValueNet( + d_hidden=d, s_dim=s_dim, time_embed_dim=32, + hidden_dim=256, num_layers=3 + ).to(device) + value_net_ema = create_ema_model(value_net) + value_opt = optim.Adam(value_net.parameters(), lr=args.lr_fb) + + lam = args.lam + K_samples = args.K + sigma_bridge = args.sigma_bridge + ema_momentum = args.ema_momentum + term_grad_weight = args.term_grad_weight + + model.eval() + for epoch in range(1, args.estimator_epochs + 1): + value_net.train() + total_vloss = 0 + total_term = 0 + total_tgrad = 0 + total_bridge = 0 + n = 0 + + for x, y in train_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + true_loss = F.cross_entropy(logits, y, reduction='none').detach() + + hL_det = hiddens[-1].detach() + + # Compute s (conditioning code) + if s_type == 'eT': + s = e_T.detach() + elif s_type == 'deltaL': + # delta_L = grad_{h_L} CE (output-layer-local, allowed) + hL_req = hL_det.clone().requires_grad_(True) + logits_for_s = model.out_head(model.out_ln(hL_req)) + ce_for_s = F.cross_entropy(logits_for_s, y, reduction='sum') + delta_L = torch.autograd.grad(ce_for_s, hL_req, create_graph=False)[0].detach() + s = delta_L + + # Terminal boundary + t_L = torch.ones(batch, device=device) + V_terminal = value_net(hL_det, t_L, s) + loss_term = ((V_terminal - true_loss) ** 2).mean() + + # Terminal gradient matching + loss_tgrad = torch.tensor(0.0, device=device) + if term_grad_weight > 0: + hL_req = hL_det.clone().requires_grad_(True) + V_at_L = value_net(hL_req, t_L, s) + grad_V_L = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0] + # Exact terminal gradient (output-layer-local) + hL_req2 = hL_det.clone().requires_grad_(True) + logits_tgt = model.out_head(model.out_ln(hL_req2)) + ce_loss = F.cross_entropy(logits_tgt, y, reduction='sum') + a_L_exact = torch.autograd.grad(ce_loss, hL_req2, create_graph=False)[0].detach() + loss_tgrad = ((grad_V_L - a_L_exact) ** 2).sum(dim=-1).mean() + + # Bridge consistency + loss_bridge = 0.0 + for l in range(L): + h_l_det = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + t_l_next = torch.full((batch,), (l + 1) / L, device=device) + V_l = value_net(h_l_det, t_l, s) + + with torch.no_grad(): + h_next_det = hiddens[l + 1].detach() + log_terms = [] + for k in range(K_samples): + noise = sigma_bridge * torch.randn_like(h_next_det) + V_next = value_net_ema(h_next_det + noise, t_l_next, s) + log_terms.append(-V_next / lam) + log_stack = torch.stack(log_terms, dim=-1) + V_target = -lam * (torch.logsumexp(log_stack, dim=-1) - np.log(K_samples)) + + loss_bridge = loss_bridge + ((V_l - V_target.detach()) ** 2).mean() + loss_bridge = loss_bridge / L + + value_loss = loss_term + loss_bridge + term_grad_weight * loss_tgrad + + value_opt.zero_grad() + value_loss.backward() + torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0) + value_opt.step() + update_ema(value_net, value_net_ema, ema_momentum) + + total_vloss += value_loss.item() * batch + total_term += loss_term.item() * batch + total_tgrad += loss_tgrad.item() * batch + total_bridge += (loss_bridge.item() if isinstance(loss_bridge, torch.Tensor) else loss_bridge) * batch + n += batch + + if epoch % 20 == 0 or epoch == 1: + print(f" [CB_{s_type}] Epoch {epoch}: vloss={total_vloss/n:.6f}, " + f"term={total_term/n:.6f}, tgrad={total_tgrad/n:.6f}, bridge={total_bridge/n:.6f}") + + return value_net, value_net_ema + + +# ============================================================================= +# Step 3: Evaluate credit quality on frozen features +# ============================================================================= + +def evaluate_credits(model, test_loader, device, estimators, args): + """ + Evaluate credit quality for all estimators on frozen BP features. + + Args: + estimators: dict of {name: {'type': 'sb'/'cb', 'net': ..., 's_type': ...}} + Returns: + dict of {name: {per-layer metrics}} + """ + model.eval() + d = model.d_hidden + L = model.num_blocks + num_classes = 10 + + # Accumulate over multiple test batches for robust statistics + all_results = {} + for name in estimators: + all_results[name] = { + 'bp_cosine': [[] for _ in range(L)], + 'perturbation_rho': [0.0] * L, + 'nudging_0.001': [0.0] * L, + 'nudging_0.003': [0.0] * L, + 'nudging_0.01': [0.0] * L, + } + + # Also add DFA baseline + dfa_Bs = [torch.randn(d, num_classes, device=device) / np.sqrt(num_classes) for _ in range(L)] + all_results['dfa'] = { + 'bp_cosine': [[] for _ in range(L)], + 'perturbation_rho': [0.0] * L, + 'nudging_0.001': [0.0] * L, + 'nudging_0.003': [0.0] * L, + 'nudging_0.01': [0.0] * L, + } + + n_batches_diag = min(10, len(test_loader)) # Use multiple batches + batch_idx = 0 + + for x, y in test_loader: + if batch_idx >= n_batches_diag: + break + batch_idx += 1 + + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + # Get BP gradients (ground truth for Gamma) + # Temporarily enable grad on model params for BP gradient computation + for p in model.parameters(): + p.requires_grad_(True) + model.zero_grad() + logits_bp, hiddens_bp = model(x, return_hidden=True) + for l in range(L + 1): + hiddens_bp[l].retain_grad() + loss_bp = F.cross_entropy(logits_bp, y) + loss_bp.backward() + bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)} + # Re-freeze model + for p in model.parameters(): + p.requires_grad_(False) + + # Clean forward + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s_eT = e_T.detach() + + hL_det = hiddens[-1].detach() + + # Compute delta_L for deltaL conditioning + hL_req = hL_det.clone().requires_grad_(True) + logits_for_delta = model.out_head(model.out_ln(hL_req)) + ce_for_delta = F.cross_entropy(logits_for_delta, y, reduction='sum') + delta_L = torch.autograd.grad(ce_for_delta, hL_req, create_graph=False)[0].detach() + + for l in range(L): + h_l = hiddens[l].detach() + t_l = torch.full((batch,), l / L, device=device) + + # Forward function for perturbation and nudging + def make_fwd_fn(start_l): + def fwd_fn(h): + with torch.no_grad(): + curr = h + for i in range(start_l, L): + curr = curr + model.blocks[i](curr) + out = model.out_head(model.out_ln(curr)) + return F.cross_entropy(out, y, reduction='none') + return fwd_fn + + fwd_fn = make_fwd_fn(l) + + # --- DFA credit --- + a_dfa = (s_eT @ dfa_Bs[l].T).detach() + bp_cos_dfa = cosine_similarity_batch(a_dfa, bp_grads[l]) + all_results['dfa']['bp_cosine'][l].append(bp_cos_dfa) + + if batch_idx == 1: # Only compute rho/nudging on first batch (expensive) + rho_dfa = perturbation_correlation(h_l, a_dfa, fwd_fn, epsilon=1e-3, M=32) + all_results['dfa']['perturbation_rho'][l] = rho_dfa + for eta in [0.001, 0.003, 0.01]: + nud = nudging_test(h_l, a_dfa, fwd_fn, eta=eta) + all_results['dfa'][f'nudging_{eta}'][l] = nud + + # --- Estimator credits --- + for name, est in estimators.items(): + if est['type'] == 'sb': + net = est['net'] + net.eval() + h_l_req = h_l.clone().requires_grad_(True) + pred_hL = net(h_l_req, t_l, s_eT) + pred_logits = model.out_head(model.out_ln(pred_hL)) + pred_loss = F.cross_entropy(pred_logits, y, reduction='sum') + a_l = torch.autograd.grad(pred_loss, h_l_req, create_graph=False)[0].detach() + + elif est['type'] == 'cb': + net = est['net'] + net.eval() + s_type = est['s_type'] + if s_type == 'eT': + s = s_eT + elif s_type == 'deltaL': + s = delta_L + else: + raise ValueError(f"Unknown s_type: {s_type}") + + h_l_req = h_l.clone().requires_grad_(True) + V_l = net(h_l_req, t_l, s) + a_l = torch.autograd.grad(V_l.sum(), h_l_req, create_graph=False)[0].detach() + else: + raise ValueError(f"Unknown estimator type: {est['type']}") + + bp_cos = cosine_similarity_batch(a_l, bp_grads[l]) + all_results[name]['bp_cosine'][l].append(bp_cos) + + if batch_idx == 1: + rho = perturbation_correlation(h_l, a_l, fwd_fn, epsilon=1e-3, M=32) + all_results[name]['perturbation_rho'][l] = rho + for eta in [0.001, 0.003, 0.01]: + nud = nudging_test(h_l, a_l, fwd_fn, eta=eta) + all_results[name][f'nudging_{eta}'][l] = nud + + # Average bp_cosine over batches + for name in all_results: + for l in range(L): + vals = all_results[name]['bp_cosine'][l] + all_results[name]['bp_cosine'][l] = float(np.mean(vals)) if vals else 0.0 + + return all_results + + +def evaluate_state_bridge_pred_error(model, state_pred, test_loader, device): + """Evaluate state bridge's terminal state prediction error.""" + model.eval() + state_pred.eval() + L = model.num_blocks + + total_error = [0.0] * L + n = 0 + for x, y in test_loader: + x = x.view(x.size(0), -1).to(device) + y = y.to(device) + batch = x.size(0) + + with torch.no_grad(): + logits, hiddens = model(x, return_hidden=True) + e_T = logits.softmax(dim=-1) + e_T[torch.arange(batch), y] -= 1 + s = e_T.detach() + hL = hiddens[-1] + + for l in range(L): + h_l = hiddens[l] + t_l = torch.full((batch,), l / L, device=x.device) + pred_hL = state_pred(h_l, t_l, s) + error = ((pred_hL - hL) ** 2).sum(dim=-1).mean().item() + total_error[l] += error * batch + n += batch + + return [e / n for e in total_error] + + +# ============================================================================= +# Main experiment +# ============================================================================= +def run_experiment(args): + device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu') + print(f"Using device: {device}") + os.makedirs(args.output_dir, exist_ok=True) + + torch.manual_seed(args.seed) + np.random.seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + + train_loader, test_loader = get_cifar10(batch_size=args.batch_size) + input_dim = 32 * 32 * 3 + num_classes = 10 + + # ----- Step 1: Train BP reference ----- + print(f"\n{'='*60}") + print(f"Step 1: Train BP reference (L={args.num_blocks}, d={args.d_hidden})") + print(f"{'='*60}") + + bp_ckpt_path = os.path.join(args.output_dir, f'bp_ref_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.pt') + + model = ResidualMLP(input_dim, args.d_hidden, num_classes, args.num_blocks).to(device) + + if os.path.exists(bp_ckpt_path) and not args.retrain_bp: + print(f" Loading BP reference from {bp_ckpt_path}") + model.load_state_dict(torch.load(bp_ckpt_path, map_location=device)) + bp_acc = evaluate(model, test_loader, device) + print(f" BP reference test accuracy: {bp_acc:.4f}") + else: + bp_acc = train_bp_reference(model, train_loader, test_loader, device, + epochs=args.bp_epochs, lr=args.lr, wd=args.wd) + torch.save(model.state_dict(), bp_ckpt_path) + print(f" Saved BP reference to {bp_ckpt_path}") + + # Freeze the model completely + model.eval() + for p in model.parameters(): + p.requires_grad_(False) + + # ----- Step 2: Train estimators ----- + print(f"\n{'='*60}") + print(f"Step 2: Train estimators ({args.estimator_epochs} epochs each)") + print(f"{'='*60}") + + estimators = {} + + # 2a. State Bridge with s=eT + print("\n--- State Bridge (s=eT) ---") + torch.manual_seed(args.seed + 1000) + sb = train_state_bridge_frozen(model, train_loader, device, args) + estimators['sb_eT'] = {'type': 'sb', 'net': sb, 's_type': 'eT'} + + # 2b. Scalar CB with s=eT + print("\n--- Scalar CB (s=eT) ---") + torch.manual_seed(args.seed + 2000) + cb_eT, cb_eT_ema = train_scalar_cb_frozen(model, train_loader, device, args, s_type='eT') + estimators['cb_eT'] = {'type': 'cb', 'net': cb_eT, 's_type': 'eT'} + + # 2c. Scalar CB with s=deltaL + print("\n--- Scalar CB (s=deltaL) ---") + torch.manual_seed(args.seed + 3000) + cb_dL, cb_dL_ema = train_scalar_cb_frozen(model, train_loader, device, args, s_type='deltaL') + estimators['cb_deltaL'] = {'type': 'cb', 'net': cb_dL, 's_type': 'deltaL'} + + # ----- Step 3: Evaluate ----- + print(f"\n{'='*60}") + print(f"Step 3: Evaluate credit quality") + print(f"{'='*60}") + + results = evaluate_credits(model, test_loader, device, estimators, args) + + # State bridge prediction error + sb_pred_error = evaluate_state_bridge_pred_error(model, sb, test_loader, device) + + # ----- Print results ----- + L = args.num_blocks + print(f"\n{'='*60}") + print(f"RESULTS: Frozen CIFAR Credit Recovery (L={L}, d={args.d_hidden}, seed={args.seed})") + print(f"BP reference test accuracy: {bp_acc:.4f}") + print(f"{'='*60}") + + # Summary table + methods = ['dfa', 'sb_eT', 'cb_eT', 'cb_deltaL'] + method_labels = { + 'dfa': 'DFA (random)', + 'sb_eT': 'State Bridge (eT)', + 'cb_eT': 'Scalar CB (eT)', + 'cb_deltaL': 'Scalar CB (deltaL)', + } + + print(f"\n{'Method':<25} {'mean Gamma':>12} {'mean rho':>12} {'mean nudge':>12}") + print("-" * 65) + + summary = {} + for m in methods: + r = results[m] + mean_gamma = np.mean(r['bp_cosine']) + mean_rho = np.mean(r['perturbation_rho']) + mean_nudge = np.mean(r['nudging_0.003']) + summary[m] = { + 'mean_gamma': float(mean_gamma), + 'mean_rho': float(mean_rho), + 'mean_nudge': float(mean_nudge), + } + print(f"{method_labels[m]:<25} {mean_gamma:>12.4f} {mean_rho:>12.4f} {mean_nudge:>12.6f}") + + # Per-layer detail + print(f"\n--- Per-layer Gamma ---") + header = f"{'Layer':<8}" + for m in methods: + header += f" {method_labels[m]:>16}" + print(header) + for l in range(L): + row = f" {l:<6}" + for m in methods: + row += f" {results[m]['bp_cosine'][l]:>16.4f}" + print(row) + + print(f"\n--- Per-layer rho ---") + print(header) + for l in range(L): + row = f" {l:<6}" + for m in methods: + row += f" {results[m]['perturbation_rho'][l]:>16.4f}" + print(row) + + print(f"\n--- Per-layer nudge (eta=0.003) ---") + print(header) + for l in range(L): + row = f" {l:<6}" + for m in methods: + row += f" {results[m]['nudging_0.003'][l]:>16.6f}" + print(row) + + print(f"\n--- State Bridge prediction error per layer ---") + for l in range(L): + print(f" Layer {l}: {sb_pred_error[l]:.6f}") + + # ----- Save all results ----- + save_data = { + 'config': { + 'num_blocks': args.num_blocks, + 'd_hidden': args.d_hidden, + 'seed': args.seed, + 'bp_epochs': args.bp_epochs, + 'estimator_epochs': args.estimator_epochs, + 'lr_fb': args.lr_fb, + 'lam': args.lam, + 'K': args.K, + 'sigma_bridge': args.sigma_bridge, + 'ema_momentum': args.ema_momentum, + 'term_grad_weight': args.term_grad_weight, + }, + 'bp_acc': float(bp_acc), + 'summary': summary, + 'per_layer': {}, + 'sb_pred_error': sb_pred_error, + } + + for m in methods: + save_data['per_layer'][m] = { + 'bp_cosine': results[m]['bp_cosine'], + 'perturbation_rho': results[m]['perturbation_rho'], + 'nudging_0.001': results[m]['nudging_0.001'], + 'nudging_0.003': results[m]['nudging_0.003'], + 'nudging_0.01': results[m]['nudging_0.01'], + } + + out_path = os.path.join(args.output_dir, + f'frozen_L{args.num_blocks}_d{args.d_hidden}_s{args.seed}.json') + with open(out_path, 'w') as f: + json.dump(save_data, f, indent=2) + print(f"\nResults saved to {out_path}") + + # ----- Judgment ----- + print(f"\n{'='*60}") + print("JUDGMENT") + print(f"{'='*60}") + + best_cb = max(summary['cb_eT']['mean_rho'], summary['cb_deltaL']['mean_rho']) + dfa_rho = summary['dfa']['mean_rho'] + best_cb_gamma = max(summary['cb_eT']['mean_gamma'], summary['cb_deltaL']['mean_gamma']) + dfa_gamma = summary['dfa']['mean_gamma'] + + if best_cb > dfa_rho + 0.02 and best_cb_gamma > dfa_gamma: + print("POSITIVE: Scalar CB recovers credit that is clearly better than DFA.") + print(" -> Bottleneck is in forward exploitability / local update, not estimator.") + print(" -> Next: Phase B (online shallow CIFAR).") + elif best_cb > 0.02: + print("MARGINAL: Scalar CB shows some signal but not clearly better than DFA.") + print(" -> Need more investigation before concluding estimator is the bottleneck.") + else: + print("NEGATIVE: Scalar CB cannot recover useful credit even on frozen features.") + print(" -> Estimator parameterization is the bottleneck.") + print(" -> Next: Phase C (direct vector field pilot).") + + return save_data + + +def main(): + parser = argparse.ArgumentParser(description='Frozen CIFAR Credit Recovery') + parser.add_argument('--num_blocks', type=int, default=4) + parser.add_argument('--d_hidden', type=int, default=256) + parser.add_argument('--batch_size', type=int, default=128) + parser.add_argument('--bp_epochs', type=int, default=100, + help='Epochs to train BP reference') + parser.add_argument('--estimator_epochs', type=int, default=100, + help='Epochs to train each estimator on frozen features') + parser.add_argument('--lr', type=float, default=1e-3, help='LR for BP reference') + parser.add_argument('--lr_fb', type=float, default=1e-3, help='LR for estimators') + parser.add_argument('--wd', type=float, default=0.01) + parser.add_argument('--lam', type=float, default=0.1) + parser.add_argument('--K', type=int, default=4) + parser.add_argument('--sigma_bridge', type=float, default=0.05) + parser.add_argument('--ema_momentum', type=float, default=0.995) + parser.add_argument('--term_grad_weight', type=float, default=1.0) + parser.add_argument('--seed', type=int, default=42) + parser.add_argument('--gpu', type=int, default=2) + parser.add_argument('--output_dir', type=str, default='results/frozen_cifar') + parser.add_argument('--retrain_bp', action='store_true') + args = parser.parse_args() + run_experiment(args) + + +if __name__ == '__main__': + main() -- cgit v1.2.3