summaryrefslogtreecommitdiff
path: root/ep_run/ra_mlp.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/ra_mlp.py')
-rw-r--r--ep_run/ra_mlp.py287
1 files changed, 287 insertions, 0 deletions
diff --git a/ep_run/ra_mlp.py b/ep_run/ra_mlp.py
new file mode 100644
index 0000000..5e3de91
--- /dev/null
+++ b/ep_run/ra_mlp.py
@@ -0,0 +1,287 @@
+"""
+Reciprocal Alignment (RA) exploration on MLP + MNIST. 5 arms.
+
+arms:
+ bp: W via BP, B ignored
+ fa: W via FA with B=random, B fixed
+ ra_rev: W via FA with B, B via reverse task FA with W (lambda=0)
+ ra_recon: W via FA with B, B via layer-local reconstruction (lambda=1)
+ ra_comb: W via FA with B, B = (1-lam)*rev + lam*recon
+
+model: pure linear MLP with LayerNorm between layers (whitening for RA-recon fixpoint).
+task: MNIST classification, cross-entropy.
+diag: test acc, per-layer ||B_l - W_l.T|| / ||W_l|| alignment, losses.
+"""
+import argparse
+import json
+import time
+from pathlib import Path
+
+import torch
+import torch.nn.functional as F
+from torch.utils.data import DataLoader
+from torchvision import datasets, transforms
+
+
+def make_dims(in_dim, hidden, depth, out_dim):
+ """depth = number of weight matrices. hidden count = depth - 1."""
+ if depth < 2:
+ raise ValueError("depth must be >= 2")
+ return [in_dim] + [hidden] * (depth - 1) + [out_dim]
+
+
+def init_mats(dims, device, seed, activation="linear"):
+ g = torch.Generator(device="cpu").manual_seed(seed)
+ gain = 1.0 if activation in ("linear", "tanh") else (2.0 ** 0.5) # Kaiming for ReLU, Xavier otherwise
+ W, B = [], []
+ for i in range(len(dims) - 1):
+ d_in, d_out = dims[i], dims[i + 1]
+ w = torch.empty(d_out, d_in).normal_(std=gain * (1.0 / d_in) ** 0.5, generator=g).to(device)
+ b = torch.empty(d_in, d_out).normal_(std=gain * (1.0 / d_out) ** 0.5, generator=g).to(device)
+ W.append(w)
+ B.append(b)
+ return W, B
+
+
+def apply_act(z, activation):
+ if activation == "linear":
+ return z
+ if activation == "relu":
+ return F.relu(z)
+ if activation == "tanh":
+ return torch.tanh(z)
+ raise ValueError(activation)
+
+
+def act_deriv(z, activation):
+ """d(act)/dz evaluated at z."""
+ if activation == "linear":
+ return torch.ones_like(z)
+ if activation == "relu":
+ return (z > 0).float()
+ if activation == "tanh":
+ return 1 - torch.tanh(z).pow(2)
+ raise ValueError(activation)
+
+
+def forward_w(x, W, activation="linear", ln=True):
+ """MLP forward with activation + optional LN between hidden layers. Returns (acts[0..L], pre[0..L]).
+ pre[l] for l in 1..L is the pre-activation z[l] = W[l-1] @ a[l-1]. pre[0] = None.
+ acts[l] for l in 0..L is the post-activation (or input at l=0, logits at l=L).
+ """
+ acts = [x]
+ pre = [None]
+ h = x
+ L = len(W)
+ for l in range(L):
+ z = h @ W[l].T # (N, d_{l+1})
+ pre.append(z)
+ if l < L - 1:
+ h = apply_act(z, activation)
+ if ln:
+ h = F.layer_norm(h, h.shape[-1:])
+ else:
+ h = z # logits, no activation
+ acts.append(h)
+ return acts, pre
+
+
+def w_grads(acts, pre, grad_top, W, B, use_bp, activation="linear"):
+ """W update via BP or FA (with B as feedback). Applies activation derivative for hidden layers.
+ LN derivative is ignored (approximate, consistent BP/FA comparison)."""
+ L = len(W)
+ grads = [None] * L
+ delta = grad_top # (N, d_L), at the final (logit) layer
+ for l in reversed(range(L)):
+ grads[l] = delta.T @ acts[l] # (d_{l+1}, d_l)
+ if l > 0:
+ if use_bp:
+ grad_a = delta @ W[l] # (N, d_l) in a-space
+ else:
+ grad_a = delta @ B[l].T # FA feedback via B
+ # Convert grad_a (in activation output space) to grad_z (in pre-activation space)
+ delta = grad_a * act_deriv(pre[l], activation)
+ return grads
+
+
+def b_grads_rev(acts, x, W, B, ln_b=True):
+ """Reverse task FA: B pathway reconstructs x from top, W as feedback matrix.
+ b[L] = a[L], b[l-1] = LN(B[l-1] @ b[l]) (LN on B pathway mirrors W's LN for stability).
+ L_rev = ||b[0] - x||^2 (mean over batch)
+ FA for B: eta[0] = (b[0]-x)/N; eta[l] = eta[l-1] @ W[l-1].T
+ Update: dB[l] = eta[l].T @ b[l+1]
+ Note: LN derivative in B pathway is ignored (approximate, consistent with W pathway treatment).
+ """
+ L = len(B)
+ N = x.shape[0]
+ b = [None] * (L + 1)
+ b[L] = acts[L]
+ for l in range(L, 0, -1):
+ h = b[l] @ B[l - 1].T # (N, d_{l-1})
+ if ln_b and l > 1: # don't LN the last (bottom) output since it targets x directly
+ h = F.layer_norm(h, h.shape[-1:])
+ b[l - 1] = h
+ eta = [None] * (L + 1)
+ eta[0] = (b[0] - x) / N # batch-averaged loss gradient
+ for l in range(1, L + 1):
+ eta[l] = eta[l - 1] @ W[l - 1].T # (N, d_l)
+ grads_B = [None] * L
+ for l in range(L):
+ grads_B[l] = eta[l].T @ b[l + 1] # (d_l, d_{l+1})
+ rev_loss = (b[0] - x).pow(2).sum(-1).mean().item()
+ return grads_B, rev_loss
+
+
+def b_grads_recon(acts, B):
+ """Layer-local reconstruction: L_B^l = ||a[l+1] @ B[l].T - a[l]||^2 per layer (mean over batch).
+ dB[l] = (r.T @ a[l+1]) / N, r = a[l+1] @ B[l].T - a[l] (GD on that quadratic)
+ """
+ L = len(B)
+ N = acts[0].shape[0]
+ grads_B = [None] * L
+ total_loss = 0.0
+ for l in range(L):
+ X = acts[l + 1] # (N, d_{l+1})
+ Y = acts[l] # (N, d_l)
+ r = X @ B[l].T - Y # (N, d_l)
+ grads_B[l] = (r.T @ X) / N # (d_l, d_{l+1})
+ total_loss += r.pow(2).sum(-1).mean().item()
+ return grads_B, total_loss
+
+
+def alignment(W, B):
+ """Per-layer ||B[l] - W[l].T||_F / ||W[l]||_F."""
+ out = []
+ for l in range(len(W)):
+ diff = B[l] - W[l].T
+ out.append((diff.norm() / (W[l].norm() + 1e-9)).item())
+ return out
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--arm", required=True, choices=["bp", "fa", "ra_rev", "ra_recon", "ra_comb"])
+ p.add_argument("--lam", type=float, default=0.5, help="mixing coef for ra_comb (0=pure rev, 1=pure recon)")
+ p.add_argument("--epochs", type=int, default=10)
+ p.add_argument("--batch_size", type=int, default=128)
+ p.add_argument("--lr", type=float, default=0.05)
+ p.add_argument("--lr_b", type=float, default=None, help="separate LR for B (defaults to --lr)")
+ p.add_argument("--seed", type=int, default=42)
+ p.add_argument("--out", type=str, required=True)
+ p.add_argument("--data_dir", type=str, default="data/mnist")
+ p.add_argument("--no_ln", action="store_true", help="disable LayerNorm in forward")
+ p.add_argument("--log_every", type=int, default=100)
+ p.add_argument("--depth", type=int, default=4, help="number of weight matrices (>=2)")
+ p.add_argument("--hidden", type=int, default=256)
+ p.add_argument("--activation", choices=["linear", "relu", "tanh"], default="linear")
+ args = p.parse_args()
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ torch.manual_seed(args.seed)
+
+ out_dir = Path(args.out)
+ out_dir.mkdir(parents=True, exist_ok=True)
+ (out_dir / "config.json").write_text(json.dumps(vars(args), indent=2))
+ log_path = out_dir / "log.jsonl"
+ log_path.write_text("") # truncate
+
+ # Data
+ tfm = transforms.Compose([
+ transforms.ToTensor(),
+ transforms.Normalize((0.1307,), (0.3081,)),
+ transforms.Lambda(lambda t: t.view(-1)),
+ ])
+ train_ds = datasets.MNIST(args.data_dir, train=True, download=True, transform=tfm)
+ test_ds = datasets.MNIST(args.data_dir, train=False, download=True, transform=tfm)
+ train_loader = DataLoader(train_ds, batch_size=args.batch_size, shuffle=True, num_workers=2, pin_memory=True)
+ test_loader = DataLoader(test_ds, batch_size=512, shuffle=False, num_workers=2, pin_memory=True)
+
+ dims = make_dims(784, args.hidden, args.depth, 10)
+ W, B = init_mats(dims, device, args.seed, activation=args.activation)
+ lr_b = args.lr_b if args.lr_b is not None else args.lr
+ ln = not args.no_ln
+
+ t0 = time.time()
+
+ def log(rec):
+ rec["t"] = time.time() - t0
+ with open(log_path, "a") as f:
+ f.write(json.dumps(rec) + "\n")
+
+ log({"event": "start", "arm": args.arm, "dims": dims, "activation": args.activation,
+ "lr": args.lr, "lr_b": lr_b, "ln": ln, "lam": args.lam})
+ print(f"[{args.arm}] device={device} dims={dims} act={args.activation} ln={ln} lr={args.lr} lr_b={lr_b} lam={args.lam}")
+
+ def test_acc():
+ correct = 0
+ total = 0
+ with torch.no_grad():
+ for x, y in test_loader:
+ x = x.to(device, non_blocking=True)
+ y = y.to(device, non_blocking=True)
+ acts, _ = forward_w(x, W, activation=args.activation, ln=ln)
+ logits = acts[-1]
+ pred = logits.argmax(-1)
+ correct += (pred == y).sum().item()
+ total += y.shape[0]
+ return correct / total
+
+ step = 0
+ use_bp_for_W = (args.arm == "bp")
+ for epoch in range(args.epochs):
+ for x, y in train_loader:
+ x = x.to(device, non_blocking=True)
+ y = y.to(device, non_blocking=True)
+
+ acts, pre = forward_w(x, W, activation=args.activation, ln=ln)
+ logits = acts[-1]
+ ce = F.cross_entropy(logits, y).item()
+ N = x.shape[0]
+ probs = F.softmax(logits, dim=-1)
+ onehot = F.one_hot(y, num_classes=10).float()
+ grad_top = (probs - onehot) / N
+
+ # W grads (uses W, B snapshot)
+ gW = w_grads(acts, pre, grad_top, W, B, use_bp_for_W, activation=args.activation)
+
+ # B grads
+ gB = None
+ rev_loss = None
+ recon_loss = None
+ if args.arm == "ra_rev":
+ gB, rev_loss = b_grads_rev(acts, x, W, B)
+ elif args.arm == "ra_recon":
+ gB, recon_loss = b_grads_recon(acts, B)
+ elif args.arm == "ra_comb":
+ gB_rev, rev_loss = b_grads_rev(acts, x, W, B)
+ gB_rec, recon_loss = b_grads_recon(acts, B)
+ gB = [(1 - args.lam) * gB_rev[l] + args.lam * gB_rec[l] for l in range(len(B))]
+
+ # apply updates
+ for l in range(len(W)):
+ W[l] -= args.lr * gW[l]
+ if gB is not None:
+ for l in range(len(B)):
+ B[l] -= lr_b * gB[l]
+
+ step += 1
+ if step % args.log_every == 0:
+ align = alignment(W, B)
+ log({
+ "event": "step", "step": step, "epoch": epoch,
+ "loss_ce": ce,
+ "rev_loss": rev_loss, "recon_loss": recon_loss,
+ "alignment": align,
+ })
+
+ acc = test_acc()
+ align = alignment(W, B)
+ log({"event": "eval", "epoch": epoch, "step": step, "test_acc": acc, "alignment": align})
+ print(f"[{args.arm}] epoch {epoch:2d} step {step:5d} test_acc {acc:.4f} align {[f'{a:.3f}' for a in align]}")
+
+ log({"event": "done", "step": step, "final_acc": acc, "final_alignment": align})
+ print(f"[{args.arm}] done in {time.time() - t0:.1f}s final_acc={acc:.4f}")
+
+
+if __name__ == "__main__":
+ main()