summaryrefslogtreecommitdiff
path: root/ep_run/train_recon.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
committerYuren Hao <yurenh2@illinois.edu>2026-07-03 05:56:50 -0500
commitb83947778e2c776f757a07d4719b7ce961d7ed55 (patch)
treeb9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/train_recon.py
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}), analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints git-ignored (share separately). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com> Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/train_recon.py')
-rw-r--r--ep_run/train_recon.py322
1 files changed, 322 insertions, 0 deletions
diff --git a/ep_run/train_recon.py b/ep_run/train_recon.py
new file mode 100644
index 0000000..d180cb2
--- /dev/null
+++ b/ep_run/train_recon.py
@@ -0,0 +1,322 @@
+"""Reconstruction-based (DTP-style) training for local transformer.
+
+Each transformer block l has:
+ - Forward function f_l: h_l → h_{l+1} (standard transformer block)
+ - Feedback module g_l: h_{l+1} → ĥ_l (learned reconstruction, linear)
+
+Training loop per step:
+ 1. Forward pass: compute h_0, h_1, ..., h_L
+ 2. Top target: target_L = h_L - η_target * ∂L/∂h_L
+ 3. Propagate targets backward via g_l:
+ target_l = h_l + g_l(target_{l+1}) - g_l(h_{l+1}) (difference target prop)
+ 4. Train feedback g_l: minimize reconstruction loss (DRL-style with noise)
+ 5. Train forward f_l: minimize ||f_l(h_l) - target_{l+1}||² (local loss)
+ Within each block, attention uses fused backward, LN uses center_scale, GELU uses STE.
+
+No random matrices. No weight transport. No inter-block chain rule.
+"""
+import argparse
+import json
+import math
+import pickle
+
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from model_local import LocalGPT, LocalGPTConfig, SoftmaxValueMixLocalFn
+
+
+def get_batch(split, data_dir, block_size, batch_size, device):
+ fn = "train.bin" if split == "train" else "val.bin"
+ data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r")
+ ix = torch.randint(len(data) - block_size - 1, (batch_size,))
+ x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix])
+ y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix])
+ return x.to(device, non_blocking=True), y.to(device, non_blocking=True)
+
+
+class FeedbackModule(nn.Module):
+ """g_l: h_{l+1} → ĥ_l. Linear reconstruction module."""
+ def __init__(self, d_model):
+ super().__init__()
+ self.linear = nn.Linear(d_model, d_model, bias=False)
+ nn.init.eye_(self.linear.weight) # init as identity (good starting point)
+
+ def forward(self, h):
+ return self.linear(h)
+
+
+class ReconTransformer(nn.Module):
+ """Transformer with per-block feedback modules for reconstruction-based training."""
+
+ def __init__(self, config: LocalGPTConfig):
+ super().__init__()
+ self.config = config
+ # Forward model (standard transformer)
+ self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd)
+ self.pos_emb = nn.Embedding(config.block_size, config.n_embd)
+ self.drop = nn.Dropout(config.dropout)
+
+ # Import block class from model_local
+ from model_local import LocalBlock
+ self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)])
+ self.ln_f = nn.LayerNorm(config.n_embd)
+ self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
+
+ # Feedback modules: one per block
+ self.feedbacks = nn.ModuleList([
+ FeedbackModule(config.n_embd) for _ in range(config.n_layer)
+ ])
+
+ self.apply(self._init_weights)
+ # Match LocalGPT: scale down o_proj and mlp.proj for residual stream stability
+ for pn, p in self.named_parameters():
+ if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"):
+ nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer))
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Linear, LocalLinear)):
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
+ if getattr(m, "bias", None) is not None:
+ nn.init.zeros_(m.bias)
+ elif isinstance(m, nn.Embedding):
+ nn.init.normal_(m.weight, mean=0.0, std=0.02)
+
+ def forward_activations(self, idx):
+ """Forward pass, returning per-block activations h_0 ... h_L."""
+ B, T = idx.shape
+ pos = torch.arange(T, device=idx.device)
+ h = self.drop(self.tok_emb(idx) + self.pos_emb(pos))
+ activations = [h]
+ for block in self.blocks:
+ h = block(h)
+ activations.append(h)
+ return activations # len = n_layer + 1
+
+ def logits_from_h(self, h_final):
+ """h_L → logits."""
+ return self.head(self.ln_f(h_final))
+
+ def compute_targets(self, activations, logits, targets_y, eta_target=0.1):
+ """Compute per-block targets via difference target propagation.
+
+ target_L = h_L - η * ∂L/∂h_L
+ target_l = h_l + g_l(target_{l+1}) - g_l(h_{l+1})
+ """
+ h_L = activations[-1]
+ # Compute ∂L/∂h_L (only need grad at the top, not full BP)
+ h_L_for_grad = h_L.detach().requires_grad_(True)
+ logits_local = self.head(self.ln_f(h_L_for_grad))
+ loss = F.cross_entropy(logits_local.view(-1, logits_local.size(-1)), targets_y.view(-1))
+ loss.backward()
+ grad_h_L = h_L_for_grad.grad.detach()
+
+ # Top target
+ target = h_L.detach() - eta_target * grad_h_L
+ targets_list = [None] * (self.config.n_layer + 1)
+ targets_list[-1] = target
+
+ # Propagate backward via feedback modules
+ for l in range(self.config.n_layer - 1, -1, -1):
+ h_l = activations[l].detach()
+ h_lp1 = activations[l + 1].detach()
+ target_lp1 = targets_list[l + 1]
+ # Difference target propagation
+ targets_list[l] = h_l + self.feedbacks[l](target_lp1) - self.feedbacks[l](h_lp1)
+
+ return targets_list
+
+ def reconstruction_loss(self, activations, sigma=0.1):
+ """Train feedback modules via reconstruction loss (DRL-style with noise).
+
+ For each block l: corrupt h_l, forward through block, reconstruct via g_l.
+ """
+ total_loss = 0.0
+ for l in range(self.config.n_layer):
+ h_l = activations[l].detach()
+ h_lp1 = activations[l + 1].detach()
+ # Add noise to h_l
+ noise = torch.randn_like(h_l) * sigma
+ h_l_noisy = h_l + noise
+ # Forward through block (detached, just computing)
+ with torch.no_grad():
+ h_lp1_noisy = self.blocks[l](h_l_noisy)
+ # Reconstruct via feedback
+ h_l_recon = self.feedbacks[l](h_lp1_noisy)
+ # Difference correction: reconstruct the NOISE, not absolute position
+ recon_target = h_l_noisy
+ total_loss = total_loss + F.mse_loss(h_l_recon, recon_target)
+ return total_loss / self.config.n_layer
+
+ def local_forward_loss(self, activations, targets_list):
+ """Per-block local loss: ||f_l(h_l) - target_{l+1}||².
+
+ Gradients flow within each block (using fused attention backward etc.)
+ but NOT across blocks (targets are detached).
+ """
+ total_loss = 0.0
+ for l in range(self.config.n_layer):
+ h_l = activations[l].detach() # detach: no inter-block gradient
+ target_lp1 = targets_list[l + 1].detach()
+ # Forward through block (WITH gradient for intra-block params)
+ h_lp1_pred = self.blocks[l](h_l)
+ # Local loss
+ total_loss = total_loss + F.mse_loss(h_lp1_pred, target_lp1)
+ return total_loss / self.config.n_layer
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--run_name", type=str, required=True)
+ p.add_argument("--seed", type=int, default=1337)
+ p.add_argument("--data_dir", type=str, default="data/shakespeare_char")
+ p.add_argument("--out_dir", type=str, default="runs_local")
+ p.add_argument("--block_size", type=int, default=256)
+ p.add_argument("--batch_size", type=int, default=64)
+ p.add_argument("--n_layer", type=int, default=6)
+ p.add_argument("--n_head", type=int, default=6)
+ p.add_argument("--n_embd", type=int, default=384)
+ p.add_argument("--dropout", type=float, default=0.2)
+ p.add_argument("--max_iters", type=int, default=5000)
+ p.add_argument("--warmup_iters", type=int, default=100)
+ p.add_argument("--max_lr", type=float, default=1e-3)
+ p.add_argument("--min_lr", type=float, default=1e-4)
+ p.add_argument("--eta_target", type=float, default=0.1, help="target stepsize for top-layer target")
+ p.add_argument("--sigma_recon", type=float, default=0.1, help="noise std for reconstruction loss")
+ p.add_argument("--lr_feedback", type=float, default=1e-3, help="LR for feedback modules")
+ p.add_argument("--eval_interval", type=int, default=250)
+ p.add_argument("--eval_iters", type=int, default=100)
+ p.add_argument("--log_interval", type=int, default=50)
+ p.add_argument("--attn_mode", type=str, default="softmax")
+ args = p.parse_args()
+
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ data_dir = Path(args.data_dir)
+ with open(data_dir / "meta.pkl", "rb") as f:
+ meta = pickle.load(f)
+ vocab_size = meta["vocab_size"]
+
+ run_dir = Path(args.out_dir) / args.run_name
+ run_dir.mkdir(parents=True, exist_ok=True)
+ log_path = run_dir / "log.jsonl"
+ log_path.write_text("")
+
+ cfg = LocalGPTConfig(
+ block_size=args.block_size, vocab_size=vocab_size,
+ n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd,
+ dropout=args.dropout, attn_mode=args.attn_mode,
+ method="bp", # intra-block uses standard autograd (with fused attention)
+ fuse_attn_local=True,
+ ste_gelu=True,
+ ln_mode="center_scale",
+ )
+ model = ReconTransformer(cfg).to(device)
+ n_params = sum(p.numel() for p in model.parameters())
+
+ # Separate optimizers for forward and feedback
+ forward_params = list(model.tok_emb.parameters()) + list(model.pos_emb.parameters()) + \
+ list(model.head.parameters()) + list(model.ln_f.parameters())
+ for block in model.blocks:
+ forward_params.extend(block.parameters())
+
+ feedback_params = list(model.feedbacks.parameters())
+
+ opt_fwd = torch.optim.AdamW(forward_params, lr=args.max_lr, weight_decay=0.1)
+ opt_fb = torch.optim.AdamW(feedback_params, lr=args.lr_feedback, weight_decay=0.01)
+
+ t0 = time.time()
+
+ def log(rec):
+ rec["t"] = time.time() - t0
+ with open(log_path, "a") as f:
+ f.write(json.dumps(rec) + "\n")
+
+ log({"event": "start", "method": "reconstruction", "params": n_params, "config": vars(args)})
+ print(f"[{args.run_name}] recon transformer, params={n_params/1e6:.2f}M")
+
+ def lr_schedule(it):
+ if it < args.warmup_iters:
+ return args.max_lr * (it + 1) / (args.warmup_iters + 1)
+ decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) /
+ max(1, args.max_iters - args.warmup_iters)))
+ return args.min_lr + decay * (args.max_lr - args.min_lr)
+
+ @torch.no_grad()
+ def eval_loss():
+ model.eval()
+ losses = torch.zeros(args.eval_iters)
+ for k in range(args.eval_iters):
+ X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device)
+ acts = model.forward_activations(X)
+ logits = model.logits_from_h(acts[-1])
+ loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1))
+ losses[k] = loss.item()
+ model.train()
+ return losses.mean().item()
+
+ model.train()
+ for it in range(args.max_iters + 1):
+ lr = lr_schedule(it)
+ for g in opt_fwd.param_groups:
+ g["lr"] = lr
+
+ if it % args.eval_interval == 0 or it == args.max_iters:
+ val = eval_loss()
+ log({"event": "eval", "iter": it, "val_loss": val, "lr": lr})
+ print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}")
+
+ if it == args.max_iters:
+ break
+
+ X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device)
+
+ # Step 1: Forward pass (compute activations)
+ activations = model.forward_activations(X)
+ logits = model.logits_from_h(activations[-1])
+ ce_loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1))
+
+ # Step 2-3: Compute targets via DTP
+ targets = model.compute_targets(activations, logits, Y, eta_target=args.eta_target)
+
+ # Step 4: Train feedback modules (reconstruction loss)
+ opt_fb.zero_grad()
+ recon_loss = model.reconstruction_loss(activations, sigma=args.sigma_recon)
+ recon_loss.backward()
+ opt_fb.step()
+
+ # Step 5: Train forward weights (no inter-block BP)
+ opt_fwd.zero_grad()
+
+ # 5a: Head + ln_f via CE loss on DETACHED h_L (gradient stays at top, no BP into blocks)
+ h_L_det = activations[-1].detach()
+ logits_head = model.logits_from_h(h_L_det)
+ head_loss = F.cross_entropy(logits_head.view(-1, vocab_size), Y.view(-1))
+ head_loss.backward()
+
+ # 5b: Block-local target-matching losses
+ # Block 0: DON'T detach h_0 so embedding gets gradient from block 0's local loss
+ for l in range(cfg.n_layer):
+ h_l = activations[l] if l == 0 else activations[l].detach()
+ target_lp1 = targets[l + 1].detach()
+ h_lp1_pred = model.blocks[l](h_l)
+ block_loss = F.mse_loss(h_lp1_pred, target_lp1)
+ block_loss.backward()
+
+ torch.nn.utils.clip_grad_norm_(forward_params, 1.0)
+ opt_fwd.step()
+
+ if it % args.log_interval == 0:
+ log({"event": "step", "iter": it, "ce_loss": ce_loss.item(),
+ "recon_loss": recon_loss.item(), "head_loss": head_loss.item(), "lr": lr})
+
+
+if __name__ == "__main__":
+ main()