diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:21:32 -0500 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-04-07 23:21:32 -0500 |
| commit | 8f67bdeebac543961871b9896a62cd07b7a5be26 (patch) | |
| tree | 63fec268bf894b61875ccf90e173af4e4264cb81 /experiments | |
| parent | 5771a122300f9d30a6290fcbfc9bffb5f380e648 (diff) | |
Add fast direction-quality measurement on existing DFA checkpoints
3-seed result on the existing dfa_s{42,123,456}.pt checkpoints from
results/confirmatory/checkpoints_A2/, computing per-layer cosine of
DFA's local credit signal e_T@B_l^T vs the true BP gradient at h_l.
Key findings:
per-layer cos (3-seed mean):
l0: +0.42 (high — embedding alignment)
l1: +0.006 (essentially zero)
l2: -0.015 (essentially zero)
l3: -0.004 (essentially zero)
l4: -0.004 (essentially zero)
layer-mean across all 5: +0.07-0.10
The deep blocks (l1-l4) have essentially zero alignment with BP grad in
the vanilla scale-failure regime. Layer 0 dominates the headline.
The script reconstructs the training-time random Bs by replaying the RNG
sequence (torch.manual_seed + ResidualMLP construction + randn draws),
since the existing checkpoints don't save Bs. For the still-running
direction-quality experiment which DOES save Bs, the script auto-detects
the dict format and uses the saved Bs directly.
Diffstat (limited to 'experiments')
| -rw-r--r-- | experiments/measure_direction_quality_existing_ckpt.py | 143 |
1 files changed, 143 insertions, 0 deletions
diff --git a/experiments/measure_direction_quality_existing_ckpt.py b/experiments/measure_direction_quality_existing_ckpt.py new file mode 100644 index 0000000..d150eb3 --- /dev/null +++ b/experiments/measure_direction_quality_existing_ckpt.py @@ -0,0 +1,143 @@ +""" +Fast direction-quality measurement using EXISTING checkpoints. No training. + +Loads the vanilla DFA s42 checkpoint from +`results/confirmatory/checkpoints_A2/dfa_s42.pt` and computes the per-layer +cosine between DFA's local credit signal `a_l = e_T @ B_l^T` and the BP +gradient `g_l = ∂L/∂h_l`. This is the "Γ on the degenerate reference" +measurement — what the field-standard FA evaluation reports. + +The catch: the trained-time random feedback Bs were not saved in the +existing checkpoint. We reconstruct them by replaying the training-time +RNG sequence (`torch.manual_seed(seed); ResidualMLP(...); randn(d, C)`), +matching what the original DFA trainer did. + +For the "scale-fixed" comparison case, we'll need a penalized DFA checkpoint +(which `experiments/dfa_direction_quality_test.py` is currently saving in +the background). Once that lands, this script can be reused with --ckpt +pointing to the penalized one. + +Run: + CUDA_VISIBLE_DEVICES=2 python experiments/measure_direction_quality_existing_ckpt.py +""" +import os +import sys +import argparse + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as transforms +from torch.utils.data import DataLoader + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from models.residual_mlp import ResidualMLP + + +def load_eval(n=2048, device="cuda:0"): + tv = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), + ]) + te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) + loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) + xs, ys = [], [] + for x, y in loader: + xs.append(x.view(x.size(0), -1)); ys.append(y) + if sum(xb.size(0) for xb in xs) >= n: + break + return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device) + + +def reconstruct_training_Bs(seed, d_hidden=256, num_blocks=4, num_classes=10, device="cuda:0"): + """Replay the training-time RNG sequence to recover the Bs.""" + torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) + _ = ResidualMLP(3072, d_hidden, num_classes, num_blocks) # consume model init RNG + Bs = [torch.randn(d_hidden, num_classes, device=device) / np.sqrt(num_classes) for _ in range(num_blocks)] + return Bs + + +def per_layer_bp_grads(model, x, y): + with torch.enable_grad(): + h = model.embed(x) + hiddens = [h] + for block in model.blocks: + h = h + block(h) + hiddens.append(h) + logits = model.out_head(model.out_ln(h)) + loss = F.cross_entropy(logits, y) + grads = torch.autograd.grad(loss, hiddens) + return list(grads), logits.detach() + + +def cosine_no_clamp(a, b): + eps = 1e-30 + an = a.norm(dim=-1, keepdim=True).clamp_min(eps) + bn = b.norm(dim=-1, keepdim=True).clamp_min(eps) + return ((a / an) * (b / bn)).sum(dim=-1) + + +def measure(model, Bs, x, y): + L = model.num_blocks + grads, logits = per_layer_bp_grads(model, x, y) + e_T = F.softmax(logits, dim=-1).clone() + e_T[torch.arange(len(y), device=y.device), y] -= 1 + out = [] + for l in range(L + 1): + b_idx = min(l, L - 1) + a_l = (e_T @ Bs[b_idx].T).detach() + g_l = grads[l].detach() + cos = cosine_no_clamp(a_l, g_l) + out.append({ + "layer": l, + "cos_mean": float(cos.mean().item()), + "cos_std": float(cos.std().item()), + "g_l_norm_median": float(g_l.norm(dim=-1).median().item()), + "a_l_norm_median": float(a_l.norm(dim=-1).median().item()), + }) + return out + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--seed", type=int, default=42) + p.add_argument("--ckpt", type=str, + default="results/confirmatory/checkpoints_A2/dfa_s42.pt") + p.add_argument("--label", type=str, default="vanilla DFA") + args = p.parse_args() + + device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + print(f"Device: {device}") + + x, y = load_eval(n=2048, device=device) + sd = torch.load(args.ckpt, map_location=device, weights_only=False) + if isinstance(sd, dict) and "state_dict" in sd: + # Direction-quality script saves dict with state_dict + Bs + model = ResidualMLP(3072, 256, 10, 4).to(device) + model.load_state_dict(sd["state_dict"]) + Bs = [b.to(device) for b in sd["Bs"]] if "Bs" in sd else None + else: + model = ResidualMLP(3072, 256, 10, 4).to(device) + model.load_state_dict(sd) + Bs = None + + if Bs is None: + print(f" Reconstructing training Bs from RNG seed {args.seed}...") + Bs = reconstruct_training_Bs(args.seed, device=device) + + print(f"\n=== {args.label} (seed {args.seed}) ===") + print(f" ckpt: {args.ckpt}") + out = measure(model, Bs, x, y) + print(f" per-layer DFA-credit vs BP-grad cosine:") + for entry in out: + print(f" l{entry['layer']}: cos_mean={entry['cos_mean']:+.4f} " + f"(±{entry['cos_std']:.4f}) ‖g‖={entry['g_l_norm_median']:.2e} " + f"‖a‖={entry['a_l_norm_median']:.2e}") + mean_cos = np.mean([e["cos_mean"] for e in out]) + print(f" layer-mean cos: {mean_cos:+.4f}") + + +if __name__ == "__main__": + main() |
