summaryrefslogtreecommitdiff
path: root/experiments
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:21:32 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-07 23:21:32 -0500
commit8f67bdeebac543961871b9896a62cd07b7a5be26 (patch)
tree63fec268bf894b61875ccf90e173af4e4264cb81 /experiments
parent5771a122300f9d30a6290fcbfc9bffb5f380e648 (diff)
Add fast direction-quality measurement on existing DFA checkpoints
3-seed result on the existing dfa_s{42,123,456}.pt checkpoints from results/confirmatory/checkpoints_A2/, computing per-layer cosine of DFA's local credit signal e_T@B_l^T vs the true BP gradient at h_l. Key findings: per-layer cos (3-seed mean): l0: +0.42 (high — embedding alignment) l1: +0.006 (essentially zero) l2: -0.015 (essentially zero) l3: -0.004 (essentially zero) l4: -0.004 (essentially zero) layer-mean across all 5: +0.07-0.10 The deep blocks (l1-l4) have essentially zero alignment with BP grad in the vanilla scale-failure regime. Layer 0 dominates the headline. The script reconstructs the training-time random Bs by replaying the RNG sequence (torch.manual_seed + ResidualMLP construction + randn draws), since the existing checkpoints don't save Bs. For the still-running direction-quality experiment which DOES save Bs, the script auto-detects the dict format and uses the saved Bs directly.
Diffstat (limited to 'experiments')
-rw-r--r--experiments/measure_direction_quality_existing_ckpt.py143
1 files changed, 143 insertions, 0 deletions
diff --git a/experiments/measure_direction_quality_existing_ckpt.py b/experiments/measure_direction_quality_existing_ckpt.py
new file mode 100644
index 0000000..d150eb3
--- /dev/null
+++ b/experiments/measure_direction_quality_existing_ckpt.py
@@ -0,0 +1,143 @@
+"""
+Fast direction-quality measurement using EXISTING checkpoints. No training.
+
+Loads the vanilla DFA s42 checkpoint from
+`results/confirmatory/checkpoints_A2/dfa_s42.pt` and computes the per-layer
+cosine between DFA's local credit signal `a_l = e_T @ B_l^T` and the BP
+gradient `g_l = ∂L/∂h_l`. This is the "Γ on the degenerate reference"
+measurement — what the field-standard FA evaluation reports.
+
+The catch: the trained-time random feedback Bs were not saved in the
+existing checkpoint. We reconstruct them by replaying the training-time
+RNG sequence (`torch.manual_seed(seed); ResidualMLP(...); randn(d, C)`),
+matching what the original DFA trainer did.
+
+For the "scale-fixed" comparison case, we'll need a penalized DFA checkpoint
+(which `experiments/dfa_direction_quality_test.py` is currently saving in
+the background). Once that lands, this script can be reused with --ckpt
+pointing to the penalized one.
+
+Run:
+ CUDA_VISIBLE_DEVICES=2 python experiments/measure_direction_quality_existing_ckpt.py
+"""
+import os
+import sys
+import argparse
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision
+import torchvision.transforms as transforms
+from torch.utils.data import DataLoader
+
+sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
+
+from models.residual_mlp import ResidualMLP
+
+
+def load_eval(n=2048, device="cuda:0"):
+ tv = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.4914, 0.4822, 0.4465), (0.2470, 0.2435, 0.2616)),
+ ])
+ te = torchvision.datasets.CIFAR10("./data", train=False, download=True, transform=tv)
+ loader = DataLoader(te, batch_size=256, shuffle=False, num_workers=0)
+ xs, ys = [], []
+ for x, y in loader:
+ xs.append(x.view(x.size(0), -1)); ys.append(y)
+ if sum(xb.size(0) for xb in xs) >= n:
+ break
+ return torch.cat(xs)[:n].to(device), torch.cat(ys)[:n].to(device)
+
+
+def reconstruct_training_Bs(seed, d_hidden=256, num_blocks=4, num_classes=10, device="cuda:0"):
+ """Replay the training-time RNG sequence to recover the Bs."""
+ torch.manual_seed(seed); np.random.seed(seed); torch.cuda.manual_seed_all(seed)
+ _ = ResidualMLP(3072, d_hidden, num_classes, num_blocks) # consume model init RNG
+ Bs = [torch.randn(d_hidden, num_classes, device=device) / np.sqrt(num_classes) for _ in range(num_blocks)]
+ return Bs
+
+
+def per_layer_bp_grads(model, x, y):
+ with torch.enable_grad():
+ h = model.embed(x)
+ hiddens = [h]
+ for block in model.blocks:
+ h = h + block(h)
+ hiddens.append(h)
+ logits = model.out_head(model.out_ln(h))
+ loss = F.cross_entropy(logits, y)
+ grads = torch.autograd.grad(loss, hiddens)
+ return list(grads), logits.detach()
+
+
+def cosine_no_clamp(a, b):
+ eps = 1e-30
+ an = a.norm(dim=-1, keepdim=True).clamp_min(eps)
+ bn = b.norm(dim=-1, keepdim=True).clamp_min(eps)
+ return ((a / an) * (b / bn)).sum(dim=-1)
+
+
+def measure(model, Bs, x, y):
+ L = model.num_blocks
+ grads, logits = per_layer_bp_grads(model, x, y)
+ e_T = F.softmax(logits, dim=-1).clone()
+ e_T[torch.arange(len(y), device=y.device), y] -= 1
+ out = []
+ for l in range(L + 1):
+ b_idx = min(l, L - 1)
+ a_l = (e_T @ Bs[b_idx].T).detach()
+ g_l = grads[l].detach()
+ cos = cosine_no_clamp(a_l, g_l)
+ out.append({
+ "layer": l,
+ "cos_mean": float(cos.mean().item()),
+ "cos_std": float(cos.std().item()),
+ "g_l_norm_median": float(g_l.norm(dim=-1).median().item()),
+ "a_l_norm_median": float(a_l.norm(dim=-1).median().item()),
+ })
+ return out
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--ckpt", type=str,
+ default="results/confirmatory/checkpoints_A2/dfa_s42.pt")
+ p.add_argument("--label", type=str, default="vanilla DFA")
+ args = p.parse_args()
+
+ device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ print(f"Device: {device}")
+
+ x, y = load_eval(n=2048, device=device)
+ sd = torch.load(args.ckpt, map_location=device, weights_only=False)
+ if isinstance(sd, dict) and "state_dict" in sd:
+ # Direction-quality script saves dict with state_dict + Bs
+ model = ResidualMLP(3072, 256, 10, 4).to(device)
+ model.load_state_dict(sd["state_dict"])
+ Bs = [b.to(device) for b in sd["Bs"]] if "Bs" in sd else None
+ else:
+ model = ResidualMLP(3072, 256, 10, 4).to(device)
+ model.load_state_dict(sd)
+ Bs = None
+
+ if Bs is None:
+ print(f" Reconstructing training Bs from RNG seed {args.seed}...")
+ Bs = reconstruct_training_Bs(args.seed, device=device)
+
+ print(f"\n=== {args.label} (seed {args.seed}) ===")
+ print(f" ckpt: {args.ckpt}")
+ out = measure(model, Bs, x, y)
+ print(f" per-layer DFA-credit vs BP-grad cosine:")
+ for entry in out:
+ print(f" l{entry['layer']}: cos_mean={entry['cos_mean']:+.4f} "
+ f"(±{entry['cos_std']:.4f}) ‖g‖={entry['g_l_norm_median']:.2e} "
+ f"‖a‖={entry['a_l_norm_median']:.2e}")
+ mean_cos = np.mean([e["cos_mean"] for e in out])
+ print(f" layer-mean cos: {mean_cos:+.4f}")
+
+
+if __name__ == "__main__":
+ main()