summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 01:15:08 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-08 01:15:08 -0500
commit671d9823668197c21b2d35d08d15da0d5c3c4161 (patch)
tree89d42a252b7522196cc992224a62e632d3038d3e /experiments
parentdf9f69bc9172b3473be144ff8a17370bc7a68e64 (diff)
Add null calibration script: training-Bs vs fresh-Bs cos on penalized DFA
Codex round 19's #1 critical control. Result on penalized DFA s42 (lam=1e-2, 30 ep): training-Bs deep-layer cos: +0.1627 fresh-Bs deep-layer cos: +0.0022 ± 0.0220 (n=20 draws) The +0.17 measurement is REAL signal, not artifact. The network specifically adapted to its training-time Bs during the penalized run. Fresh Bs give essentially zero cosine (within noise). This validates the walk-back interpretation: in the rescued regime where ||g_l|| is meaningful, DFA's local credit signal shows partial alignment with BP grad — and this alignment is specifically the network learning to align with its specific Bs. Round 19 caveat preserved: cannot yet distinguish whether the alignment was always present in vanilla but hidden by measurement degeneracy, OR whether it was created by the penalty intervention. The early-epoch vanilla checkpoint sweep (round 19's other proposed control) would disambiguate.
Diffstat (limited to 'experiments')
-rw-r--r--experiments/null_calibration_penalized_cos.py154
1 files changed, 154 insertions, 0 deletions
diff --git a/experiments/null_calibration_penalized_cos.py b/experiments/null_calibration_penalized_cos.py
new file mode 100644
index 0000000..d0d3472
--- /dev/null
+++ b/experiments/null_calibration_penalized_cos.py
@@ -0,0 +1,154 @@
+"""
+Null calibration of the +0.17 deep-layer cosine on penalized DFA.
+
+Codex round 19 critical control: same penalized checkpoint, but compute the
+cosine with FRESH random Bs (not the training-time Bs). If +0.17 was real
+signal that the network adapted to its training-time Bs, fresh Bs should
+give cosine ≈ 0. If +0.17 was an artifact of how the cosine is computed
+(e.g., a property of the penalized network independent of the Bs), fresh
+Bs should also give ~+0.17.
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python experiments/null_calibration_penalized_cos.py
+"""
+import os
+import sys
+import argparse
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+from models.residual_mlp import ResidualMLP
+
+
+def load_eval(n=2048, device="cuda:0"):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+ xs, ys = [], []
+ for x, y in loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= n:
+ break
+ return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device)
+
+
+def per_layer_bp_grads(model, x, y):
+ with torch.enable_grad():
+ h = model.embed(x)
+ hiddens = [h]
+ for block in model.blocks:
+ h = h + block(h)
+ hiddens.append(h)
+ logits = model.out_head(model.out_ln(h))
+ loss = F.cross_entropy(logits, y)
+ grads = torch.autograd.grad(loss, hiddens)
+ return list(grads), logits.detach()
+
+
+def cosine_no_clamp(a, b):
+ eps = 1e-30
+ an = a.norm(dim=-1, keepdim=True).clamp_min(eps)
+ bn = b.norm(dim=-1, keepdim=True).clamp_min(eps)
+ return ((a / an) * (b / bn)).sum(dim=-1)
+
+
+def measure_with_Bs(model, Bs, x, y):
+ L = model.num_blocks
+ grads, logits = per_layer_bp_grads(model, x, y)
+ e_T = F.softmax(logits, dim=-1).clone()
+ e_T[torch.arange(len(y), device=y.device), y] -= 1
+ out = []
+ for l in range(L + 1):
+ b_idx = min(l, L - 1)
+ a_l = (e_T @ Bs[b_idx].T).detach()
+ g_l = grads[l].detach()
+ cos = cosine_no_clamp(a_l, g_l)
+ out.append({"layer": l, "cos_mean": float(cos.mean().item())})
+ return out
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--ckpt", type=str, default="results/dfa_pen_short/dfa_pen_lam0.01_s42.pt")
+ p.add_argument("--n_fresh", type=int, default=20, help="number of fresh Bs draws")
+ args = p.parse_args()
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+
+ print(f"Loading {args.ckpt}")
+ sd = torch.load(args.ckpt, map_location=device, weights_only=False)
+ model = ResidualMLP(3072, 256, 10, 4).to(device)
+ model.load_state_dict(sd["state_dict"])
+ Bs_train = [b.to(device) for b in sd["Bs"]]
+ print(f"Test acc: {sd.get('test_acc', 'unknown')}")
+ print()
+
+ x, y = load_eval(n=2048, device=device)
+
+ print("=" * 72)
+ print("REFERENCE: training-time Bs")
+ print("=" * 72)
+ out_train = measure_with_Bs(model, Bs_train, x, y)
+ for entry in out_train:
+ print(f" l{entry['layer']}: cos_mean={entry['cos_mean']:+.4f}")
+ train_mean = np.mean([e['cos_mean'] for e in out_train])
+ train_deep = np.mean([e['cos_mean'] for e in out_train[1:]])
+ print(f" layer-mean: {train_mean:+.4f}")
+ print(f" deep-layer mean (l1-l4): {train_deep:+.4f}")
+ print()
+
+ print("=" * 72)
+ print(f"NULL CALIBRATION: {args.n_fresh} fresh random Bs draws")
+ print("=" * 72)
+ fresh_results = []
+ for k in range(args.n_fresh):
+ torch.manual_seed(10000 + k)
+ Bs_fresh = [torch.randn(256, 10, device=device) / np.sqrt(10) for _ in range(4)]
+ out_fresh = measure_with_Bs(model, Bs_fresh, x, y)
+ fresh_results.append(out_fresh)
+ deep_mean = np.mean([e['cos_mean'] for e in out_fresh[1:]])
+ per_layer_str = ", ".join(f"{e['cos_mean']:+.4f}" for e in out_fresh)
+ print(f" fresh #{k}: per-layer = [{per_layer_str}], deep mean {deep_mean:+.4f}")
+
+ # Aggregate
+ arr = np.array([[e['cos_mean'] for e in r] for r in fresh_results]) # (n_fresh, n_layers)
+ print()
+ print(f" Across {args.n_fresh} fresh Bs draws (mean ± std per layer):")
+ for l in range(arr.shape[1]):
+ print(f" l{l}: {arr[:,l].mean():+.4f} ± {arr[:,l].std():.4f}")
+ fresh_deep_mean = arr[:, 1:].mean()
+ fresh_deep_std = arr[:, 1:].std()
+ print(f" fresh-Bs deep-layer mean: {fresh_deep_mean:+.4f} ± {fresh_deep_std:.4f}")
+ print()
+ print("=" * 72)
+ print("INTERPRETATION")
+ print("=" * 72)
+ print(f" Training-Bs deep cos: {train_deep:+.4f}")
+ print(f" Fresh-Bs deep cos: {fresh_deep_mean:+.4f}")
+ print()
+ if abs(fresh_deep_mean) < 0.05:
+ print(f" Fresh Bs give ~0 cosine (|{fresh_deep_mean:.4f}| < 0.05)")
+ print(f" → The +{train_deep:.4f} on training Bs is REAL signal that the network")
+ print(f" adapted to its specific Bs during training.")
+ elif abs(fresh_deep_mean) > 0.10:
+ print(f" Fresh Bs give SIMILAR cosine to training Bs (|{fresh_deep_mean:.4f}| > 0.10)")
+ print(f" → The +{train_deep:.4f} is NOT specifically about the training Bs.")
+ print(f" It could be a property of the BP grad direction itself in the")
+ print(f" penalized regime — i.e. the BP grad and ANY random direction give")
+ print(f" a similar partial alignment. This would weaken the 'partial credit")
+ print(f" quality' interpretation.")
+ else:
+ print(f" Fresh Bs give intermediate cosine ({fresh_deep_mean:.4f})")
+ print(f" → Mixed: the training Bs are partially specific, partially generic.")
+
+
+if __name__ == "__main__":
+ main()