diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-07-03 05:56:50 -0500 |
| commit | b83947778e2c776f757a07d4719b7ce961d7ed55 (patch) | |
| tree | b9cc01d7adda691d9156d9d04f4fb2f644674e96 /ep_run/train_stiefel.py | |
Initial commit: ept — backprop-free equilibrium transformer (EP)
Code (ep_run/), organized docs (docs/{method,campaign,hardware,outreach,paper}),
analysis scripts (scripts/), ONBOARDING.md entry point. Large data/checkpoints
git-ignored (share separately).
Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
Claude-Session: https://claude.ai/code/session_014FAPDWQ49M5Ye3NpTndTpn
Diffstat (limited to 'ep_run/train_stiefel.py')
| -rw-r--r-- | ep_run/train_stiefel.py | 211 |
1 files changed, 211 insertions, 0 deletions
diff --git a/ep_run/train_stiefel.py b/ep_run/train_stiefel.py new file mode 100644 index 0000000..0b218ff --- /dev/null +++ b/ep_run/train_stiefel.py @@ -0,0 +1,211 @@ +"""Stiefel factored feedback training for local transformer. + +Replaces FA's random B with: δ_l = α_l · (e_L @ C^T) @ U_l^T +where C is fixed row-orthonormal, U_l is per-layer learnable on Stiefel. + +Each block uses fused attention backward + GELU STE + center_scale LN. +Head trained via detached CE loss. Embedding frozen. +g_l reconstruction modules provide local proxy signal for U_l updates. +""" +import argparse +import json +import math +import pickle +import time +from pathlib import Path + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from model_local import LocalGPTConfig +from train_recon import ReconTransformer, get_batch, FeedbackModule +from stiefel_feedback import StiefelFeedbackSystem + + +def main(): + p = argparse.ArgumentParser() + p.add_argument("--run_name", type=str, required=True) + p.add_argument("--seed", type=int, default=1337) + p.add_argument("--data_dir", type=str, default="data/shakespeare_char") + p.add_argument("--out_dir", type=str, default="runs_local") + p.add_argument("--block_size", type=int, default=256) + p.add_argument("--batch_size", type=int, default=64) + p.add_argument("--n_layer", type=int, default=6) + p.add_argument("--n_head", type=int, default=6) + p.add_argument("--n_embd", type=int, default=384) + p.add_argument("--dropout", type=float, default=0.2) + p.add_argument("--max_iters", type=int, default=5000) + p.add_argument("--warmup_iters", type=int, default=100) + p.add_argument("--max_lr", type=float, default=1e-3) + p.add_argument("--min_lr", type=float, default=1e-4) + p.add_argument("--rank", type=int, default=128) + p.add_argument("--eta_B", type=float, default=3e-5) + p.add_argument("--freeze_fb_steps", type=int, default=200) + p.add_argument("--sigma_recon", type=float, default=0.1) + p.add_argument("--eta_target", type=float, default=0.1) + p.add_argument("--eval_interval", type=int, default=250) + p.add_argument("--eval_iters", type=int, default=100) + p.add_argument("--log_interval", type=int, default=50) + p.add_argument("--attn_mode", type=str, default="softmax") + args = p.parse_args() + + torch.manual_seed(args.seed) + torch.cuda.manual_seed_all(args.seed) + device = "cuda" if torch.cuda.is_available() else "cpu" + + data_dir = Path(args.data_dir) + with open(data_dir / "meta.pkl", "rb") as f: + meta = pickle.load(f) + vocab_size = meta["vocab_size"] + + run_dir = Path(args.out_dir) / args.run_name + run_dir.mkdir(parents=True, exist_ok=True) + log_path = run_dir / "log.jsonl" + log_path.write_text("") + + cfg = LocalGPTConfig( + block_size=args.block_size, vocab_size=vocab_size, + n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, + dropout=args.dropout, attn_mode=args.attn_mode, + method="bp", # intra-block standard autograd with fused attention + fuse_attn_local=True, ste_gelu=True, ln_mode="center_scale", + ) + model = ReconTransformer(cfg).to(device) + + # Stiefel feedback system + layer_dims = [args.n_embd] * args.n_layer # each block output is d_model + fb_system = StiefelFeedbackSystem(vocab_size, layer_dims, rank=args.rank).to(device) + + n_params = sum(p.numel() for p in model.parameters()) + n_fb = sum(p.numel() for p in fb_system.parameters()) + + # Optimizers: forward (blocks + head), feedback g_l (reconstruction) + forward_params = list(model.head.parameters()) + list(model.ln_f.parameters()) + for block in model.blocks: + forward_params.extend(block.parameters()) + feedback_g_params = list(model.feedbacks.parameters()) + + opt_fwd = torch.optim.AdamW(forward_params, lr=args.max_lr, weight_decay=0.1) + opt_fb_g = torch.optim.AdamW(feedback_g_params, lr=args.max_lr, weight_decay=0.01) + # U_l and α_l are updated manually via Stiefel retraction, not via optimizer + + t0 = time.time() + + def log(rec): + rec["t"] = time.time() - t0 + with open(log_path, "a") as f: + f.write(json.dumps(rec) + "\n") + + log({"event": "start", "method": "stiefel_factored", "params": n_params, + "fb_params": n_fb, "rank": args.rank, "config": vars(args)}) + print(f"[{args.run_name}] stiefel factored, params={n_params/1e6:.2f}M, fb={n_fb/1e3:.1f}K, rank={args.rank}") + + def lr_schedule(it): + if it < args.warmup_iters: + return args.max_lr * (it + 1) / (args.warmup_iters + 1) + decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) / + max(1, args.max_iters - args.warmup_iters))) + return args.min_lr + decay * (args.max_lr - args.min_lr) + + @torch.no_grad() + def eval_loss(): + model.eval() + losses = torch.zeros(args.eval_iters) + for k in range(args.eval_iters): + X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device) + acts = model.forward_activations(X) + logits = model.logits_from_h(acts[-1]) + loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) + losses[k] = loss.item() + model.train() + return losses.mean().item() + + model.train() + for it in range(args.max_iters + 1): + lr = lr_schedule(it) + for g in opt_fwd.param_groups: + g["lr"] = lr + + if it % args.eval_interval == 0 or it == args.max_iters: + val = eval_loss() + log({"event": "eval", "iter": it, "val_loss": val, "lr": lr}) + print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}") + + if it == args.max_iters: + break + + X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) + + # 1. Forward pass + activations = model.forward_activations(X) + logits = model.logits_from_h(activations[-1]) + ce_loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1)) + + # 2. Compute e_L and compress + with torch.no_grad(): + probs = F.softmax(logits.detach(), dim=-1) + onehot = F.one_hot(Y, num_classes=vocab_size).float() + e_L = (probs - onehot) / Y.numel() + c = fb_system.compress_error(e_L) + + # 3. Compute per-layer δ via Stiefel feedback + deltas = fb_system.compute_deltas(c) + + # 4. Train g_l (reconstruction feedback modules) + opt_fb_g.zero_grad() + recon_loss = model.reconstruction_loss(activations, sigma=args.sigma_recon) + recon_loss.backward() + opt_fb_g.step() + + # 5. Get local proxy signals g_hat_l from reconstruction modules + g_hats = [] + for l in range(cfg.n_layer): + with torch.no_grad(): + h_l = activations[l].detach() + h_lp1 = activations[l + 1].detach() + g_hat_l = model.feedbacks[l](h_lp1) - h_l # reconstruction error + g_hats.append(g_hat_l) + + # 6. Update Stiefel feedback (U_l, α_l) + frozen = (it < args.freeze_fb_steps) + fb_diags = fb_system.update_all(g_hats, c, frozen=frozen, eta_B=args.eta_B) + + # 7. Train forward weights via block-local loss using Stiefel δ as targets + opt_fwd.zero_grad() + + # 7a. Head via detached CE + h_L_det = activations[-1].detach() + logits_head = model.logits_from_h(h_L_det) + head_loss = F.cross_entropy(logits_head.view(-1, vocab_size), Y.view(-1)) + head_loss.backward() + + # 7b. Each block: local target = h_l + δ_l (feedback signal as target displacement) + for l in range(cfg.n_layer): + h_l = activations[l] if l == 0 else activations[l].detach() + h_lp1 = activations[l + 1].detach() + # Target for block l's output: current output + δ_l displacement + target_lp1 = h_lp1 - deltas[l].detach() # push toward lower loss + h_lp1_pred = model.blocks[l](h_l) + block_loss = F.mse_loss(h_lp1_pred, target_lp1) + block_loss.backward() + + torch.nn.utils.clip_grad_norm_(forward_params, 1.0) + opt_fwd.step() + + if it % args.log_interval == 0: + fb_info = {} + if not frozen and fb_diags: + fb_info = { + "alpha_mean": sum(d.get("alpha", 0) for d in fb_diags) / len(fb_diags), + "rho_mean": sum(d.get("rho", 0) for d in fb_diags) / len(fb_diags), + "Delta_frob_mean": sum(d.get("Delta_frob", 0) for d in fb_diags) / len(fb_diags), + } + log({"event": "step", "iter": it, "ce_loss": ce_loss.item(), + "recon_loss": recon_loss.item(), "head_loss": head_loss.item(), + "frozen": frozen, **fb_info, "lr": lr}) + + +if __name__ == "__main__": + main() |
