diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/ra_mlp.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/ra_mlp.py')
| -rw-r--r-- | ep_run/ra_mlp.py | 287 |
1 files changed, 287 insertions, 0 deletions
diff --git a/ep_run/ra_mlp.py b/ep_run/ra_mlp.py new file mode 100644 index 0000000..5e3de91 --- /dev/null +++ b/ep_run/ra_mlp.py @@ -0,0 +1,287 @@ +""" +Reciprocal Alignment (RA) exploration on MLP + MNIST. 5 arms. + +arms: + bp: W via BP, B ignored + fa: W via FA with B=random, B fixed + ra_rev: W via FA with B, B via reverse task FA with W (lambda=0) + ra_recon: W via FA with B, B via layer-local reconstruction (lambda=1) + ra_comb: W via FA with B, B = (1-lam)*rev + lam*recon + +model: pure linear MLP with LayerNorm between layers (whitening for RA-recon fixpoint). +task: MNIST classification, cross-entropy. +diag: test acc, per-layer ||B_l - W_l.T|| / ||W_l|| alignment, losses. +""" +import argparse +import json +import time +from pathlib import Path + +import torch +import torch.nn.functional as F +from torch.utils.data import DataLoader +from torchvision import datasets, transforms + + +def make_dims(in_dim, hidden, depth, out_dim): + """depth = number of weight matrices. hidden count = depth - 1.""" + if depth < 2: + raise ValueError("depth must be >= 2") + return [in_dim] + [hidden] * (depth - 1) + [out_dim] + + +def init_mats(dims, device, seed, activation="linear"): + g = torch.Generator(device="cpu").manual_seed(seed) + gain = 1.0 if activation in ("linear", "tanh") else (2.0 ** 0.5) # Kaiming for ReLU, Xavier otherwise + W, B = [], [] + for i in range(len(dims) - 1): + d_in, d_out = dims[i], dims[i + 1] + w = torch.empty(d_out, d_in).normal_(std=gain * (1.0 / d_in) ** 0.5, generator=g).to(device) + b = torch.empty(d_in, d_out).normal_(std=gain * (1.0 / d_out) ** 0.5, generator=g).to(device) + W.append(w) + B.append(b) + return W, B + + +def apply_act(z, activation): + if activation == "linear": + return z + if activation == "relu": + return F.relu(z) + if activation == "tanh": + return torch.tanh(z) + raise ValueError(activation) + + +def act_deriv(z, activation): + """d(act)/dz evaluated at z.""" + if activation == "linear": + return torch.ones_like(z) + if activation == "relu": + return (z > 0).float() + if activation == "tanh": + return 1 - torch.tanh(z).pow(2) + raise ValueError(activation) + + +def forward_w(x, W, activation="linear", ln=True): + """MLP forward with activation + optional LN between hidden layers. Returns (acts[0..L], pre[0..L]). + pre[l] for l in 1..L is the pre-activation z[l] = W[l-1] @ a[l-1]. pre[0] = None. + acts[l] for l in 0..L is the post-activation (or input at l=0, logits at l=L). + """ + acts = [x] + pre = [None] + h = x + L = len(W) + for l in range(L): + z = h @ W[l].T # (N, d_{l+1}) + pre.append(z) + if l < L - 1: + h = apply_act(z, activation) + if ln: + h = F.layer_norm(h, h.shape[-1:]) + else: + h = z # logits, no activation + acts.append(h) + return acts, pre + + +def w_grads(acts, pre, grad_top, W, B, use_bp, activation="linear"): + """W update via BP or FA (with B as feedback). Applies activation derivative for hidden layers. + LN derivative is ignored (approximate, consistent BP/FA comparison).""" + L = len(W) + grads = [None] * L + delta = grad_top # (N, d_L), at the final (logit) layer + for l in reversed(range(L)): + grads[l] = delta.T @ acts[l] # (d_{l+1}, d_l) + if l > 0: + if use_bp: + grad_a = delta @ W[l] # (N, d_l) in a-space + else: + grad_a = delta @ B[l].T # FA feedback via B + # Convert grad_a (in activation output space) to grad_z (in pre-activation space) + delta = grad_a * act_deriv(pre[l], activation) + return grads + + +def b_grads_rev(acts, x, W, B, ln_b=True): + """Reverse task FA: B pathway reconstructs x from top, W as feedback matrix. + b[L] = a[L], b[l-1] = LN(B[l-1] @ b[l]) (LN on B pathway mirrors W's LN for stability). + L_rev = ||b[0] - x||^2 (mean over batch) + FA for B: eta[0] = (b[0]-x)/N; eta[l] = eta[l-1] @ W[l-1].T + Update: dB[l] = eta[l].T @ b[l+1] + Note: LN derivative in B pathway is ignored (approximate, consistent with W pathway treatment). + """ + L = len(B) + N = x.shape[0] + b = [None] * (L + 1) + b[L] = acts[L] + for l in range(L, 0, -1): + h = b[l] @ B[l - 1].T # (N, d_{l-1}) + if ln_b and l > 1: # don't LN the last (bottom) output since it targets x directly + h = F.layer_norm(h, h.shape[-1:]) + b[l - 1] = h + eta = [None] * (L + 1) + eta[0] = (b[0] - x) / N # batch-averaged loss gradient + for l in range(1, L + 1): + eta[l] = eta[l - 1] @ W[l - 1].T # (N, d_l) + grads_B = [None] * L + for l in range(L): + grads_B[l] = eta[l].T @ b[l + 1] # (d_l, d_{l+1}) + rev_loss = (b[0] - x).pow(2).sum(-1).mean().item() + return grads_B, rev_loss + + +def b_grads_recon(acts, B): + """Layer-local reconstruction: L_B^l = ||a[l+1] @ B[l].T - a[l]||^2 per layer (mean over batch). + dB[l] = (r.T @ a[l+1]) / N, r = a[l+1] @ B[l].T - a[l] (GD on that quadratic) + """ + L = len(B) + N = acts[0].shape[0] + grads_B = [None] * L + total_loss = 0.0 + for l in range(L): + X = acts[l + 1] # (N, d_{l+1}) + Y = acts[l] # (N, d_l) + r = X @ B[l].T - Y # (N, d_l) + grads_B[l] = (r.T @ X) / N # (d_l, d_{l+1}) + total_loss += r.pow(2).sum(-1).mean().item() + return grads_B, total_loss + + +def alignment(W, B): + """Per-layer ||B[l] - W[l].T||_F / ||W[l]||_F.""" + out = [] + for l in range(len(W)): + diff = B[l] - W[l].T + out.append((diff.norm() / (W[l].norm() + 1e-9)).item()) + return out + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--arm", required=True, choices=["bp", "fa", "ra_rev", "ra_recon", "ra_comb"]) + p.add_argument("--lam", type=float, default=0.5, help="mixing coef for ra_comb (0=pure rev, 1=pure recon)") + p.add_argument("--epochs", type=int, default=10) + p.add_argument("--batch_size", type=int, default=128) + p.add_argument("--lr", type=float, default=0.05) + p.add_argument("--lr_b", type=float, default=None, help="separate LR for B (defaults to --lr)") + p.add_argument("--seed", type=int, default=42) + p.add_argument("--out", type=str, required=True) + p.add_argument("--data_dir", type=str, default="data/mnist") + p.add_argument("--no_ln", action="store_true", help="disable LayerNorm in forward") + p.add_argument("--log_every", type=int, default=100) + p.add_argument("--depth", type=int, default=4, help="number of weight matrices (>=2)") + p.add_argument("--hidden", type=int, default=256) + p.add_argument("--activation", choices=["linear", "relu", "tanh"], default="linear") + args = p.parse_args() + + device = "cuda" if torch.cuda.is_available() else "cpu" + torch.manual_seed(args.seed) + + out_dir = Path(args.out) + out_dir.mkdir(parents=True, exist_ok=True) + (out_dir / "config.json").write_text(json.dumps(vars(args), indent=2)) + log_path = out_dir / "log.jsonl" + log_path.write_text("") # truncate + + # Data + tfm = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)), + transforms.Lambda(lambda t: t.view(-1)), + ]) + train_ds = datasets.MNIST(args.data_dir, train=True, download=True, transform=tfm) + test_ds = datasets.MNIST(args.data_dir, train=False, download=True, transform=tfm) + train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) + test_loader = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=2, pin_memory=True) + + dims = make_dims(784, args.hidden, args.depth, 10) + W, B = init_mats(dims, device, args.seed, activation=args.activation) + lr_b = args.lr_b if args.lr_b is not None else args.lr + ln = not args.no_ln + + t0 = time.time() + + def log(rec): + rec["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(rec) + "\n") + + log({"event": "start", "arm": args.arm, "dims": dims, "activation": args.activation, + "lr": args.lr, "lr_b": lr_b, "ln": ln, "lam": args.lam}) + print(f"[{args.arm}] device={device} dims={dims} act={args.activation} ln={ln} lr={args.lr} lr_b={lr_b} lam={args.lam}") + + def test_acc(): + correct = 0 + total = 0 + with torch.no_grad(): + for x, y in test_loader: + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + acts, _ = forward_w(x, W, activation=args.activation, ln=ln) + logits = acts[-1] + pred = logits.argmax(-1) + correct += (pred == y).sum().item() + total += y.shape[0] + return correct / total + + step = 0 + use_bp_for_W = (args.arm == "bp") + for epoch in range(args.epochs): + for x, y in train_loader: + x = x.to(device, non_blocking=True) + y = y.to(device, non_blocking=True) + + acts, pre = forward_w(x, W, activation=args.activation, ln=ln) + logits = acts[-1] + ce = F.cross_entropy(logits, y).item() + N = x.shape[0] + probs = F.softmax(logits, dim=-1) + onehot = F.one_hot(y, num_classes=10).float() + grad_top = (probs - onehot) / N + + # W grads (uses W, B snapshot) + gW = w_grads(acts, pre, grad_top, W, B, use_bp_for_W, activation=args.activation) + + # B grads + gB = None + rev_loss = None + recon_loss = None + if args.arm == "ra_rev": + gB, rev_loss = b_grads_rev(acts, x, W, B) + elif args.arm == "ra_recon": + gB, recon_loss = b_grads_recon(acts, B) + elif args.arm == "ra_comb": + gB_rev, rev_loss = b_grads_rev(acts, x, W, B) + gB_rec, recon_loss = b_grads_recon(acts, B) + gB = [(1 - args.lam) * gB_rev[l] + args.lam * gB_rec[l] for l in range(len(B))] + + # apply updates + for l in range(len(W)): + W[l] -= args.lr * gW[l] + if gB is not None: + for l in range(len(B)): + B[l] -= lr_b * gB[l] + + step += 1 + if step % args.log_every == 0: + align = alignment(W, B) + log({ + "event": "step", "step": step, "epoch": epoch, + "loss_ce": ce, + "rev_loss": rev_loss, "recon_loss": recon_loss, + "alignment": align, + }) + + acc = test_acc() + align = alignment(W, B) + log({"event": "eval", "epoch": epoch, "step": step, "test_acc": acc, "alignment": align}) + print(f"[{args.arm}] epoch {epoch:2d} step {step:5d} test_acc {acc:.4f} align {[f'{a:.3f}' for a in align]}") + + log({"event": "done", "step": step, "final_acc": acc, "final_alignment": align}) + print(f"[{args.arm}] done in {time.time() - t0:.1f}s final_acc={acc:.4f}") + + +if __name__ == "__main__": + main() |
