diff options
Diffstat (limited to 'ep_run/train_local.py')
| -rw-r--r-- | ep_run/train_local.py | 300 |
1 files changed, 300 insertions, 0 deletions
diff --git a/ep_run/train_local.py b/ep_run/train_local.py new file mode 100644 index 0000000..2519cf9 --- /dev/null +++ b/ep_run/train_local.py @@ -0,0 +1,300 @@ +"""Local-learning sweep training on Shakespeare char LM with sigmoid transformer. + +Supported methods (via --method): + bp standard backprop (reference baseline) + fa Feedback Alignment: per-LocalLinear random fixed B replaces W.T in backward + sign_sym Sign-symmetric: per-LocalLinear sign(W) replaces W.T in backward + dfa Direct Feedback Alignment: each LocalLinear's .grad is overwritten with + (B_dfa @ e_L) outer (cached input). Embeddings/LN retain BP gradients. + +Reuses data/shakespeare_char/*.bin from Phase 1. +""" +import argparse +import json +import math +import os +import pickle +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn.functional as F + +from local_layers import LocalLinear, apply_dfa_update, initialize_dfa_targets +from model_local import LocalGPT, LocalGPTConfig + + +def get_batch(split, data_dir, block_size, batch_size, device): + fn = "train.bin" if split == "train" else "val.bin" + data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") + ix = torch.randint(len(data) - block_size - 1, (batch_size,)) + x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) + y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) + return x.to(device, non_blocking=True), y.to(device, non_blocking=True) + + +@torch.no_grad() +def estimate_loss(model, data_dir, block_size, batch_size, device, eval_iters): + out = {} + model.eval() + for split in ("train", "val"): + losses = torch.zeros(eval_iters) + for k in range(eval_iters): + X, Y = get_batch(split, data_dir, block_size, batch_size, device) + _, loss = model(X, Y) + losses[k] = loss.item() + out[split] = losses.mean().item() + model.train() + return out + + +def lr_schedule(it, warmup, decay_iters, max_lr, min_lr): + if it < warmup: + return max_lr * (it + 1) / (warmup + 1) + if it > decay_iters: + return min_lr + coeff = 0.5 * (1.0 + math.cos(math.pi * (it - warmup) / max(1, decay_iters - warmup))) + return min_lr + coeff * (max_lr - min_lr) + + +def compute_analytical_e_L(logits, targets, vocab_size): + """Closed-form gradient of mean cross-entropy w.r.t. logits. + + logits: (B, T, V), targets: (B, T). Returns e_L shape (B, T, V). + For CE with reduction='mean' over (B*T): dL/dlogits = (softmax(logits) - onehot(y)) / (B*T) + """ + probs = F.softmax(logits, dim=-1) + onehot = F.one_hot(targets, num_classes=vocab_size).float() + N = targets.numel() + return (probs - onehot) / N + + +def compute_alignment_diagnostics(model, batch_x, batch_y, method, vocab_size): + """Per-LocalLinear gradient cosine to BP + (FA only) ‖B − W‖_F / ‖W‖_F. + + Cosine signs the "functional alignment": 1 = method grad matches BP direction, + 0 = orthogonal, negative = wrong direction. Per-layer dict keyed by module name. + + Two forward-backward passes: first in the method's own mode (grabs method grads), + second with all LocalLinear temporarily switched to 'bp' (grabs BP grads). + Restores method before returning. Runs in eval mode to disable dropout so both + passes see identical activations (otherwise BP would show cosine < 1 vs itself). + """ + out = {"grad_cos": {}, "fa_offset": {}} + was_training = model.training + model.eval() + + # --- Pass 1: method's backward --- + model.zero_grad(set_to_none=True) + logits, loss = model(batch_x, batch_y) + loss.backward() + if method == "dfa": + with torch.no_grad(): + e_L = compute_analytical_e_L(logits.detach(), batch_y, vocab_size) + apply_dfa_update(model, e_L) + + method_grads = {} + for name, m in model.named_modules(): + if isinstance(m, LocalLinear) and m.weight.grad is not None: + method_grads[name] = m.weight.grad.detach().clone() + + # FA-specific metric: distance between fixed B and current W + if method == "fa": + for name, m in model.named_modules(): + if isinstance(m, LocalLinear) and m.method == "fa": + diff = m.B - m.weight + out["fa_offset"][name] = (diff.norm() / (m.weight.norm() + 1e-9)).item() + + # --- Pass 2: BP backward via temporary method switch --- + method_backup = {} + for m in model.modules(): + if isinstance(m, LocalLinear): + method_backup[id(m)] = m.method + m.method = "bp" + + model.zero_grad(set_to_none=True) + _, loss_bp = model(batch_x, batch_y) + loss_bp.backward() + + for name, m in model.named_modules(): + if isinstance(m, LocalLinear) and name in method_grads: + g_bp = m.weight.grad.detach() + g_method = method_grads[name] + cos = F.cosine_similarity( + g_bp.flatten().unsqueeze(0), + g_method.flatten().unsqueeze(0), + ).item() + out["grad_cos"][name] = cos + + # Restore method + for m in model.modules(): + if isinstance(m, LocalLinear) and id(m) in method_backup: + m.method = method_backup[id(m)] + + model.zero_grad(set_to_none=True) + if was_training: + model.train() + return out + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--method", required=True, choices=["bp", "fa", "dfa", "sign_sym"]) + p.add_argument("--run_name", type=str, required=True) + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--data_dir", type=str, default="data/shakespeare_char") + p.add_argument("--out_dir", type=str, default="runs_local") + p.add_argument("--block_size", type=int, default=256) + p.add_argument("--batch_size", type=int, default=64) + p.add_argument("--n_layer", type=int, default=6) + p.add_argument("--n_head", type=int, default=6) + p.add_argument("--n_embd", type=int, default=384) + p.add_argument("--dropout", type=float, default=0.2) + p.add_argument("--max_iters", type=int, default=5000) + p.add_argument("--warmup_iters", type=int, default=100) + p.add_argument("--lr_decay_iters", type=int, default=5000) + p.add_argument("--max_lr", type=float, default=1e-3) + p.add_argument("--min_lr", type=float, default=1e-4) + p.add_argument("--weight_decay", type=float, default=0.1) + p.add_argument("--beta1", type=float, default=0.9) + p.add_argument("--beta2", type=float, default=0.99) + p.add_argument("--grad_clip", type=float, default=1.0) + p.add_argument("--eval_interval", type=int, default=250) + p.add_argument("--eval_iters", type=int, default=100) + p.add_argument("--log_interval", type=int, default=50) + p.add_argument("--attn_mode", type=str, default="sigmoid", choices=["softmax", "sigmoid"]) + p.add_argument("--sigmoid_bias_mode", type=str, default="neg_log_n") + p.add_argument("--ste_sigmoid", action="store_true", help="STE on sigmoid attention (skip A(1-A) derivative)") + p.add_argument("--ste_gelu", action="store_true", help="STE on GELU (skip gelu' derivative)") + p.add_argument("--ln_mode", type=str, default="bp", choices=["bp", "ste", "center_scale", "projected"], + help="LN backward: bp=standard, ste=identity, center_scale=mean-center+1/σ, projected=full surrogate") + p.add_argument("--freeze_emb", action="store_true", help="Freeze token + position embeddings") + p.add_argument("--fuse_attn_local", action="store_true", help="Fuse softmax+A@V with local backward (no lateral sum)") + args = p.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + data_dir = Path(args.data_dir) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + vocab_size = meta["vocab_size"] + + run_dir = Path(args.out_dir) / args.run_name + run_dir.mkdir(parents=True, exist_ok=True) + log_path = run_dir / "log.jsonl" + log_path.write_text("") + with open(run_dir / "config.json", "w") as f: + json.dump(vars(args) | {"vocab_size": vocab_size}, f, indent=2) + + cfg = LocalGPTConfig( + block_size=args.block_size, + vocab_size=vocab_size, + n_layer=args.n_layer, + n_head=args.n_head, + n_embd=args.n_embd, + dropout=args.dropout, + attn_mode=args.attn_mode, + sigmoid_bias_mode=args.sigmoid_bias_mode, + method=args.method, + ste_sigmoid=args.ste_sigmoid, + ste_gelu=args.ste_gelu, + ln_mode=args.ln_mode, + freeze_emb=args.freeze_emb, + fuse_attn_local=args.fuse_attn_local, + ) + model = LocalGPT(cfg).to(device) + n_params = model.num_params() + + if args.method == "dfa": + initialize_dfa_targets(model, vocab_size) + + # Build optimizer. For all methods, gather params with weight decay convention. + decay_params, nodecay_params = [], [] + for n, pr in model.named_parameters(): + if not pr.requires_grad: + continue + if pr.dim() >= 2: + decay_params.append(pr) + else: + nodecay_params.append(pr) + optimizer = torch.optim.AdamW( + [ + {"params": decay_params, "weight_decay": args.weight_decay}, + {"params": nodecay_params, "weight_decay": 0.0}, + ], + lr=args.max_lr, + betas=(args.beta1, args.beta2), + fused=(device == "cuda"), + ) + + t0 = time.time() + + def log(rec): + rec["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(rec) + "\n") + + n_localinear = sum(1 for m in model.modules() if isinstance(m, LocalLinear)) + log({ + "event": "start", "method": args.method, "params": n_params, + "n_localinear": n_localinear, "vocab_size": vocab_size, + "config": vars(args), + }) + print(f"[{args.run_name}] method={args.method} params={n_params/1e6:.2f}M LocalLinear={n_localinear}") + + model.train() + for it in range(args.max_iters + 1): + lr = lr_schedule(it, args.warmup_iters, args.lr_decay_iters, args.max_lr, args.min_lr) + for g in optimizer.param_groups: + g["lr"] = lr + + if it % args.eval_interval == 0 or it == args.max_iters: + losses = estimate_loss(model, data_dir, args.block_size, args.batch_size, device, args.eval_iters) + # Alignment diagnostic on a fresh training batch + X_diag, Y_diag = get_batch("train", data_dir, args.block_size, args.batch_size, device) + align = compute_alignment_diagnostics(model, X_diag, Y_diag, args.method, vocab_size) + log({ + "event": "eval", "iter": it, + "train_loss": losses["train"], "val_loss": losses["val"], "lr": lr, + "grad_cos": align["grad_cos"], "fa_offset": align["fa_offset"], + }) + # Summary for print + if align["grad_cos"]: + cos_vals = list(align["grad_cos"].values()) + cos_mean = sum(cos_vals) / len(cos_vals) + cos_min = min(cos_vals) + print(f"[{args.run_name}] iter {it:5d} train {losses['train']:.4f} val {losses['val']:.4f} " + f"grad_cos μ={cos_mean:.3f} min={cos_min:.3f} lr {lr:.4g}") + else: + print(f"[{args.run_name}] iter {it:5d} train {losses['train']:.4f} val {losses['val']:.4f} lr {lr:.4g}") + + if it == args.max_iters: + break + + X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) + logits, loss = model(X, Y) + + optimizer.zero_grad(set_to_none=True) + loss.backward() + + if args.method == "dfa": + # Overwrite LocalLinear .grad with DFA-computed updates (using cached inputs from forward) + with torch.no_grad(): + e_L = compute_analytical_e_L(logits.detach(), Y, vocab_size) + apply_dfa_update(model, e_L) + + torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) + optimizer.step() + + if it % args.log_interval == 0: + log({"event": "step", "iter": it, "train_loss": loss.item(), "lr": lr}) + + log({"event": "done", "iter": args.max_iters}) + print(f"[{args.run_name}] done in {time.time()-t0:.1f}s") + + +if __name__ == "__main__": + main() |
