""" Reciprocal Alignment (RA) exploration on MLP + MNIST. 5 arms. arms: bp: W via BP, B ignored fa: W via FA with B=random, B fixed ra_rev: W via FA with B, B via reverse task FA with W (lambda=0) ra_recon: W via FA with B, B via layer-local reconstruction (lambda=1) ra_comb: W via FA with B, B = (1-lam)*rev + lam*recon model: pure linear MLP with LayerNorm between layers (whitening for RA-recon fixpoint). task: MNIST classification, cross-entropy. diag: test acc, per-layer ||B_l - W_l.T|| / ||W_l|| alignment, losses. """ import argparse import json import time from pathlib import Path import torch import torch.nn.functional as F from torch.utils.data import DataLoader from torchvision import datasets, transforms def make_dims(in_dim, hidden, depth, out_dim): """depth = number of weight matrices. hidden count = depth - 1.""" if depth < 2: raise ValueError("depth must be >= 2") return [in_dim] + [hidden] * (depth - 1) + [out_dim] def init_mats(dims, device, seed, activation="linear"): g = torch.Generator(device="cpu").manual_seed(seed) gain = 1.0 if activation in ("linear", "tanh") else (2.0 ** 0.5) # Kaiming for ReLU, Xavier otherwise W, B = [], [] for i in range(len(dims) - 1): d_in, d_out = dims[i], dims[i + 1] w = torch.empty(d_out, d_in).normal_(std=gain * (1.0 / d_in) ** 0.5, generator=g).to(device) b = torch.empty(d_in, d_out).normal_(std=gain * (1.0 / d_out) ** 0.5, generator=g).to(device) W.append(w) B.append(b) return W, B def apply_act(z, activation): if activation == "linear": return z if activation == "relu": return F.relu(z) if activation == "tanh": return torch.tanh(z) raise ValueError(activation) def act_deriv(z, activation): """d(act)/dz evaluated at z.""" if activation == "linear": return torch.ones_like(z) if activation == "relu": return (z > 0).float() if activation == "tanh": return 1 - torch.tanh(z).pow(2) raise ValueError(activation) def forward_w(x, W, activation="linear", ln=True): """MLP forward with activation + optional LN between hidden layers. Returns (acts[0..L], pre[0..L]). pre[l] for l in 1..L is the pre-activation z[l] = W[l-1] @ a[l-1]. pre[0] = None. acts[l] for l in 0..L is the post-activation (or input at l=0, logits at l=L). """ acts = [x] pre = [None] h = x L = len(W) for l in range(L): z = h @ W[l].T # (N, d_{l+1}) pre.append(z) if l < L - 1: h = apply_act(z, activation) if ln: h = F.layer_norm(h, h.shape[-1:]) else: h = z # logits, no activation acts.append(h) return acts, pre def w_grads(acts, pre, grad_top, W, B, use_bp, activation="linear"): """W update via BP or FA (with B as feedback). Applies activation derivative for hidden layers. LN derivative is ignored (approximate, consistent BP/FA comparison).""" L = len(W) grads = [None] * L delta = grad_top # (N, d_L), at the final (logit) layer for l in reversed(range(L)): grads[l] = delta.T @ acts[l] # (d_{l+1}, d_l) if l > 0: if use_bp: grad_a = delta @ W[l] # (N, d_l) in a-space else: grad_a = delta @ B[l].T # FA feedback via B # Convert grad_a (in activation output space) to grad_z (in pre-activation space) delta = grad_a * act_deriv(pre[l], activation) return grads def b_grads_rev(acts, x, W, B, ln_b=True): """Reverse task FA: B pathway reconstructs x from top, W as feedback matrix. b[L] = a[L], b[l-1] = LN(B[l-1] @ b[l]) (LN on B pathway mirrors W's LN for stability). L_rev = ||b[0] - x||^2 (mean over batch) FA for B: eta[0] = (b[0]-x)/N; eta[l] = eta[l-1] @ W[l-1].T Update: dB[l] = eta[l].T @ b[l+1] Note: LN derivative in B pathway is ignored (approximate, consistent with W pathway treatment). """ L = len(B) N = x.shape[0] b = [None] * (L + 1) b[L] = acts[L] for l in range(L, 0, -1): h = b[l] @ B[l - 1].T # (N, d_{l-1}) if ln_b and l > 1: # don't LN the last (bottom) output since it targets x directly h = F.layer_norm(h, h.shape[-1:]) b[l - 1] = h eta = [None] * (L + 1) eta[0] = (b[0] - x) / N # batch-averaged loss gradient for l in range(1, L + 1): eta[l] = eta[l - 1] @ W[l - 1].T # (N, d_l) grads_B = [None] * L for l in range(L): grads_B[l] = eta[l].T @ b[l + 1] # (d_l, d_{l+1}) rev_loss = (b[0] - x).pow(2).sum(-1).mean().item() return grads_B, rev_loss def b_grads_recon(acts, B): """Layer-local reconstruction: L_B^l = ||a[l+1] @ B[l].T - a[l]||^2 per layer (mean over batch). dB[l] = (r.T @ a[l+1]) / N, r = a[l+1] @ B[l].T - a[l] (GD on that quadratic) """ L = len(B) N = acts[0].shape[0] grads_B = [None] * L total_loss = 0.0 for l in range(L): X = acts[l + 1] # (N, d_{l+1}) Y = acts[l] # (N, d_l) r = X @ B[l].T - Y # (N, d_l) grads_B[l] = (r.T @ X) / N # (d_l, d_{l+1}) total_loss += r.pow(2).sum(-1).mean().item() return grads_B, total_loss def alignment(W, B): """Per-layer ||B[l] - W[l].T||_F / ||W[l]||_F.""" out = [] for l in range(len(W)): diff = B[l] - W[l].T out.append((diff.norm() / (W[l].norm() + 1e-9)).item()) return out def main(): p = argparse.ArgumentParser() p.add_argument("--arm", required=True, choices=["bp", "fa", "ra_rev", "ra_recon", "ra_comb"]) p.add_argument("--lam", type=float, default=0.5, help="mixing coef for ra_comb (0=pure rev, 1=pure recon)") p.add_argument("--epochs", type=int, default=10) p.add_argument("--batch_size", type=int, default=128) p.add_argument("--lr", type=float, default=0.05) p.add_argument("--lr_b", type=float, default=None, help="separate LR for B (defaults to --lr)") p.add_argument("--seed", type=int, default=42) p.add_argument("--out", type=str, required=True) p.add_argument("--data_dir", type=str, default="data/mnist") p.add_argument("--no_ln", action="store_true", help="disable LayerNorm in forward") p.add_argument("--log_every", type=int, default=100) p.add_argument("--depth", type=int, default=4, help="number of weight matrices (>=2)") p.add_argument("--hidden", type=int, default=256) p.add_argument("--activation", choices=["linear", "relu", "tanh"], default="linear") args = p.parse_args() device = "cuda" if torch.cuda.is_available() else "cpu" torch.manual_seed(args.seed) out_dir = Path(args.out) out_dir.mkdir(parents=True, exist_ok=True) (out_dir / "config.json").write_text(json.dumps(vars(args), indent=2)) log_path = out_dir / "log.jsonl" log_path.write_text("") # truncate # Data tfm = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,)), transforms.Lambda(lambda t: t.view(-1)), ]) train_ds = datasets.MNIST(args.data_dir, train=True, download=True, transform=tfm) test_ds = datasets.MNIST(args.data_dir, train=False, download=True, transform=tfm) train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True) test_loader = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=2, pin_memory=True) dims = make_dims(784, args.hidden, args.depth, 10) W, B = init_mats(dims, device, args.seed, activation=args.activation) lr_b = args.lr_b if args.lr_b is not None else args.lr ln = not args.no_ln t0 = time.time() def log(rec): rec["t"] = time.time() - t0 with open(log_path, "a") as f: f.write(json.dumps(rec) + "\n") log({"event": "start", "arm": args.arm, "dims": dims, "activation": args.activation, "lr": args.lr, "lr_b": lr_b, "ln": ln, "lam": args.lam}) print(f"[{args.arm}] device={device} dims={dims} act={args.activation} ln={ln} lr={args.lr} lr_b={lr_b} lam={args.lam}") def test_acc(): correct = 0 total = 0 with torch.no_grad(): for x, y in test_loader: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) acts, _ = forward_w(x, W, activation=args.activation, ln=ln) logits = acts[-1] pred = logits.argmax(-1) correct += (pred == y).sum().item() total += y.shape[0] return correct / total step = 0 use_bp_for_W = (args.arm == "bp") for epoch in range(args.epochs): for x, y in train_loader: x = x.to(device, non_blocking=True) y = y.to(device, non_blocking=True) acts, pre = forward_w(x, W, activation=args.activation, ln=ln) logits = acts[-1] ce = F.cross_entropy(logits, y).item() N = x.shape[0] probs = F.softmax(logits, dim=-1) onehot = F.one_hot(y, num_classes=10).float() grad_top = (probs - onehot) / N # W grads (uses W, B snapshot) gW = w_grads(acts, pre, grad_top, W, B, use_bp_for_W, activation=args.activation) # B grads gB = None rev_loss = None recon_loss = None if args.arm == "ra_rev": gB, rev_loss = b_grads_rev(acts, x, W, B) elif args.arm == "ra_recon": gB, recon_loss = b_grads_recon(acts, B) elif args.arm == "ra_comb": gB_rev, rev_loss = b_grads_rev(acts, x, W, B) gB_rec, recon_loss = b_grads_recon(acts, B) gB = [(1 - args.lam) * gB_rev[l] + args.lam * gB_rec[l] for l in range(len(B))] # apply updates for l in range(len(W)): W[l] -= args.lr * gW[l] if gB is not None: for l in range(len(B)): B[l] -= lr_b * gB[l] step += 1 if step % args.log_every == 0: align = alignment(W, B) log({ "event": "step", "step": step, "epoch": epoch, "loss_ce": ce, "rev_loss": rev_loss, "recon_loss": recon_loss, "alignment": align, }) acc = test_acc() align = alignment(W, B) log({"event": "eval", "epoch": epoch, "step": step, "test_acc": acc, "alignment": align}) print(f"[{args.arm}] epoch {epoch:2d} step {step:5d} test_acc {acc:.4f} align {[f'{a:.3f}' for a in align]}") log({"event": "done", "step": step, "final_acc": acc, "final_alignment": align}) print(f"[{args.arm}] done in {time.time() - t0:.1f}s final_acc={acc:.4f}") if __name__ == "__main__": main()