summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--NOTE.md46
-rw-r--r--experiments/local_update_swap.py427
-rw-r--r--experiments/snapshot_exploitability.py713
-rw-r--r--report_explore/MEMO_6A_snapshot_exploitability.md39
-rw-r--r--report_explore/MEMO_6_exploitability.md53
5 files changed, 1277 insertions, 1 deletions
diff --git a/NOTE.md b/NOTE.md
index 36e0be9..6242a41 100644
--- a/NOTE.md
+++ b/NOTE.md
@@ -5,7 +5,7 @@
- **pilot**: Controlled iteration (commits 0b9ebb2, 7baf7ae)
- **frozen**: Code at commit 0b9ebb2 for all reported results
-## Status: PHASE 5 VECTOR FIELD AUDIT + TRANSFER COMPLETE
+## Status: PHASE 6 EXPLOITABILITY DISSECTION COMPLETE
---
@@ -329,3 +329,47 @@ the signal.
- `vector_audit_full/`: Phase 5A full 3-seed audit
- `frozen_cifar_vec/`: Phase 5B frozen CIFAR vector transfer
- `online_vec_pilot/`: Phase 5C online CIFAR vector pilot
+
+---
+
+## Phase 6: Exploitability Dissection
+
+### Phase 6A: Snapshot Exploitability
+
+**Setup**: BP-trained CIFAR snapshot (L=4, d=256, 61.9% acc).
+Offline-trained estimators. k-step local updates with real loss measurement.
+
+**CRITICAL FINDING: Better credit → worse loss decrease.**
+
+| Credit | Gamma | rho | dL_5step (inner_product) |
+|--------|-------|-----|-------------------------|
+| DFA | 0.009 | -0.023 | **-0.0001** (only negative!) |
+| ScalarCB | 0.122 | 0.090 | +0.042 |
+| Vec_M4 | 0.378 | 0.411 | +0.057 |
+| Oracle BP | 1.000 | 0.998 | +0.011 |
+
+Credit quality is ANTI-CORRELATED with loss decrease.
+DFA (worst credit) is the only method not increasing loss.
+
+### Phase 6C: Local Update Rule Swap
+
+Tested target-shift (`h_target = h_{l+1} - eta * a_norm`) at eta in {0.01, 0.1, 0.3, 1.0}.
+
+Target-shift reduces damage (Vec dL: +0.057 → +0.002 at eta=0.1) but never achieves
+negative DeltaLoss for any non-DFA credit. Cosine rule produces near-zero effects.
+
+### Root Cause
+
+The inner-product surrogate `<F_l(h), a>` is not a valid proxy for global loss minimization.
+The gradient of this surrogate w.r.t. block parameters ≠ gradient of global loss w.r.t. same parameters.
+A BP-trained snapshot is at a minimum reachable only by full BP; local updates systematically push uphill.
+
+DFA works because its credits are weak enough to produce near-zero updates, effectively doing nothing.
+
+### This is Case B from the diagnostic logic tree:
+Better credit does NOT lead to better snapshot loss decrease.
+**The primary bottleneck is the local update rule itself, not the estimator or tracking.**
+
+### Experiment IDs (Phase 6)
+- `snapshot_exploit/`: Phase 6A snapshot exploitability
+- `update_swap/`: Phase 6C local update rule comparison
diff --git a/experiments/local_update_swap.py b/experiments/local_update_swap.py
new file mode 100644
index 0000000..207560a
--- /dev/null
+++ b/experiments/local_update_swap.py
@@ -0,0 +1,427 @@
+"""
+Phase 6C: Local Update Rule Swap.
+
+Compare different local update rules using the same credit signals on a fixed snapshot.
+
+Rule 1 (baseline): Inner-product surrogate
+ L_inner = <F_l(h_l), a_{l+1}>
+
+Rule 2: Target-shift local regression
+ h_{l+1}^target = h_{l+1} - eta_target * a_{l+1}^norm
+ L_shift = 0.5 * || h_l + F_l(h_l) - sg(h_{l+1}^target) ||^2
+
+Rule 3: Cosine-target update
+ L_cos = - cos(F_l(h_l), a_{l+1})
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from metrics.credit_metrics import cosine_similarity_batch, perturbation_correlation, nudging_test
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def get_credits(model, x, y, device, credit_source, estimator=None, dfa_Bs=None):
+ L = model.num_blocks
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ credits = {}
+ if credit_source == 'dfa':
+ for l in range(L):
+ credits[l] = (s @ dfa_Bs[l].T).detach()
+ elif credit_source == 'scalar_cb':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V = estimator(h_l, t_l, s)
+ credits[l] = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0].detach()
+ elif credit_source == 'vec':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ credits[l] = estimator(h_l, t_l, s).detach()
+ elif credit_source == 'oracle_bp':
+ for p in model.parameters():
+ p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ F.cross_entropy(logits_bp, y).backward()
+ for l in range(L):
+ credits[l] = hiddens_bp[l].grad.detach().clone()
+ for p in model.parameters():
+ p.requires_grad_(False)
+ return credits, hiddens, s
+
+
+# =============================================================================
+# Local update rules
+# =============================================================================
+def update_inner_product(model, x, y, credits, hiddens, device, lr):
+ """Rule 1: L_inner = <F_l(h_l), a_{l+1}>"""
+ L = model.num_blocks
+ # Head
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(lr * g)
+ # Blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+ # Embed
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean()
+ embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters())
+ with torch.no_grad():
+ for p, g in zip(model.embed.parameters(), embed_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+
+
+def update_target_shift(model, x, y, credits, hiddens, device, lr, eta_target=0.01):
+ """
+ Rule 2: Target-shift local regression.
+ h_{l+1}^target = h_{l+1} - eta_target * a_{l+1}^norm
+ L_shift = 0.5 * || (h_l + F_l(h_l)) - sg(h_{l+1}^target) ||^2
+ """
+ L = model.num_blocks
+ # Head — still use exact CE
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(lr * g)
+
+ # Blocks: target-shift regression
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ h_l_next = hiddens[l + 1].detach() # current h_{l+1}
+
+ # Credit at layer l+1 (or l for the last one)
+ # We use credit[l] which is the credit at layer l
+ # The target shift: move h_{l+1} in the negative credit direction
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+
+ # Target: where h_{l+1} should move toward
+ h_target = (h_l_next - eta_target * a_norm).detach()
+
+ # Compute F_l(h_l) with gradient
+ f_l = model.blocks[l](h_l)
+ h_l_next_pred = h_l + f_l # predicted h_{l+1}
+
+ # Regression loss
+ shift_loss = 0.5 * ((h_l_next_pred - h_target) ** 2).sum(dim=-1).mean()
+ block_grads = torch.autograd.grad(shift_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+
+ # Embed: use credit[0] as target shift for h_0
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ h0_target = (hiddens[0].detach() - eta_target * (a_0 / rms_0)).detach()
+ embed_loss = 0.5 * ((h0 - h0_target) ** 2).sum(dim=-1).mean()
+ embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters())
+ with torch.no_grad():
+ for p, g in zip(model.embed.parameters(), embed_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+
+
+def update_cosine_target(model, x, y, credits, hiddens, device, lr):
+ """Rule 3: L_cos = -cos(F_l(h_l), a_{l+1})"""
+ L = model.num_blocks
+ # Head
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(lr * g)
+ # Blocks
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ f_l = model.blocks[l](h_l)
+ cos_sim = F.cosine_similarity(f_l, a, dim=-1).mean()
+ local_loss = -cos_sim
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+ # Embed
+ a_0 = credits[0]
+ h0 = model.embed(x)
+ cos_sim_0 = F.cosine_similarity(h0, a_0, dim=-1).mean()
+ embed_loss = -cos_sim_0
+ embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters())
+ with torch.no_grad():
+ for p, g in zip(model.embed.parameters(), embed_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+
+
+# =============================================================================
+# Main
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(args.batch_size)
+ input_dim = 32 * 32 * 3
+ L = args.num_blocks
+ d = args.d_hidden
+
+ # Load BP snapshot
+ model_bp = ResidualMLP(input_dim, d, 10, L).to(device)
+ bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt'
+ model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device))
+ model_bp.eval()
+ for p in model_bp.parameters():
+ p.requires_grad_(False)
+ print(f"Loaded BP snapshot from {bp_ckpt}")
+
+ # Load pre-trained estimators (or train fresh)
+ # DFA
+ dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ # Scalar CB — train on snapshot
+ print("\nTraining ScalarCB on snapshot...")
+ from experiments.snapshot_exploitability import train_scalar_cb_on_snapshot, train_vector_on_snapshot
+ torch.manual_seed(args.seed + 2000)
+ cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb)
+
+ # Vector field — train on snapshot
+ print("\nTraining Vec_M4 on snapshot...")
+ torch.manual_seed(args.seed + 4000)
+ vec4 = train_vector_on_snapshot(model_bp, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4)
+
+ credit_sources = {
+ 'dfa': ('dfa', None, dfa_Bs),
+ 'scalar_cb': ('scalar_cb', cb, None),
+ 'vec_eT_M4': ('vec', vec4, None),
+ 'oracle_bp': ('oracle_bp', None, None),
+ }
+
+ update_rules = {
+ 'inner_product': update_inner_product,
+ 'target_shift': lambda m, x, y, c, h, dev, lr: update_target_shift(m, x, y, c, h, dev, lr, eta_target=args.eta_target),
+ 'cosine_target': update_cosine_target,
+ }
+
+ # Eval function
+ eval_batches = []
+ for i, (xv, yv) in enumerate(test_loader):
+ if i >= 10:
+ break
+ eval_batches.append((xv.view(xv.size(0), -1).to(device), yv.to(device)))
+
+ def eval_model(model):
+ model.eval()
+ total_loss, correct, total = 0, 0, 0
+ with torch.no_grad():
+ for xv, yv in eval_batches:
+ logits = model(xv)
+ total_loss += F.cross_entropy(logits, yv, reduction='sum').item()
+ correct += (logits.argmax(1) == yv).sum().item()
+ total += xv.size(0)
+ return total_loss / total, correct / total
+
+ # =========================================================
+ # Run all combinations: credit_source x update_rule x k_steps
+ # =========================================================
+ results = {}
+
+ for cs_name, (src, est, Bs) in credit_sources.items():
+ for rule_name, rule_fn in update_rules.items():
+ for k in [1, 5, 20]:
+ tag = f"{cs_name}_{rule_name}_k{k}"
+
+ model_test = copy.deepcopy(model_bp)
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+
+ loss_before, acc_before = eval_model(model_test)
+
+ train_iter = iter(train_loader)
+ for step in range(k):
+ try:
+ x_step, y_step = next(train_iter)
+ except StopIteration:
+ train_iter = iter(train_loader)
+ x_step, y_step = next(train_iter)
+ x_step = x_step.view(x_step.size(0), -1).to(device)
+ y_step = y_step.to(device)
+
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+ credits, hiddens, s = get_credits(model_test, x_step, y_step, device,
+ src, estimator=est, dfa_Bs=Bs)
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+ rule_fn(model_test, x_step, y_step, credits, hiddens, device, lr=args.lr_update)
+
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+ loss_after, acc_after = eval_model(model_test)
+
+ results[tag] = {
+ 'credit': cs_name, 'rule': rule_name, 'k': k,
+ 'loss_before': loss_before, 'loss_after': loss_after,
+ 'delta_loss': loss_after - loss_before,
+ 'delta_acc': acc_after - acc_before,
+ }
+
+ # =========================================================
+ # Summary tables
+ # =========================================================
+ print(f"\n{'='*90}")
+ print("RESULTS: DeltaLoss (negative = good)")
+ print(f"{'='*90}")
+
+ for k in [1, 5, 20]:
+ print(f"\n--- k={k} steps ---")
+ print(f"{'Credit':<15} {'inner_prod':>12} {'target_shift':>14} {'cosine':>12}")
+ print("-" * 58)
+ for cs_name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']:
+ row = f"{cs_name:<15}"
+ for rule_name in ['inner_product', 'target_shift', 'cosine_target']:
+ tag = f"{cs_name}_{rule_name}_k{k}"
+ dl = results[tag]['delta_loss']
+ row += f" {dl:>+12.4f}"
+ print(row)
+
+ # Save
+ out_path = os.path.join(args.output_dir, f'update_swap_L{L}_d{d}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(results, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+ # Judgment
+ print(f"\n{'='*60}")
+ print("JUDGMENT")
+ print(f"{'='*60}")
+
+ # Compare at k=5
+ inner_vec = results['vec_eT_M4_inner_product_k5']['delta_loss']
+ shift_vec = results['vec_eT_M4_target_shift_k5']['delta_loss']
+ shift_bp = results['oracle_bp_target_shift_k5']['delta_loss']
+ inner_dfa = results['dfa_inner_product_k5']['delta_loss']
+
+ print(f"k=5: Vec+inner={inner_vec:+.4f}, Vec+shift={shift_vec:+.4f}, "
+ f"BP+shift={shift_bp:+.4f}, DFA+inner={inner_dfa:+.4f}")
+
+ if shift_vec < inner_vec and shift_vec < 0:
+ print("TARGET-SHIFT WINS: Vec credit becomes exploitable with target-shift rule.")
+ print(" -> Project should pivot to 'credit + better local update coupling'.")
+ elif shift_bp < 0 and shift_vec >= 0:
+ print("TARGET-SHIFT HELPS BP BUT NOT VEC: Credit quality still matters.")
+ else:
+ print("TARGET-SHIFT DOESN'T HELP: Need further investigation.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 6C: Local Update Rule Swap')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--estimator_epochs', type=int, default=100)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--lr_update', type=float, default=1e-3)
+ parser.add_argument('--eta_target', type=float, default=0.01)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/update_swap')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/experiments/snapshot_exploitability.py b/experiments/snapshot_exploitability.py
new file mode 100644
index 0000000..ce07acc
--- /dev/null
+++ b/experiments/snapshot_exploitability.py
@@ -0,0 +1,713 @@
+"""
+Phase 6A: Snapshot Exploitability Test.
+
+Core question: on a fixed network snapshot, does better credit lead to
+better real loss decrease via the current local surrogate update?
+
+Tests:
+A1. Single global step DeltaLoss
+A2. k-step short rollout (k=1,5,20)
+A3. Per-layer ablation (last-1, last-2, all blocks)
+
+Credit sources: DFA, ScalarCB_eT, Vec_eT_M4, Oracle BP gradient
+Snapshot sources: BP-trained, DFA-warmup
+"""
+import os
+import sys
+import json
+import argparse
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader
+import torchvision
+import torchvision.transforms as transforms
+import copy
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+from models.value_net import ValueNet, SinusoidalTimeEmbed, create_ema_model, update_ema
+from metrics.credit_metrics import (
+ cosine_similarity_batch, perturbation_correlation, nudging_test
+)
+
+
+class VectorCreditNet(nn.Module):
+ def __init__(self, d_hidden, s_dim, time_embed_dim=32, hidden_dim=256, num_layers=3):
+ super().__init__()
+ self.ln = nn.LayerNorm(d_hidden)
+ self.time_embed = SinusoidalTimeEmbed(time_embed_dim)
+ input_dim = d_hidden + time_embed_dim + s_dim
+ layers = []
+ for i in range(num_layers):
+ in_d = input_dim if i == 0 else hidden_dim
+ layers.append(nn.Linear(in_d, hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, d_hidden))
+ self.net = nn.Sequential(*layers)
+
+ def forward(self, h, t, s):
+ h_normed = self.ln(h)
+ t_emb = self.time_embed(t)
+ inp = torch.cat([h_normed, t_emb, s], dim=-1)
+ return self.net(inp)
+
+
+def get_cifar10(batch_size=128):
+ transform_train = transforms.Compose([
+ transforms.RandomCrop(32, padding=4),
+ transforms.RandomHorizontalFlip(),
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ transform_test = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ trainset = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
+ testset = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=transform_test)
+ train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=4, pin_memory=True)
+ test_loader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=4, pin_memory=True)
+ return train_loader, test_loader
+
+
+def evaluate_loss_acc(model, data_iter, device, n_batches=5):
+ """Evaluate loss and accuracy on a few batches."""
+ model.eval()
+ total_loss, correct, total = 0, 0, 0
+ with torch.no_grad():
+ for i, (x, y) in enumerate(data_iter):
+ if i >= n_batches:
+ break
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ logits = model(x)
+ loss = F.cross_entropy(logits, y, reduction='sum')
+ total_loss += loss.item()
+ correct += (logits.argmax(1) == y).sum().item()
+ total += x.size(0)
+ return total_loss / total, correct / total
+
+
+# =============================================================================
+# Train estimators on frozen snapshot
+# =============================================================================
+def train_scalar_cb_on_snapshot(model, train_loader, device, epochs=100, lr_fb=1e-3):
+ d = model.d_hidden
+ L = model.num_blocks
+ value_net = ValueNet(d_hidden=d, s_dim=10, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ value_net_ema = create_ema_model(value_net)
+ value_opt = optim.Adam(value_net.parameters(), lr=lr_fb)
+ model.eval()
+ for epoch in range(1, epochs + 1):
+ value_net.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ true_loss = F.cross_entropy(logits, y, reduction='none').detach()
+ hL = hiddens[-1].detach()
+ t_L = torch.ones(batch, device=device)
+ V_term = value_net(hL, t_L, s)
+ loss_term = ((V_term - true_loss) ** 2).mean()
+
+ hL_req = hL.clone().requires_grad_(True)
+ V_at_L = value_net(hL_req, t_L, s)
+ grad_V = torch.autograd.grad(V_at_L.sum(), hL_req, create_graph=True)[0]
+ hL_req2 = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req2))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ a_L_exact = torch.autograd.grad(ce, hL_req2, create_graph=False)[0].detach()
+ loss_tgrad = ((grad_V - a_L_exact) ** 2).sum(dim=-1).mean()
+
+ loss_bridge = 0.0
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ t_next = torch.full((batch,), (l + 1) / L, device=device)
+ V_l = value_net(h_l, t_l, s)
+ with torch.no_grad():
+ h_next = hiddens[l + 1].detach()
+ log_terms = []
+ for k in range(4):
+ noise = 0.05 * torch.randn_like(h_next)
+ V_next = value_net_ema(h_next + noise, t_next, s)
+ log_terms.append(-V_next / 0.1)
+ log_stack = torch.stack(log_terms, dim=-1)
+ V_target = -0.1 * (torch.logsumexp(log_stack, dim=-1) - np.log(4))
+ loss_bridge += ((V_l - V_target.detach()) ** 2).mean()
+ loss_bridge /= L
+
+ vloss = loss_term + loss_bridge + 1.0 * loss_tgrad
+ value_opt.zero_grad()
+ vloss.backward()
+ torch.nn.utils.clip_grad_norm_(value_net.parameters(), 1.0)
+ value_opt.step()
+ update_ema(value_net, value_net_ema, 0.995)
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [CB] Ep {epoch}")
+ return value_net
+
+
+def train_vector_on_snapshot(model, train_loader, device, epochs=100, lr_fb=1e-3, M=4):
+ d = model.d_hidden
+ L = model.num_blocks
+ vector_net = VectorCreditNet(d_hidden=d, s_dim=10, time_embed_dim=32,
+ hidden_dim=256, num_layers=3).to(device)
+ vec_opt = optim.Adam(vector_net.parameters(), lr=lr_fb)
+ eps = 1e-3
+ model.eval()
+ for epoch in range(1, epochs + 1):
+ vector_net.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device)
+ y = y.to(device)
+ batch = x.size(0)
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+ hL = hiddens[-1].detach()
+
+ # Terminal matching
+ t_L = torch.ones(batch, device=device)
+ a_term = vector_net(hL, t_L, s)
+ hL_req = hL.clone().requires_grad_(True)
+ logits_tgt = model.out_head(model.out_ln(hL_req))
+ ce = F.cross_entropy(logits_tgt, y, reduction='sum')
+ delta_L = torch.autograd.grad(ce, hL_req, create_graph=False)[0].detach()
+ loss_term = ((a_term - delta_L) ** 2).sum(dim=-1).mean()
+
+ # Perturbation target — subsample 1 layer
+ l = np.random.randint(0, L)
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ a_l = vector_net(h_l, t_l, s)
+ loss_proj = torch.tensor(0.0, device=device)
+ for _ in range(M):
+ v = torch.randn_like(h_l)
+ v = v / (v.norm(dim=-1, keepdim=True) + 1e-8)
+ with torch.no_grad():
+ lp = F.cross_entropy(model.forward_from_layer(h_l + eps * v, l), y, reduction='none')
+ lm = F.cross_entropy(model.forward_from_layer(h_l - eps * v, l), y, reduction='none')
+ g_j = (lp - lm) / (2 * eps)
+ loss_proj = loss_proj + (((a_l * v).sum(-1) - g_j.detach()) ** 2).mean()
+ loss_proj = loss_proj / M
+
+ vloss = loss_term + loss_proj
+ vec_opt.zero_grad()
+ vloss.backward()
+ torch.nn.utils.clip_grad_norm_(vector_net.parameters(), 1.0)
+ vec_opt.step()
+ if epoch % 20 == 0 or epoch == 1:
+ print(f" [Vec_M{M}] Ep {epoch}")
+ return vector_net
+
+
+# =============================================================================
+# Credit computation
+# =============================================================================
+def get_credits(model, x, y, device, credit_source, estimator=None, dfa_Bs=None):
+ """Compute per-layer credits for a single batch. Returns dict {l: (batch, d)}."""
+ L = model.num_blocks
+ batch = x.size(0)
+
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+ e_T = logits.softmax(dim=-1)
+ e_T[torch.arange(batch), y] -= 1
+ s = e_T.detach()
+
+ credits = {}
+
+ if credit_source == 'dfa':
+ for l in range(L):
+ credits[l] = (s @ dfa_Bs[l].T).detach()
+
+ elif credit_source == 'scalar_cb':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach().requires_grad_(True)
+ t_l = torch.full((batch,), l / L, device=device)
+ V = estimator(h_l, t_l, s)
+ a = torch.autograd.grad(V.sum(), h_l, create_graph=False)[0]
+ credits[l] = a.detach()
+
+ elif credit_source == 'vec':
+ estimator.eval()
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ t_l = torch.full((batch,), l / L, device=device)
+ credits[l] = estimator(h_l, t_l, s).detach()
+
+ elif credit_source == 'oracle_bp':
+ # Compute true BP gradients — evaluation only, never used for training
+ for p in model.parameters():
+ p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ loss_bp = F.cross_entropy(logits_bp, y)
+ loss_bp.backward()
+ for l in range(L):
+ credits[l] = hiddens_bp[l].grad.detach().clone()
+ for p in model.parameters():
+ p.requires_grad_(False)
+
+ return credits, hiddens, s
+
+
+def compute_credit_quality(model, credits, hiddens, x, y, device):
+ """Compute mean Gamma, rho, nudge for given credits."""
+ L = model.num_blocks
+ batch = x.size(0)
+
+ # BP gradients for Gamma
+ for p in model.parameters():
+ p.requires_grad_(True)
+ model.zero_grad()
+ logits_bp, hiddens_bp = model(x, return_hidden=True)
+ for l in range(L + 1):
+ hiddens_bp[l].retain_grad()
+ F.cross_entropy(logits_bp, y).backward()
+ bp_grads = {l: hiddens_bp[l].grad.detach().clone() for l in range(L + 1)}
+ for p in model.parameters():
+ p.requires_grad_(False)
+
+ gammas, rhos, nudges = [], [], []
+ for l in range(L):
+ h_l = hiddens[l].detach()
+ a_l = credits[l]
+ gammas.append(cosine_similarity_batch(a_l, bp_grads[l]))
+
+ def make_fwd(sl):
+ def f(h):
+ with torch.no_grad():
+ c = h
+ for i in range(sl, L):
+ c = c + model.blocks[i](c)
+ return F.cross_entropy(model.out_head(model.out_ln(c)), y, reduction='none')
+ return f
+ fwd = make_fwd(l)
+ rhos.append(perturbation_correlation(h_l, a_l, fwd, epsilon=1e-3, M=16))
+ nudges.append(nudging_test(h_l, a_l, fwd, eta=0.003))
+
+ return float(np.mean(gammas)), float(np.mean(rhos)), float(np.mean(nudges))
+
+
+# =============================================================================
+# Local update step (block-local inner-product surrogate)
+# =============================================================================
+def do_local_update_step(model, x, y, credits, device, lr=1e-3, update_layers=None):
+ """
+ Perform one local surrogate update step on model blocks + output head.
+ update_layers: list of layer indices to update (None = all)
+ Returns new model state (modifies in-place).
+ """
+ L = model.num_blocks
+ if update_layers is None:
+ update_layers = list(range(L))
+
+ model.train()
+ with torch.no_grad():
+ logits, hiddens = model(x, return_hidden=True)
+
+ # Update output head (always)
+ hL = hiddens[-1].detach()
+ logits_out = model.out_head(model.out_ln(hL))
+ loss_out = F.cross_entropy(logits_out, y)
+ head_params = list(model.out_head.parameters()) + list(model.out_ln.parameters())
+ grads_head = torch.autograd.grad(loss_out, head_params)
+ with torch.no_grad():
+ for p, g in zip(head_params, grads_head):
+ p.sub_(lr * g)
+
+ # Update selected blocks
+ for l in update_layers:
+ h_l = hiddens[l].detach()
+ a = credits[l]
+ rms = (a ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ a_norm = a / rms
+ f_l = model.blocks[l](h_l)
+ local_loss = (f_l * a_norm).sum(dim=-1).mean()
+ block_grads = torch.autograd.grad(local_loss, model.blocks[l].parameters())
+ with torch.no_grad():
+ for p, g in zip(model.blocks[l].parameters(), block_grads):
+ p.sub_(lr * g.clamp(-1, 1)) # implicit grad clip
+
+ # Update embedding
+ if 0 in update_layers:
+ a_0 = credits[0]
+ rms_0 = (a_0 ** 2).mean(dim=-1, keepdim=True).sqrt() + 1e-6
+ h0 = model.embed(x)
+ embed_loss = (h0 * (a_0 / rms_0)).sum(dim=-1).mean()
+ embed_grads = torch.autograd.grad(embed_loss, model.embed.parameters())
+ with torch.no_grad():
+ for p, g in zip(model.embed.parameters(), embed_grads):
+ p.sub_(lr * g.clamp(-1, 1))
+
+
+# =============================================================================
+# Main experiment
+# =============================================================================
+def run_experiment(args):
+ device = torch.device(f'cuda:{args.gpu}' if torch.cuda.is_available() else 'cpu')
+ print(f"Using device: {device}")
+ os.makedirs(args.output_dir, exist_ok=True)
+
+ torch.manual_seed(args.seed)
+ np.random.seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+
+ train_loader, test_loader = get_cifar10(batch_size=args.batch_size)
+ input_dim = 32 * 32 * 3
+ L = args.num_blocks
+ d = args.d_hidden
+
+ # =========================================================
+ # Step 1: Load or train BP snapshot
+ # =========================================================
+ print(f"\n{'='*60}")
+ print(f"Loading BP snapshot (L={L}, d={d})")
+ print(f"{'='*60}")
+
+ model_bp = ResidualMLP(input_dim, d, 10, L).to(device)
+ bp_ckpt = f'results/frozen_cifar/bp_ref_L{L}_d{d}_s{args.seed}.pt'
+ if os.path.exists(bp_ckpt):
+ model_bp.load_state_dict(torch.load(bp_ckpt, map_location=device))
+ print(f" Loaded from {bp_ckpt}")
+ else:
+ print(f" Training BP reference...")
+ optimizer = optim.AdamW(model_bp.parameters(), lr=1e-3, weight_decay=0.01)
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
+ for epoch in range(1, 101):
+ model_bp.train()
+ for x, y in train_loader:
+ x = x.view(x.size(0), -1).to(device); y = y.to(device)
+ loss = F.cross_entropy(model_bp(x), y)
+ optimizer.zero_grad(); loss.backward(); optimizer.step()
+ scheduler.step()
+ if epoch % 20 == 0:
+ model_bp.eval()
+ c, t = 0, 0
+ with torch.no_grad():
+ for xv, yv in test_loader:
+ xv = xv.view(xv.size(0), -1).to(device); yv = yv.to(device)
+ c += (model_bp(xv).argmax(1) == yv).sum().item(); t += xv.size(0)
+ print(f" Ep {epoch}: test_acc={c/t:.4f}")
+ os.makedirs(os.path.dirname(bp_ckpt), exist_ok=True)
+ torch.save(model_bp.state_dict(), bp_ckpt)
+
+ model_bp.eval()
+ for p in model_bp.parameters():
+ p.requires_grad_(False)
+ bp_loss, bp_acc = evaluate_loss_acc(model_bp, test_loader, device, n_batches=20)
+ print(f" BP snapshot: loss={bp_loss:.4f}, acc={bp_acc:.4f}")
+
+ # =========================================================
+ # Step 2: Train estimators on BP snapshot
+ # =========================================================
+ print(f"\n{'='*60}")
+ print(f"Training estimators on BP snapshot ({args.estimator_epochs} epochs)")
+ print(f"{'='*60}")
+
+ # DFA
+ dfa_Bs = [torch.randn(d, 10, device=device) / np.sqrt(10) for _ in range(L)]
+
+ # Scalar CB
+ print(" --- ScalarCB_eT ---")
+ torch.manual_seed(args.seed + 2000)
+ cb = train_scalar_cb_on_snapshot(model_bp, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb)
+
+ # Vector field M4
+ print(" --- Vec_eT_M4 ---")
+ torch.manual_seed(args.seed + 4000)
+ vec4 = train_vector_on_snapshot(model_bp, train_loader, device,
+ epochs=args.estimator_epochs, lr_fb=args.lr_fb, M=4)
+
+ credit_sources = {
+ 'dfa': ('dfa', None, dfa_Bs),
+ 'scalar_cb': ('scalar_cb', cb, None),
+ 'vec_eT_M4': ('vec', vec4, None),
+ 'oracle_bp': ('oracle_bp', None, None),
+ }
+
+ # =========================================================
+ # Step 3: Compute credit quality
+ # =========================================================
+ print(f"\n{'='*60}")
+ print(f"Computing credit quality on BP snapshot")
+ print(f"{'='*60}")
+
+ # Get a fixed batch for credit quality
+ for x_eval, y_eval in test_loader:
+ x_eval = x_eval.view(x_eval.size(0), -1).to(device)
+ y_eval = y_eval.to(device)
+ break
+
+ quality = {}
+ for name, (src, est, Bs) in credit_sources.items():
+ credits, hiddens, s = get_credits(model_bp, x_eval, y_eval, device, src,
+ estimator=est, dfa_Bs=Bs)
+ gamma, rho, nudge = compute_credit_quality(model_bp, credits, hiddens,
+ x_eval, y_eval, device)
+ quality[name] = {'gamma': gamma, 'rho': rho, 'nudge': nudge}
+ print(f" {name}: Gamma={gamma:.4f}, rho={rho:.4f}, nudge={nudge:.6f}")
+
+ # =========================================================
+ # Step 4: Exploitability tests
+ # =========================================================
+ print(f"\n{'='*60}")
+ print(f"Test A1: Single-step DeltaLoss")
+ print(f"{'='*60}")
+
+ # Get train batch for update
+ train_iter = iter(train_loader)
+ x_train, y_train = next(train_iter)
+ x_train = x_train.view(x_train.size(0), -1).to(device)
+ y_train = y_train.to(device)
+
+ # Get held-out eval batches (different from train batch)
+ eval_batches = []
+ for i, (xv, yv) in enumerate(test_loader):
+ if i >= 10:
+ break
+ eval_batches.append((xv.view(xv.size(0), -1).to(device), yv.to(device)))
+
+ def eval_on_batches(model):
+ model.eval()
+ total_loss, correct, total = 0, 0, 0
+ with torch.no_grad():
+ for xv, yv in eval_batches:
+ logits = model(xv)
+ total_loss += F.cross_entropy(logits, yv, reduction='sum').item()
+ correct += (logits.argmax(1) == yv).sum().item()
+ total += xv.size(0)
+ return total_loss / total, correct / total
+
+ results = {}
+ lr_update = args.lr_update
+
+ for name, (src, est, Bs) in credit_sources.items():
+ # Reset to snapshot
+ model_test = copy.deepcopy(model_bp)
+ model_test.eval()
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+
+ loss_before, acc_before = eval_on_batches(model_test)
+
+ # Compute credits on the train batch
+ credits, hiddens, s = get_credits(model_test, x_train, y_train, device, src,
+ estimator=est, dfa_Bs=Bs)
+
+ # Re-enable grad for update
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+
+ # One-step update
+ do_local_update_step(model_test, x_train, y_train, credits, device, lr=lr_update)
+
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+ loss_after, acc_after = eval_on_batches(model_test)
+
+ delta_loss = loss_after - loss_before
+ delta_acc = acc_after - acc_before
+
+ results[name] = {
+ 'gamma': quality[name]['gamma'],
+ 'rho': quality[name]['rho'],
+ 'nudge': quality[name]['nudge'],
+ 'loss_before': loss_before,
+ 'loss_after_1step': loss_after,
+ 'delta_loss_1step': delta_loss,
+ 'delta_acc_1step': delta_acc,
+ }
+ print(f" {name}: DeltaLoss={delta_loss:+.6f}, DeltaAcc={delta_acc:+.4f} "
+ f"(before={loss_before:.4f}, after={loss_after:.4f})")
+
+ # =========================================================
+ # Test A2: k-step rollout
+ # =========================================================
+ print(f"\n{'='*60}")
+ print(f"Test A2: k-step rollout (k=1,5,20)")
+ print(f"{'='*60}")
+
+ for name, (src, est, Bs) in credit_sources.items():
+ rollout = {}
+ for k in [1, 5, 20]:
+ model_test = copy.deepcopy(model_bp)
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+
+ for step in range(k):
+ # Fresh batch each step
+ try:
+ x_step, y_step = next(train_iter)
+ except StopIteration:
+ train_iter = iter(train_loader)
+ x_step, y_step = next(train_iter)
+ x_step = x_step.view(x_step.size(0), -1).to(device)
+ y_step = y_step.to(device)
+
+ # Recompute credits on current model state
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+ credits_step, _, _ = get_credits(model_test, x_step, y_step, device, src,
+ estimator=est, dfa_Bs=Bs)
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+ do_local_update_step(model_test, x_step, y_step, credits_step, device, lr=lr_update)
+
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+ loss_k, acc_k = eval_on_batches(model_test)
+ rollout[k] = {'loss': loss_k, 'acc': acc_k,
+ 'delta_loss': loss_k - results[name]['loss_before'],
+ 'delta_acc': acc_k - results[name]['delta_acc_1step']}
+
+ results[name]['rollout'] = rollout
+ print(f" {name}: k=1 dL={rollout[1]['delta_loss']:+.4f}, "
+ f"k=5 dL={rollout[5]['delta_loss']:+.4f}, "
+ f"k=20 dL={rollout[20]['delta_loss']:+.4f}")
+
+ # =========================================================
+ # Test A3: Per-layer ablation
+ # =========================================================
+ print(f"\n{'='*60}")
+ print(f"Test A3: Per-layer ablation (5-step)")
+ print(f"{'='*60}")
+
+ layer_configs = {
+ 'last_1': [L - 1],
+ 'last_2': [L - 2, L - 1],
+ 'all': list(range(L)),
+ }
+
+ layer_results = {}
+ for name, (src, est, Bs) in credit_sources.items():
+ layer_results[name] = {}
+ for lname, layers in layer_configs.items():
+ model_test = copy.deepcopy(model_bp)
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+
+ train_iter2 = iter(train_loader)
+ for step in range(5):
+ try:
+ x_step, y_step = next(train_iter2)
+ except StopIteration:
+ train_iter2 = iter(train_loader)
+ x_step, y_step = next(train_iter2)
+ x_step = x_step.view(x_step.size(0), -1).to(device)
+ y_step = y_step.to(device)
+
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+ credits_step, _, _ = get_credits(model_test, x_step, y_step, device, src,
+ estimator=est, dfa_Bs=Bs)
+ for p in model_test.parameters():
+ p.requires_grad_(True)
+ do_local_update_step(model_test, x_step, y_step, credits_step, device,
+ lr=lr_update, update_layers=layers)
+
+ for p in model_test.parameters():
+ p.requires_grad_(False)
+ loss_k, acc_k = eval_on_batches(model_test)
+ dl = loss_k - results[name]['loss_before']
+ layer_results[name][lname] = {'delta_loss': dl, 'loss': loss_k}
+
+ print(f" {name}: last_1={layer_results[name]['last_1']['delta_loss']:+.4f}, "
+ f"last_2={layer_results[name]['last_2']['delta_loss']:+.4f}, "
+ f"all={layer_results[name]['all']['delta_loss']:+.4f}")
+
+ # =========================================================
+ # Summary
+ # =========================================================
+ print(f"\n{'='*80}")
+ print("SUMMARY TABLE")
+ print(f"{'='*80}")
+ print(f"{'Method':<15} {'Gamma':>7} {'rho':>7} {'nudge':>10} {'dL_1':>8} {'dL_5':>8} {'dL_20':>8}")
+ print("-" * 68)
+ for name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']:
+ r = results[name]
+ dL1 = r['delta_loss_1step']
+ dL5 = r['rollout'][5]['delta_loss']
+ dL20 = r['rollout'][20]['delta_loss']
+ print(f"{name:<15} {r['gamma']:>7.4f} {r['rho']:>7.4f} {r['nudge']:>10.6f} "
+ f"{dL1:>+8.4f} {dL5:>+8.4f} {dL20:>+8.4f}")
+
+ print(f"\nPer-layer ablation (5-step DeltaLoss):")
+ print(f"{'Method':<15} {'last_1':>8} {'last_2':>8} {'all':>8}")
+ print("-" * 42)
+ for name in ['dfa', 'scalar_cb', 'vec_eT_M4', 'oracle_bp']:
+ lr = layer_results[name]
+ print(f"{name:<15} {lr['last_1']['delta_loss']:>+8.4f} "
+ f"{lr['last_2']['delta_loss']:>+8.4f} {lr['all']['delta_loss']:>+8.4f}")
+
+ # Save
+ save_data = {
+ 'config': {'L': L, 'd': d, 'seed': args.seed, 'lr_update': lr_update,
+ 'estimator_epochs': args.estimator_epochs},
+ 'credit_quality': quality,
+ 'exploitability': {n: {k: v for k, v in r.items() if k != 'rollout'}
+ for n, r in results.items()},
+ 'rollout': {n: r['rollout'] for n, r in results.items()},
+ 'layer_ablation': layer_results,
+ }
+ out_path = os.path.join(args.output_dir, f'snapshot_L{L}_d{d}_s{args.seed}.json')
+ with open(out_path, 'w') as f:
+ json.dump(save_data, f, indent=2, default=float)
+ print(f"\nSaved to {out_path}")
+
+ # Judgment
+ print(f"\n{'='*60}")
+ print("JUDGMENT")
+ print(f"{'='*60}")
+ vec_dl = results['vec_eT_M4']['rollout'][5]['delta_loss']
+ cb_dl = results['scalar_cb']['rollout'][5]['delta_loss']
+ bp_dl = results['oracle_bp']['rollout'][5]['delta_loss']
+ dfa_dl = results['dfa']['rollout'][5]['delta_loss']
+
+ print(f"5-step DeltaLoss: DFA={dfa_dl:+.4f}, CB={cb_dl:+.4f}, Vec={vec_dl:+.4f}, BP={bp_dl:+.4f}")
+
+ if vec_dl < cb_dl and vec_dl < dfa_dl:
+ print("EXPLOITABLE: Vec credit produces better loss decrease than ScalarCB and DFA.")
+ print(" -> Online failure is likely tracking/co-adaptation (Case A).")
+ elif bp_dl < dfa_dl and vec_dl >= cb_dl:
+ print("NOT EXPLOITABLE: Better credit (vec) doesn't translate to better loss decrease.")
+ print(" -> Bottleneck is in local update rule (Case B).")
+ else:
+ print("AMBIGUOUS: Need more investigation.")
+
+
+def main():
+ parser = argparse.ArgumentParser(description='Phase 6A: Snapshot Exploitability')
+ parser.add_argument('--num_blocks', type=int, default=4)
+ parser.add_argument('--d_hidden', type=int, default=256)
+ parser.add_argument('--batch_size', type=int, default=128)
+ parser.add_argument('--estimator_epochs', type=int, default=100)
+ parser.add_argument('--lr_fb', type=float, default=1e-3)
+ parser.add_argument('--lr_update', type=float, default=1e-3)
+ parser.add_argument('--seed', type=int, default=42)
+ parser.add_argument('--gpu', type=int, default=3)
+ parser.add_argument('--output_dir', type=str, default='results/snapshot_exploit')
+ args = parser.parse_args()
+ run_experiment(args)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/report_explore/MEMO_6A_snapshot_exploitability.md b/report_explore/MEMO_6A_snapshot_exploitability.md
new file mode 100644
index 0000000..950ed1b
--- /dev/null
+++ b/report_explore/MEMO_6A_snapshot_exploitability.md
@@ -0,0 +1,39 @@
+# Phase 6A Memo: Snapshot Exploitability
+
+**Date**: 2026-03-24
+**Config**: BP snapshot, CIFAR-10, L=4, d=256 (61.9% acc), seed=42
+
+## Question
+On a fixed snapshot, does better credit lead to better real loss decrease via the current local surrogate?
+
+## Results
+
+| Method | Gamma | rho | dL_1step | dL_5step | dL_20step |
+|--------|-------|-----|----------|----------|-----------|
+| DFA | 0.009 | -0.023 | **-0.0004** | **+0.0002** | **-0.0007** |
+| ScalarCB | 0.122 | 0.090 | +0.003 | +0.042 | +0.405 |
+| Vec_M4 | 0.378 | 0.411 | +0.003 | +0.050 | +0.272 |
+| Oracle BP | 1.000 | 0.998 | **-0.001** | +0.007 | +0.026 |
+
+## Key Finding: The Local Surrogate is Anti-Correlated with Credit Quality
+
+**Better credit produces WORSE loss change.** DFA (Gamma≈0) is the only method that decreases loss. ScalarCB (Gamma=0.12) and Vec (Gamma=0.38) both increase loss, with Vec slightly worse. Even Oracle BP increases loss at 5+ steps.
+
+The inner-product surrogate `L_local = <F_l(h_l), a_l>` is fundamentally broken as a local update rule for directional credit:
+- It treats a_l as a "desired direction for the residual output" rather than a gradient
+- The gradient of this surrogate w.r.t. block parameters pushes F_l(h) to align with a_l, but this is NOT the same as making h_{l+1} = h_l + F_l(h_l) move in the direction that decreases global loss
+- DFA "works" precisely because its random credits are small and roughly isotropic — the updates are near-random perturbations that don't systematically damage the representation
+
+## Verdict
+
+**This is Case B: the local update rule is the bottleneck, not the estimator or tracking.**
+
+Improving credit quality from DFA (Gamma=0.01) through ScalarCB (0.12) to Vec (0.38) to Oracle BP (1.0) does NOT improve — and actually worsens — real parameter update quality.
+
+## Implication
+
+The project should pivot from "better credit estimator" to "better local update coupling." The target-shift local regression rule (Phase 6C) is the natural next experiment:
+
+`L_shift = 0.5 * || h_l + F_l(h_l) - sg(h_{l+1} - eta * a_{l+1}^norm) ||^2`
+
+This directly tells each block: "adjust your output so the next hidden state moves toward the credit-indicated direction."
diff --git a/report_explore/MEMO_6_exploitability.md b/report_explore/MEMO_6_exploitability.md
new file mode 100644
index 0000000..42dfda5
--- /dev/null
+++ b/report_explore/MEMO_6_exploitability.md
@@ -0,0 +1,53 @@
+# Phase 6 Memo: Snapshot Exploitability + Local Update Rule Swap
+
+**Date**: 2026-03-24
+
+## Phase 6A: Snapshot Exploitability
+
+**Setup**: BP-trained CIFAR-10 snapshot (L=4, d=256, 61.9% acc). Train estimators on frozen features, then do k-step local updates and measure real loss change.
+
+### Results (5-step DeltaLoss, inner-product surrogate)
+
+| Credit | Gamma | rho | dL_5step |
+|--------|-------|-----|----------|
+| DFA | 0.009 | -0.023 | **-0.0001** |
+| ScalarCB | 0.122 | 0.090 | +0.042 |
+| Vec_M4 | 0.378 | 0.411 | +0.057 |
+| Oracle BP | 1.000 | 0.998 | +0.011 |
+
+**Finding**: Better credit quality is ANTI-CORRELATED with loss decrease. DFA (worst credit) produces the only method that doesn't increase loss. Vec (best credit) increases loss the most. Even Oracle BP increases loss at 5 steps.
+
+**Verdict**: This is **Case B** — the local update rule is the bottleneck.
+
+## Phase 6C: Local Update Rule Swap
+
+Tested target-shift rule (h_{l+1}^target = h_{l+1} - eta * a_norm) at eta in {0.01, 0.1, 0.3, 1.0}.
+
+### Results (5-step DeltaLoss)
+
+| Credit | inner_prod | shift_0.1 | shift_0.3 | shift_1.0 |
+|--------|:---:|:---:|:---:|:---:|
+| DFA | -0.0001 | **-0.0003** | +0.0004 | +0.001 |
+| Vec_M4 | +0.057 | +0.002 | +0.009 | +0.048 |
+| Oracle BP | +0.011 | +0.0002 | +0.001 | +0.005 |
+
+Target-shift reduces the damage but never achieves negative DeltaLoss for non-DFA credits. The cosine rule produces near-zero effects at all settings.
+
+## Root Cause Analysis
+
+The issue is deeper than the update rule. A BP-trained snapshot sits at a minimum of the full-backprop loss surface. Any local update that doesn't have access to the full gradient chain will push parameters in a direction that may locally align with the credit but globally increases loss. This is because:
+
+1. The inner-product surrogate `<F_l(h), a_l>` assumes a_l is the desired direction for the residual output. But even perfect credit (Oracle BP) doesn't produce good updates via this mechanism — the gradient of the surrogate w.r.t. block parameters is NOT the same as the gradient of the global loss.
+
+2. Target-shift reduces the magnitude of harmful updates but doesn't fix the direction. At small eta, updates are negligible. At large eta, the target shifts too far and becomes harmful.
+
+3. DFA "works" precisely because its random credits produce near-zero effective updates — it's approximately doing nothing, which is better than doing the wrong thing.
+
+## Implications
+
+**The project's fundamental limitation is NOT in the credit estimator.** It's in the local surrogate update paradigm itself. The inner-product surrogate `<F(h), a>` is not a valid proxy for global loss minimization, regardless of credit quality.
+
+**Potential directions:**
+1. Use credit to set per-block learning targets rather than gradients (e.g., knowledge distillation-style objectives)
+2. Use credit to modulate a more expressive local loss (e.g., local CE with projected targets)
+3. Abandon block-local updates entirely and use credit to define a global but differentiable auxiliary loss