"""Local-learning sweep training on Shakespeare char LM with sigmoid transformer. Supported methods (via --method): bp standard backprop (reference baseline) fa Feedback Alignment: per-LocalLinear random fixed B replaces W.T in backward sign_sym Sign-symmetric: per-LocalLinear sign(W) replaces W.T in backward dfa Direct Feedback Alignment: each LocalLinear's .grad is overwritten with (B_dfa @ e_L) outer (cached input). Embeddings/LN retain BP gradients. Reuses data/shakespeare_char/*.bin from Phase 1. """ import argparse import json import math import os import pickle import time from pathlib import Path import numpy as np import torch import torch.nn.functional as F from local_layers import LocalLinear, apply_dfa_update, initialize_dfa_targets from model_local import LocalGPT, LocalGPTConfig def get_batch(split, data_dir, block_size, batch_size, device): fn = "train.bin" if split == "train" else "val.bin" data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") ix = torch.randint(len(data) - block_size - 1, (batch_size,)) x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) return x.to(device, non_blocking=True), y.to(device, non_blocking=True) @torch.no_grad() def estimate_loss(model, data_dir, block_size, batch_size, device, eval_iters): out = {} model.eval() for split in ("train", "val"): losses = torch.zeros(eval_iters) for k in range(eval_iters): X, Y = get_batch(split, data_dir, block_size, batch_size, device) _, loss = model(X, Y) losses[k] = loss.item() out[split] = losses.mean().item() model.train() return out def lr_schedule(it, warmup, decay_iters, max_lr, min_lr): if it < warmup: return max_lr * (it + 1) / (warmup + 1) if it > decay_iters: return min_lr coeff = 0.5 * (1.0 + math.cos(math.pi * (it - warmup) / max(1, decay_iters - warmup))) return min_lr + coeff * (max_lr - min_lr) def compute_analytical_e_L(logits, targets, vocab_size): """Closed-form gradient of mean cross-entropy w.r.t. logits. logits: (B, T, V), targets: (B, T). Returns e_L shape (B, T, V). For CE with reduction='mean' over (B*T): dL/dlogits = (softmax(logits) - onehot(y)) / (B*T) """ probs = F.softmax(logits, dim=-1) onehot = F.one_hot(targets, num_classes=vocab_size).float() N = targets.numel() return (probs - onehot) / N def compute_alignment_diagnostics(model, batch_x, batch_y, method, vocab_size): """Per-LocalLinear gradient cosine to BP + (FA only) ‖B − W‖_F / ‖W‖_F. Cosine signs the "functional alignment": 1 = method grad matches BP direction, 0 = orthogonal, negative = wrong direction. Per-layer dict keyed by module name. Two forward-backward passes: first in the method's own mode (grabs method grads), second with all LocalLinear temporarily switched to 'bp' (grabs BP grads). Restores method before returning. Runs in eval mode to disable dropout so both passes see identical activations (otherwise BP would show cosine < 1 vs itself). """ out = {"grad_cos": {}, "fa_offset": {}} was_training = model.training model.eval() # --- Pass 1: method's backward --- model.zero_grad(set_to_none=True) logits, loss = model(batch_x, batch_y) loss.backward() if method == "dfa": with torch.no_grad(): e_L = compute_analytical_e_L(logits.detach(), batch_y, vocab_size) apply_dfa_update(model, e_L) method_grads = {} for name, m in model.named_modules(): if isinstance(m, LocalLinear) and m.weight.grad is not None: method_grads[name] = m.weight.grad.detach().clone() # FA-specific metric: distance between fixed B and current W if method == "fa": for name, m in model.named_modules(): if isinstance(m, LocalLinear) and m.method == "fa": diff = m.B - m.weight out["fa_offset"][name] = (diff.norm() / (m.weight.norm() + 1e-9)).item() # --- Pass 2: BP backward via temporary method switch --- method_backup = {} for m in model.modules(): if isinstance(m, LocalLinear): method_backup[id(m)] = m.method m.method = "bp" model.zero_grad(set_to_none=True) _, loss_bp = model(batch_x, batch_y) loss_bp.backward() for name, m in model.named_modules(): if isinstance(m, LocalLinear) and name in method_grads: g_bp = m.weight.grad.detach() g_method = method_grads[name] cos = F.cosine_similarity( g_bp.flatten().unsqueeze(0), g_method.flatten().unsqueeze(0), ).item() out["grad_cos"][name] = cos # Restore method for m in model.modules(): if isinstance(m, LocalLinear) and id(m) in method_backup: m.method = method_backup[id(m)] model.zero_grad(set_to_none=True) if was_training: model.train() return out def main(): p = argparse.ArgumentParser() p.add_argument("--method", required=True, choices=["bp", "fa", "dfa", "sign_sym"]) p.add_argument("--run_name", type=str, required=True) p.add_argument("--seed", type=int, default=1337) p.add_argument("--data_dir", type=str, default="data/shakespeare_char") p.add_argument("--out_dir", type=str, default="runs_local") p.add_argument("--block_size", type=int, default=256) p.add_argument("--batch_size", type=int, default=64) p.add_argument("--n_layer", type=int, default=6) p.add_argument("--n_head", type=int, default=6) p.add_argument("--n_embd", type=int, default=384) p.add_argument("--dropout", type=float, default=0.2) p.add_argument("--max_iters", type=int, default=5000) p.add_argument("--warmup_iters", type=int, default=100) p.add_argument("--lr_decay_iters", type=int, default=5000) p.add_argument("--max_lr", type=float, default=1e-3) p.add_argument("--min_lr", type=float, default=1e-4) p.add_argument("--weight_decay", type=float, default=0.1) p.add_argument("--beta1", type=float, default=0.9) p.add_argument("--beta2", type=float, default=0.99) p.add_argument("--grad_clip", type=float, default=1.0) p.add_argument("--eval_interval", type=int, default=250) p.add_argument("--eval_iters", type=int, default=100) p.add_argument("--log_interval", type=int, default=50) p.add_argument("--attn_mode", type=str, default="sigmoid", choices=["softmax", "sigmoid"]) p.add_argument("--sigmoid_bias_mode", type=str, default="neg_log_n") p.add_argument("--ste_sigmoid", action="store_true", help="STE on sigmoid attention (skip A(1-A) derivative)") p.add_argument("--ste_gelu", action="store_true", help="STE on GELU (skip gelu' derivative)") p.add_argument("--ln_mode", type=str, default="bp", choices=["bp", "ste", "center_scale", "projected"], help="LN backward: bp=standard, ste=identity, center_scale=mean-center+1/σ, projected=full surrogate") p.add_argument("--freeze_emb", action="store_true", help="Freeze token + position embeddings") p.add_argument("--fuse_attn_local", action="store_true", help="Fuse softmax+A@V with local backward (no lateral sum)") args = p.parse_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) device = "cuda" if torch.cuda.is_available() else "cpu" data_dir = Path(args.data_dir) with open(data_dir / "meta.pkl", "rb") as f: meta = pickle.load(f) vocab_size = meta["vocab_size"] run_dir = Path(args.out_dir) / args.run_name run_dir.mkdir(parents=True, exist_ok=True) log_path = run_dir / "log.jsonl" log_path.write_text("") with open(run_dir / "config.json", "w") as f: json.dump(vars(args) | {"vocab_size": vocab_size}, f, indent=2) cfg = LocalGPTConfig( block_size=args.block_size, vocab_size=vocab_size, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, dropout=args.dropout, attn_mode=args.attn_mode, sigmoid_bias_mode=args.sigmoid_bias_mode, method=args.method, ste_sigmoid=args.ste_sigmoid, ste_gelu=args.ste_gelu, ln_mode=args.ln_mode, freeze_emb=args.freeze_emb, fuse_attn_local=args.fuse_attn_local, ) model = LocalGPT(cfg).to(device) n_params = model.num_params() if args.method == "dfa": initialize_dfa_targets(model, vocab_size) # Build optimizer. For all methods, gather params with weight decay convention. decay_params, nodecay_params = [], [] for n, pr in model.named_parameters(): if not pr.requires_grad: continue if pr.dim() >= 2: decay_params.append(pr) else: nodecay_params.append(pr) optimizer = torch.optim.AdamW( [ {"params": decay_params, "weight_decay": args.weight_decay}, {"params": nodecay_params, "weight_decay": 0.0}, ], lr=args.max_lr, betas=(args.beta1, args.beta2), fused=(device == "cuda"), ) t0 = time.time() def log(rec): rec["t"] = time.time() - t0 with open(log_path, "a") as f: f.write(json.dumps(rec) + "\n") n_localinear = sum(1 for m in model.modules() if isinstance(m, LocalLinear)) log({ "event": "start", "method": args.method, "params": n_params, "n_localinear": n_localinear, "vocab_size": vocab_size, "config": vars(args), }) print(f"[{args.run_name}] method={args.method} params={n_params/1e6:.2f}M LocalLinear={n_localinear}") model.train() for it in range(args.max_iters + 1): lr = lr_schedule(it, args.warmup_iters, args.lr_decay_iters, args.max_lr, args.min_lr) for g in optimizer.param_groups: g["lr"] = lr if it % args.eval_interval == 0 or it == args.max_iters: losses = estimate_loss(model, data_dir, args.block_size, args.batch_size, device, args.eval_iters) # Alignment diagnostic on a fresh training batch X_diag, Y_diag = get_batch("train", data_dir, args.block_size, args.batch_size, device) align = compute_alignment_diagnostics(model, X_diag, Y_diag, args.method, vocab_size) log({ "event": "eval", "iter": it, "train_loss": losses["train"], "val_loss": losses["val"], "lr": lr, "grad_cos": align["grad_cos"], "fa_offset": align["fa_offset"], }) # Summary for print if align["grad_cos"]: cos_vals = list(align["grad_cos"].values()) cos_mean = sum(cos_vals) / len(cos_vals) cos_min = min(cos_vals) print(f"[{args.run_name}] iter {it:5d} train {losses['train']:.4f} val {losses['val']:.4f} " f"grad_cos μ={cos_mean:.3f} min={cos_min:.3f} lr {lr:.4g}") else: print(f"[{args.run_name}] iter {it:5d} train {losses['train']:.4f} val {losses['val']:.4f} lr {lr:.4g}") if it == args.max_iters: break X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) logits, loss = model(X, Y) optimizer.zero_grad(set_to_none=True) loss.backward() if args.method == "dfa": # Overwrite LocalLinear .grad with DFA-computed updates (using cached inputs from forward) with torch.no_grad(): e_L = compute_analytical_e_L(logits.detach(), Y, vocab_size) apply_dfa_update(model, e_L) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() if it % args.log_interval == 0: log({"event": "step", "iter": it, "train_loss": loss.item(), "lr": lr}) log({"event": "done", "iter": args.max_iters}) print(f"[{args.run_name}] done in {time.time()-t0:.1f}s") if __name__ == "__main__": main()