diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 01:15:08 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-08 01:15:08 -0500 |
| commit | 671d9823668197c21b2d35d08d15da0d5c3c4161 (patch) | |
| tree | 89d42a252b7522196cc992224a62e632d3038d3e /experiments | |
| parent | df9f69bc9172b3473be144ff8a17370bc7a68e64 (diff) | |
Add null calibration script: training-Bs vs fresh-Bs cos on penalized DFA
Codex round 19's #1 critical control. Result on penalized DFA s42 (lam=1e-2, 30 ep):
training-Bs deep-layer cos: +0.1627
fresh-Bs deep-layer cos: +0.0022 ± 0.0220 (n=20 draws)
The +0.17 measurement is REAL signal, not artifact. The network specifically
adapted to its training-time Bs during the penalized run. Fresh Bs give
essentially zero cosine (within noise).
This validates the walk-back interpretation: in the rescued regime where
||g_l|| is meaningful, DFA's local credit signal shows partial alignment
with BP grad — and this alignment is specifically the network learning to
align with its specific Bs.
Round 19 caveat preserved: cannot yet distinguish whether the alignment
was always present in vanilla but hidden by measurement degeneracy, OR
whether it was created by the penalty intervention. The early-epoch
vanilla checkpoint sweep (round 19's other proposed control) would
disambiguate.
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/null_calibration_penalized_cos.py | 154 |
1 files changed, 154 insertions, 0 deletions
diff --git a/experiments/null_calibration_penalized_cos.py b/experiments/null_calibration_penalized_cos.py new file mode 100644 index 0000000..d0d3472 --- /dev/null +++ b/experiments/null_calibration_penalized_cos.py @@ -0,0 +1,154 @@ +""" +Null calibration of the +0.17 deep-layer cosine on penalized DFA. + +Codex round 19 critical control: same penalized checkpoint, but compute the +cosine with FRESH random Bs (not the training-time Bs). If +0.17 was real +signal that the network adapted to its training-time Bs, fresh Bs should +give cosine ≈ 0. If +0.17 was an artifact of how the cosine is computed +(e.g., a property of the penalized network independent of the Bs), fresh +Bs should also give ~+0.17. + +Run: + CUDA_VISIBLE_DEVICES=2 python experiments/null_calibration_penalized_cos.py +""" +import os +import sys +import argparse + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +from models.residual_mlp import ResidualMLP + + +def load_eval(n=2048, device="cuda:0"): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) + xs, ys = [], [] + for x, y in loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= n: + break + return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device) + + +def per_layer_bp_grads(model, x, y): + with torch.enable_grad(): + h = model.embed(x) + hiddens = [h] + for block in model.blocks: + h = h + block(h) + hiddens.append(h) + logits = model.out_head(model.out_ln(h)) + loss = F.cross_entropy(logits, y) + grads = torch.autograd.grad(loss, hiddens) + return list(grads), logits.detach() + + +def cosine_no_clamp(a, b): + eps = 1e-30 + an = a.norm(dim=-1, keepdim=True).clamp_min(eps) + bn = b.norm(dim=-1, keepdim=True).clamp_min(eps) + return ((a / an) * (b / bn)).sum(dim=-1) + + +def measure_with_Bs(model, Bs, x, y): + L = model.num_blocks + grads, logits = per_layer_bp_grads(model, x, y) + e_T = F.softmax(logits, dim=-1).clone() + e_T[torch.arange(len(y), device=y.device), y] -= 1 + out = [] + for l in range(L + 1): + b_idx = min(l, L - 1) + a_l = (e_T @ Bs[b_idx].T).detach() + g_l = grads[l].detach() + cos = cosine_no_clamp(a_l, g_l) + out.append({"layer": l, "cos_mean": float(cos.mean().item())}) + return out + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--ckpt", type=str, default="results/dfa_pen_short/dfa_pen_lam0.01_s42.pt") + p.add_argument("--n_fresh", type=int, default=20, help="number of fresh Bs draws") + args = p.parse_args() + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + print(f"Loading {args.ckpt}") + sd = torch.load(args.ckpt, map_location=device, weights_only=False) + model = ResidualMLP(3072, 256, 10, 4).to(device) + model.load_state_dict(sd["state_dict"]) + Bs_train = [b.to(device) for b in sd["Bs"]] + print(f"Test acc: {sd.get('test_acc', 'unknown')}") + print() + + x, y = load_eval(n=2048, device=device) + + print("=" * 72) + print("REFERENCE: training-time Bs") + print("=" * 72) + out_train = measure_with_Bs(model, Bs_train, x, y) + for entry in out_train: + print(f" l{entry['layer']}: cos_mean={entry['cos_mean']:+.4f}") + train_mean = np.mean([e['cos_mean'] for e in out_train]) + train_deep = np.mean([e['cos_mean'] for e in out_train[1:]]) + print(f" layer-mean: {train_mean:+.4f}") + print(f" deep-layer mean (l1-l4): {train_deep:+.4f}") + print() + + print("=" * 72) + print(f"NULL CALIBRATION: {args.n_fresh} fresh random Bs draws") + print("=" * 72) + fresh_results = [] + for k in range(args.n_fresh): + torch.manual_seed(10000 + k) + Bs_fresh = [torch.randn(256, 10, device=device) / np.sqrt(10) for _ in range(4)] + out_fresh = measure_with_Bs(model, Bs_fresh, x, y) + fresh_results.append(out_fresh) + deep_mean = np.mean([e['cos_mean'] for e in out_fresh[1:]]) + per_layer_str = ", ".join(f"{e['cos_mean']:+.4f}" for e in out_fresh) + print(f" fresh #{k}: per-layer = [{per_layer_str}], deep mean {deep_mean:+.4f}") + + # Aggregate + arr = np.array([[e['cos_mean'] for e in r] for r in fresh_results]) # (n_fresh, n_layers) + print() + print(f" Across {args.n_fresh} fresh Bs draws (mean ± std per layer):") + for l in range(arr.shape[1]): + print(f" l{l}: {arr[:,l].mean():+.4f} ± {arr[:,l].std():.4f}") + fresh_deep_mean = arr[:, 1:].mean() + fresh_deep_std = arr[:, 1:].std() + print(f" fresh-Bs deep-layer mean: {fresh_deep_mean:+.4f} ± {fresh_deep_std:.4f}") + print() + print("=" * 72) + print("INTERPRETATION") + print("=" * 72) + print(f" Training-Bs deep cos: {train_deep:+.4f}") + print(f" Fresh-Bs deep cos: {fresh_deep_mean:+.4f}") + print() + if abs(fresh_deep_mean) < 0.05: + print(f" Fresh Bs give ~0 cosine (|{fresh_deep_mean:.4f}| < 0.05)") + print(f" → The +{train_deep:.4f} on training Bs is REAL signal that the network") + print(f" adapted to its specific Bs during training.") + elif abs(fresh_deep_mean) > 0.10: + print(f" Fresh Bs give SIMILAR cosine to training Bs (|{fresh_deep_mean:.4f}| > 0.10)") + print(f" → The +{train_deep:.4f} is NOT specifically about the training Bs.") + print(f" It could be a property of the BP grad direction itself in the") + print(f" penalized regime — i.e. the BP grad and ANY random direction give") + print(f" a similar partial alignment. This would weaken the 'partial credit") + print(f" quality' interpretation.") + else: + print(f" Fresh Bs give intermediate cosine ({fresh_deep_mean:.4f})") + print(f" → Mixed: the training Bs are partially specific, partially generic.") + + +if __name__ == "__main__": + main() |
