summaryrefslogtreecommitdiff
path: root/ep_run/train_local.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/train_local.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/train_local.py')
-rw-r--r--ep_run/train_local.py300
1 files changed, 300 insertions, 0 deletions
diff --git a/ep_run/train_local.py b/ep_run/train_local.py
new file mode 100644
index 0000000..2519cf9
--- /dev/null
+++ b/ep_run/train_local.py
@@ -0,0 +1,300 @@
+"""Local-learning sweep training on Shakespeare char LM with sigmoid transformer.
+
+Supported methods (via --method):
+ bp standard backprop (reference baseline)
+ fa Feedback Alignment: per-LocalLinear random fixed B replaces W.T in backward
+ sign_sym Sign-symmetric: per-LocalLinear sign(W) replaces W.T in backward
+ dfa Direct Feedback Alignment: each LocalLinear's .grad is overwritten with
+ (B_dfa @ e_L) outer (cached input). Embeddings/LN retain BP gradients.
+
+Reuses data/shakespeare_char/*.bin from Phase 1.
+"""
+import argparse
+import json
+import math
+import os
+import pickle
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from local_layers import LocalLinear, apply_dfa_update, initialize_dfa_targets
+from model_local import LocalGPT, LocalGPTConfig
+
+
+def get_batch(split, data_dir, block_size, batch_size, device):
+ fn = "train.bin" if split == "train" else "val.bin"
+ data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r")
+ ix = torch.randint(len(data) - block_size - 1, (batch_size,))
+ x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix])
+ y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix])
+ return x.to(device, non_blocking=True), y.to(device, non_blocking=True)
+
+
+@torch.no_grad()
+def estimate_loss(model, data_dir, block_size, batch_size, device, eval_iters):
+ out = {}
+ model.eval()
+ for split in ("train", "val"):
+ losses = torch.zeros(eval_iters)
+ for k in range(eval_iters):
+ X, Y = get_batch(split, data_dir, block_size, batch_size, device)
+ _, loss = model(X, Y)
+ losses[k] = loss.item()
+ out[split] = losses.mean().item()
+ model.train()
+ return out
+
+
+def lr_schedule(it, warmup, decay_iters, max_lr, min_lr):
+ if it < warmup:
+ return max_lr * (it + 1) / (warmup + 1)
+ if it > decay_iters:
+ return min_lr
+ coeff = 0.5 * (1.0 + math.cos(math.pi * (it - warmup) / max(1, decay_iters - warmup)))
+ return min_lr + coeff * (max_lr - min_lr)
+
+
+def compute_analytical_e_L(logits, targets, vocab_size):
+ """Closed-form gradient of mean cross-entropy w.r.t. logits.
+
+ logits: (B, T, V), targets: (B, T). Returns e_L shape (B, T, V).
+ For CE with reduction='mean' over (B*T): dL/dlogits = (softmax(logits) - onehot(y)) / (B*T)
+ """
+ probs = F.softmax(logits, dim=-1)
+ onehot = F.one_hot(targets, num_classes=vocab_size).float()
+ N = targets.numel()
+ return (probs - onehot) / N
+
+
+def compute_alignment_diagnostics(model, batch_x, batch_y, method, vocab_size):
+ """Per-LocalLinear gradient cosine to BP + (FA only) ‖B − W‖_F / ‖W‖_F.
+
+ Cosine signs the "functional alignment": 1 = method grad matches BP direction,
+ 0 = orthogonal, negative = wrong direction. Per-layer dict keyed by module name.
+
+ Two forward-backward passes: first in the method's own mode (grabs method grads),
+ second with all LocalLinear temporarily switched to 'bp' (grabs BP grads).
+ Restores method before returning. Runs in eval mode to disable dropout so both
+ passes see identical activations (otherwise BP would show cosine < 1 vs itself).
+ """
+ out = {"grad_cos": {}, "fa_offset": {}}
+ was_training = model.training
+ model.eval()
+
+ # --- Pass 1: method's backward ---
+ model.zero_grad(set_to_none=True)
+ logits, loss = model(batch_x, batch_y)
+ loss.backward()
+ if method == "dfa":
+ with torch.no_grad():
+ e_L = compute_analytical_e_L(logits.detach(), batch_y, vocab_size)
+ apply_dfa_update(model, e_L)
+
+ method_grads = {}
+ for name, m in model.named_modules():
+ if isinstance(m, LocalLinear) and m.weight.grad is not None:
+ method_grads[name] = m.weight.grad.detach().clone()
+
+ # FA-specific metric: distance between fixed B and current W
+ if method == "fa":
+ for name, m in model.named_modules():
+ if isinstance(m, LocalLinear) and m.method == "fa":
+ diff = m.B - m.weight
+ out["fa_offset"][name] = (diff.norm() / (m.weight.norm() + 1e-9)).item()
+
+ # --- Pass 2: BP backward via temporary method switch ---
+ method_backup = {}
+ for m in model.modules():
+ if isinstance(m, LocalLinear):
+ method_backup[id(m)] = m.method
+ m.method = "bp"
+
+ model.zero_grad(set_to_none=True)
+ _, loss_bp = model(batch_x, batch_y)
+ loss_bp.backward()
+
+ for name, m in model.named_modules():
+ if isinstance(m, LocalLinear) and name in method_grads:
+ g_bp = m.weight.grad.detach()
+ g_method = method_grads[name]
+ cos = F.cosine_similarity(
+ g_bp.flatten().unsqueeze(0),
+ g_method.flatten().unsqueeze(0),
+ ).item()
+ out["grad_cos"][name] = cos
+
+ # Restore method
+ for m in model.modules():
+ if isinstance(m, LocalLinear) and id(m) in method_backup:
+ m.method = method_backup[id(m)]
+
+ model.zero_grad(set_to_none=True)
+ if was_training:
+ model.train()
+ return out
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--method", required=True, choices=["bp", "fa", "dfa", "sign_sym"])
+ p.add_argument("--run_name", type=str, required=True)
+ p.add_argument("--seed", type=int, default=1337)
+ p.add_argument("--data_dir", type=str, default="data/shakespeare_char")
+ p.add_argument("--out_dir", type=str, default="runs_local")
+ p.add_argument("--block_size", type=int, default=256)
+ p.add_argument("--batch_size", type=int, default=64)
+ p.add_argument("--n_layer", type=int, default=6)
+ p.add_argument("--n_head", type=int, default=6)
+ p.add_argument("--n_embd", type=int, default=384)
+ p.add_argument("--dropout", type=float, default=0.2)
+ p.add_argument("--max_iters", type=int, default=5000)
+ p.add_argument("--warmup_iters", type=int, default=100)
+ p.add_argument("--lr_decay_iters", type=int, default=5000)
+ p.add_argument("--max_lr", type=float, default=1e-3)
+ p.add_argument("--min_lr", type=float, default=1e-4)
+ p.add_argument("--weight_decay", type=float, default=0.1)
+ p.add_argument("--beta1", type=float, default=0.9)
+ p.add_argument("--beta2", type=float, default=0.99)
+ p.add_argument("--grad_clip", type=float, default=1.0)
+ p.add_argument("--eval_interval", type=int, default=250)
+ p.add_argument("--eval_iters", type=int, default=100)
+ p.add_argument("--log_interval", type=int, default=50)
+ p.add_argument("--attn_mode", type=str, default="sigmoid", choices=["softmax", "sigmoid"])
+ p.add_argument("--sigmoid_bias_mode", type=str, default="neg_log_n")
+ p.add_argument("--ste_sigmoid", action="store_true", help="STE on sigmoid attention (skip A(1-A) derivative)")
+ p.add_argument("--ste_gelu", action="store_true", help="STE on GELU (skip gelu' derivative)")
+ p.add_argument("--ln_mode", type=str, default="bp", choices=["bp", "ste", "center_scale", "projected"],
+ help="LN backward: bp=standard, ste=identity, center_scale=mean-center+1/σ, projected=full surrogate")
+ p.add_argument("--freeze_emb", action="store_true", help="Freeze token + position embeddings")
+ p.add_argument("--fuse_attn_local", action="store_true", help="Fuse softmax+A@V with local backward (no lateral sum)")
+ args = p.parse_args()
+
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ data_dir = Path(args.data_dir)
+ with open(data_dir / "meta.pkl", "rb") as f:
+ meta = pickle.load(f)
+ vocab_size = meta["vocab_size"]
+
+ run_dir = Path(args.out_dir) / args.run_name
+ run_dir.mkdir(parents=True, exist_ok=True)
+ log_path = run_dir / "log.jsonl"
+ log_path.write_text("")
+ with open(run_dir / "config.json", "w") as f:
+ json.dump(vars(args) | {"vocab_size": vocab_size}, f, indent=2)
+
+ cfg = LocalGPTConfig(
+ block_size=args.block_size,
+ vocab_size=vocab_size,
+ n_layer=args.n_layer,
+ n_head=args.n_head,
+ n_embd=args.n_embd,
+ dropout=args.dropout,
+ attn_mode=args.attn_mode,
+ sigmoid_bias_mode=args.sigmoid_bias_mode,
+ method=args.method,
+ ste_sigmoid=args.ste_sigmoid,
+ ste_gelu=args.ste_gelu,
+ ln_mode=args.ln_mode,
+ freeze_emb=args.freeze_emb,
+ fuse_attn_local=args.fuse_attn_local,
+ )
+ model = LocalGPT(cfg).to(device)
+ n_params = model.num_params()
+
+ if args.method == "dfa":
+ initialize_dfa_targets(model, vocab_size)
+
+ # Build optimizer. For all methods, gather params with weight decay convention.
+ decay_params, nodecay_params = [], []
+ for n, pr in model.named_parameters():
+ if not pr.requires_grad:
+ continue
+ if pr.dim() >= 2:
+ decay_params.append(pr)
+ else:
+ nodecay_params.append(pr)
+ optimizer = torch.optim.AdamW(
+ [
+ {"params": decay_params, "weight_decay": args.weight_decay},
+ {"params": nodecay_params, "weight_decay": 0.0},
+ ],
+ lr=args.max_lr,
+ betas=(args.beta1, args.beta2),
+ fused=(device == "cuda"),
+ )
+
+ t0 = time.time()
+
+ def log(rec):
+ rec["t"] = time.time() - t0
+ with open(log_path, "a") as f:
+ f.write(json.dumps(rec) + "\n")
+
+ n_localinear = sum(1 for m in model.modules() if isinstance(m, LocalLinear))
+ log({
+ "event": "start", "method": args.method, "params": n_params,
+ "n_localinear": n_localinear, "vocab_size": vocab_size,
+ "config": vars(args),
+ })
+ print(f"[{args.run_name}] method={args.method} params={n_params/1e6:.2f}M LocalLinear={n_localinear}")
+
+ model.train()
+ for it in range(args.max_iters + 1):
+ lr = lr_schedule(it, args.warmup_iters, args.lr_decay_iters, args.max_lr, args.min_lr)
+ for g in optimizer.param_groups:
+ g["lr"] = lr
+
+ if it % args.eval_interval == 0 or it == args.max_iters:
+ losses = estimate_loss(model, data_dir, args.block_size, args.batch_size, device, args.eval_iters)
+ # Alignment diagnostic on a fresh training batch
+ X_diag, Y_diag = get_batch("train", data_dir, args.block_size, args.batch_size, device)
+ align = compute_alignment_diagnostics(model, X_diag, Y_diag, args.method, vocab_size)
+ log({
+ "event": "eval", "iter": it,
+ "train_loss": losses["train"], "val_loss": losses["val"], "lr": lr,
+ "grad_cos": align["grad_cos"], "fa_offset": align["fa_offset"],
+ })
+ # Summary for print
+ if align["grad_cos"]:
+ cos_vals = list(align["grad_cos"].values())
+ cos_mean = sum(cos_vals) / len(cos_vals)
+ cos_min = min(cos_vals)
+ print(f"[{args.run_name}] iter {it:5d} train {losses['train']:.4f} val {losses['val']:.4f} "
+ f"grad_cos μ={cos_mean:.3f} min={cos_min:.3f} lr {lr:.4g}")
+ else:
+ print(f"[{args.run_name}] iter {it:5d} train {losses['train']:.4f} val {losses['val']:.4f} lr {lr:.4g}")
+
+ if it == args.max_iters:
+ break
+
+ X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device)
+ logits, loss = model(X, Y)
+
+ optimizer.zero_grad(set_to_none=True)
+ loss.backward()
+
+ if args.method == "dfa":
+ # Overwrite LocalLinear .grad with DFA-computed updates (using cached inputs from forward)
+ with torch.no_grad():
+ e_L = compute_analytical_e_L(logits.detach(), Y, vocab_size)
+ apply_dfa_update(model, e_L)
+
+ torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip)
+ optimizer.step()
+
+ if it % args.log_interval == 0:
+ log({"event": "step", "iter": it, "train_loss": loss.item(), "lr": lr})
+
+ log({"event": "done", "iter": args.max_iters})
+ print(f"[{args.run_name}] done in {time.time()-t0:.1f}s")
+
+
+if __name__ == "__main__":
+ main()