"""Reconstruction-based (DTP-style) training for local transformer. Each transformer block l has: - Forward function f_l: h_l → h_{l+1} (standard transformer block) - Feedback module g_l: h_{l+1} → ĥ_l (learned reconstruction, linear) Training loop per step: 1. Forward pass: compute h_0, h_1, ..., h_L 2. Top target: target_L = h_L - η_target * ∂L/∂h_L 3. Propagate targets backward via g_l: target_l = h_l + g_l(target_{l+1}) - g_l(h_{l+1}) (difference target prop) 4. Train feedback g_l: minimize reconstruction loss (DRL-style with noise) 5. Train forward f_l: minimize ||f_l(h_l) - target_{l+1}||² (local loss) Within each block, attention uses fused backward, LN uses center_scale, GELU uses STE. No random matrices. No weight transport. No inter-block chain rule. """ import argparse import json import math import pickle import time from pathlib import Path import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from model_local import LocalGPT, LocalGPTConfig, SoftmaxValueMixLocalFn def get_batch(split, data_dir, block_size, batch_size, device): fn = "train.bin" if split == "train" else "val.bin" data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") ix = torch.randint(len(data) - block_size - 1, (batch_size,)) x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) return x.to(device, non_blocking=True), y.to(device, non_blocking=True) class FeedbackModule(nn.Module): """g_l: h_{l+1} → ĥ_l. Linear reconstruction module.""" def __init__(self, d_model): super().__init__() self.linear = nn.Linear(d_model, d_model, bias=False) nn.init.eye_(self.linear.weight) # init as identity (good starting point) def forward(self, h): return self.linear(h) class ReconTransformer(nn.Module): """Transformer with per-block feedback modules for reconstruction-based training.""" def __init__(self, config: LocalGPTConfig): super().__init__() self.config = config # Forward model (standard transformer) self.tok_emb = nn.Embedding(config.vocab_size, config.n_embd) self.pos_emb = nn.Embedding(config.block_size, config.n_embd) self.drop = nn.Dropout(config.dropout) # Import block class from model_local from model_local import LocalBlock self.blocks = nn.ModuleList([LocalBlock(config) for _ in range(config.n_layer)]) self.ln_f = nn.LayerNorm(config.n_embd) self.head = nn.Linear(config.n_embd, config.vocab_size, bias=False) # Feedback modules: one per block self.feedbacks = nn.ModuleList([ FeedbackModule(config.n_embd) for _ in range(config.n_layer) ]) self.apply(self._init_weights) # Match LocalGPT: scale down o_proj and mlp.proj for residual stream stability for pn, p in self.named_parameters(): if pn.endswith("o_proj.weight") or pn.endswith("mlp.proj.weight"): nn.init.normal_(p, mean=0.0, std=0.02 / math.sqrt(2 * config.n_layer)) def _init_weights(self, m): if isinstance(m, (nn.Linear, LocalLinear)): nn.init.normal_(m.weight, mean=0.0, std=0.02) if getattr(m, "bias", None) is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.Embedding): nn.init.normal_(m.weight, mean=0.0, std=0.02) def forward_activations(self, idx): """Forward pass, returning per-block activations h_0 ... h_L.""" B, T = idx.shape pos = torch.arange(T, device=idx.device) h = self.drop(self.tok_emb(idx) + self.pos_emb(pos)) activations = [h] for block in self.blocks: h = block(h) activations.append(h) return activations # len = n_layer + 1 def logits_from_h(self, h_final): """h_L → logits.""" return self.head(self.ln_f(h_final)) def compute_targets(self, activations, logits, targets_y, eta_target=0.1): """Compute per-block targets via difference target propagation. target_L = h_L - η * ∂L/∂h_L target_l = h_l + g_l(target_{l+1}) - g_l(h_{l+1}) """ h_L = activations[-1] # Compute ∂L/∂h_L (only need grad at the top, not full BP) h_L_for_grad = h_L.detach().requires_grad_(True) logits_local = self.head(self.ln_f(h_L_for_grad)) loss = F.cross_entropy(logits_local.view(-1, logits_local.size(-1)), targets_y.view(-1)) loss.backward() grad_h_L = h_L_for_grad.grad.detach() # Top target target = h_L.detach() - eta_target * grad_h_L targets_list = [None] * (self.config.n_layer + 1) targets_list[-1] = target # Propagate backward via feedback modules for l in range(self.config.n_layer - 1, -1, -1): h_l = activations[l].detach() h_lp1 = activations[l + 1].detach() target_lp1 = targets_list[l + 1] # Difference target propagation targets_list[l] = h_l + self.feedbacks[l](target_lp1) - self.feedbacks[l](h_lp1) return targets_list def reconstruction_loss(self, activations, sigma=0.1): """Train feedback modules via reconstruction loss (DRL-style with noise). For each block l: corrupt h_l, forward through block, reconstruct via g_l. """ total_loss = 0.0 for l in range(self.config.n_layer): h_l = activations[l].detach() h_lp1 = activations[l + 1].detach() # Add noise to h_l noise = torch.randn_like(h_l) * sigma h_l_noisy = h_l + noise # Forward through block (detached, just computing) with torch.no_grad(): h_lp1_noisy = self.blocks[l](h_l_noisy) # Reconstruct via feedback h_l_recon = self.feedbacks[l](h_lp1_noisy) # Difference correction: reconstruct the NOISE, not absolute position recon_target = h_l_noisy total_loss = total_loss + F.mse_loss(h_l_recon, recon_target) return total_loss / self.config.n_layer def local_forward_loss(self, activations, targets_list): """Per-block local loss: ||f_l(h_l) - target_{l+1}||². Gradients flow within each block (using fused attention backward etc.) but NOT across blocks (targets are detached). """ total_loss = 0.0 for l in range(self.config.n_layer): h_l = activations[l].detach() # detach: no inter-block gradient target_lp1 = targets_list[l + 1].detach() # Forward through block (WITH gradient for intra-block params) h_lp1_pred = self.blocks[l](h_l) # Local loss total_loss = total_loss + F.mse_loss(h_lp1_pred, target_lp1) return total_loss / self.config.n_layer def main(): p = argparse.ArgumentParser() p.add_argument("--run_name", type=str, required=True) p.add_argument("--seed", type=int, default=1337) p.add_argument("--data_dir", type=str, default="data/shakespeare_char") p.add_argument("--out_dir", type=str, default="runs_local") p.add_argument("--block_size", type=int, default=256) p.add_argument("--batch_size", type=int, default=64) p.add_argument("--n_layer", type=int, default=6) p.add_argument("--n_head", type=int, default=6) p.add_argument("--n_embd", type=int, default=384) p.add_argument("--dropout", type=float, default=0.2) p.add_argument("--max_iters", type=int, default=5000) p.add_argument("--warmup_iters", type=int, default=100) p.add_argument("--max_lr", type=float, default=1e-3) p.add_argument("--min_lr", type=float, default=1e-4) p.add_argument("--eta_target", type=float, default=0.1, help="target stepsize for top-layer target") p.add_argument("--sigma_recon", type=float, default=0.1, help="noise std for reconstruction loss") p.add_argument("--lr_feedback", type=float, default=1e-3, help="LR for feedback modules") p.add_argument("--eval_interval", type=int, default=250) p.add_argument("--eval_iters", type=int, default=100) p.add_argument("--log_interval", type=int, default=50) p.add_argument("--attn_mode", type=str, default="softmax") args = p.parse_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) device = "cuda" if torch.cuda.is_available() else "cpu" data_dir = Path(args.data_dir) with open(data_dir / "meta.pkl", "rb") as f: meta = pickle.load(f) vocab_size = meta["vocab_size"] run_dir = Path(args.out_dir) / args.run_name run_dir.mkdir(parents=True, exist_ok=True) log_path = run_dir / "log.jsonl" log_path.write_text("") cfg = LocalGPTConfig( block_size=args.block_size, vocab_size=vocab_size, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, dropout=args.dropout, attn_mode=args.attn_mode, method="bp", # intra-block uses standard autograd (with fused attention) fuse_attn_local=True, ste_gelu=True, ln_mode="center_scale", ) model = ReconTransformer(cfg).to(device) n_params = sum(p.numel() for p in model.parameters()) # Separate optimizers for forward and feedback forward_params = list(model.tok_emb.parameters()) + list(model.pos_emb.parameters()) + \ list(model.head.parameters()) + list(model.ln_f.parameters()) for block in model.blocks: forward_params.extend(block.parameters()) feedback_params = list(model.feedbacks.parameters()) opt_fwd = torch.optim.AdamW(forward_params, lr=args.max_lr, weight_decay=0.1) opt_fb = torch.optim.AdamW(feedback_params, lr=args.lr_feedback, weight_decay=0.01) t0 = time.time() def log(rec): rec["t"] = time.time() - t0 with open(log_path, "a") as f: f.write(json.dumps(rec) + "\n") log({"event": "start", "method": "reconstruction", "params": n_params, "config": vars(args)}) print(f"[{args.run_name}] recon transformer, params={n_params/1e6:.2f}M") def lr_schedule(it): if it < args.warmup_iters: return args.max_lr * (it + 1) / (args.warmup_iters + 1) decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) / max(1, args.max_iters - args.warmup_iters))) return args.min_lr + decay * (args.max_lr - args.min_lr) @torch.no_grad() def eval_loss(): model.eval() losses = torch.zeros(args.eval_iters) for k in range(args.eval_iters): X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device) acts = model.forward_activations(X) logits = model.logits_from_h(acts[-1]) loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) losses[k] = loss.item() model.train() return losses.mean().item() model.train() for it in range(args.max_iters + 1): lr = lr_schedule(it) for g in opt_fwd.param_groups: g["lr"] = lr if it % args.eval_interval == 0 or it == args.max_iters: val = eval_loss() log({"event": "eval", "iter": it, "val_loss": val, "lr": lr}) print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}") if it == args.max_iters: break X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) # Step 1: Forward pass (compute activations) activations = model.forward_activations(X) logits = model.logits_from_h(activations[-1]) ce_loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) # Step 2-3: Compute targets via DTP targets = model.compute_targets(activations, logits, Y, eta_target=args.eta_target) # Step 4: Train feedback modules (reconstruction loss) opt_fb.zero_grad() recon_loss = model.reconstruction_loss(activations, sigma=args.sigma_recon) recon_loss.backward() opt_fb.step() # Step 5: Train forward weights (no inter-block BP) opt_fwd.zero_grad() # 5a: Head + ln_f via CE loss on DETACHED h_L (gradient stays at top, no BP into blocks) h_L_det = activations[-1].detach() logits_head = model.logits_from_h(h_L_det) head_loss = F.cross_entropy(logits_head.view(-1, vocab_size), Y.view(-1)) head_loss.backward() # 5b: Block-local target-matching losses # Block 0: DON'T detach h_0 so embedding gets gradient from block 0's local loss for l in range(cfg.n_layer): h_l = activations[l] if l == 0 else activations[l].detach() target_lp1 = targets[l + 1].detach() h_lp1_pred = model.blocks[l](h_l) block_loss = F.mse_loss(h_lp1_pred, target_lp1) block_loss.backward() torch.nn.utils.clip_grad_norm_(forward_params, 1.0) opt_fwd.step() if it % args.log_interval == 0: log({"event": "step", "iter": it, "ce_loss": ce_loss.item(), "recon_loss": recon_loss.item(), "head_loss": head_loss.item(), "lr": lr}) if __name__ == "__main__": main()