""" Fast direction-quality measurement using EXISTING checkpoints. No training. Loads the vanilla DFA s42 checkpoint from `results/confirmatory/checkpoints_A2/dfa_s42.pt` and computes the per-layer cosine between DFA's local credit signal `a_l = e_T @ B_l^T` and the BP gradient `g_l = ∂L/∂h_l`. This is the "Γ on the degenerate reference" measurement — what the field-standard FA evaluation reports. The catch: the trained-time random feedback Bs were not saved in the existing checkpoint. We reconstruct them by replaying the training-time RNG sequence (`torch.manual_seed(seed); ResidualMLP(...); randn(d, C)`), matching what the original DFA trainer did. For the "scale-fixed" comparison case, we'll need a penalized DFA checkpoint (which `experiments/dfa_direction_quality_test.py` is currently saving in the background). Once that lands, this script can be reused with --ckpt pointing to the penalized one. Run: CUDA_VISIBLE_DEVICES=2 python experiments/measure_direction_quality_existing_ckpt.py """ import os import sys import argparse import numpy as np import torch import torch.nn.functional as F import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) from models.residual_mlp import ResidualMLP def load_eval(n=2048, device="cuda:0"): tv = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)), ]) te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv) loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0) xs, ys = [], [] for x, y in loader: xs.append(x.view(x.size(0), -1)); ys.append(y) if sum(xb.size(0) for xb in xs) >= n: break return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device) def reconstruct_training_Bs(seed, d_hidden=256, num_blocks=4, num_classes=10, device="cuda:0"): """Replay the training-time RNG sequence to recover the Bs.""" torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed) _ = ResidualMLP(3072, d_hidden, num_classes, num_blocks) # consume model init RNG Bs = [torch.randn(d_hidden, num_classes, device=device) / np.sqrt(num_classes) for _ in range(num_blocks)] return Bs def per_layer_bp_grads(model, x, y): with torch.enable_grad(): h = model.embed(x) hiddens = [h] for block in model.blocks: h = h + block(h) hiddens.append(h) logits = model.out_head(model.out_ln(h)) loss = F.cross_entropy(logits, y) grads = torch.autograd.grad(loss, hiddens) return list(grads), logits.detach() def cosine_no_clamp(a, b): eps = 1e-30 an = a.norm(dim=-1, keepdim=True).clamp_min(eps) bn = b.norm(dim=-1, keepdim=True).clamp_min(eps) return ((a / an) * (b / bn)).sum(dim=-1) def measure(model, Bs, x, y): L = model.num_blocks grads, logits = per_layer_bp_grads(model, x, y) e_T = F.softmax(logits, dim=-1).clone() e_T[torch.arange(len(y), device=y.device), y] -= 1 out = [] for l in range(L + 1): b_idx = min(l, L - 1) a_l = (e_T @ Bs[b_idx].T).detach() g_l = grads[l].detach() cos = cosine_no_clamp(a_l, g_l) out.append({ "layer": l, "cos_mean": float(cos.mean().item()), "cos_std": float(cos.std().item()), "g_l_norm_median": float(g_l.norm(dim=-1).median().item()), "a_l_norm_median": float(a_l.norm(dim=-1).median().item()), }) return out def main(): p = argparse.ArgumentParser() p.add_argument("--seed", type=int, default=42) p.add_argument("--ckpt", type=str, default="results/confirmatory/checkpoints_A2/dfa_s42.pt") p.add_argument("--label", type=str, default="vanilla DFA") args = p.parse_args() device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") print(f"Device: {device}") x, y = load_eval(n=2048, device=device) sd = torch.load(args.ckpt, map_location=device, weights_only=False) if isinstance(sd, dict) and "state_dict" in sd: # Direction-quality script saves dict with state_dict + Bs model = ResidualMLP(3072, 256, 10, 4).to(device) model.load_state_dict(sd["state_dict"]) Bs = [b.to(device) for b in sd["Bs"]] if "Bs" in sd else None else: model = ResidualMLP(3072, 256, 10, 4).to(device) model.load_state_dict(sd) Bs = None if Bs is None: print(f" Reconstructing training Bs from RNG seed {args.seed}...") Bs = reconstruct_training_Bs(args.seed, device=device) print(f"\n=== {args.label} (seed {args.seed}) ===") print(f" ckpt: {args.ckpt}") out = measure(model, Bs, x, y) print(f" per-layer DFA-credit vs BP-grad cosine:") for entry in out: print(f" l{entry['layer']}: cos_mean={entry['cos_mean']:+.4f} " f"(±{entry['cos_std']:.4f}) ‖g‖={entry['g_l_norm_median']:.2e} " f"‖a‖={entry['a_l_norm_median']:.2e}") mean_cos = np.mean([e["cos_mean"] for e in out]) print(f" layer-mean cos: {mean_cos:+.4f}") if __name__ == "__main__": main()