summaryrefslogtreecommitdiff
path: root/ep_run/train_stiefel.py
diff options
context:
space:
mode:
Diffstat (limited to 'ep_run/train_stiefel.py')
-rw-r--r--ep_run/train_stiefel.py211
1 files changed, 211 insertions, 0 deletions
diff --git a/ep_run/train_stiefel.py b/ep_run/train_stiefel.py
new file mode 100644
index 0000000..0b218ff
--- /dev/null
+++ b/ep_run/train_stiefel.py
@@ -0,0 +1,211 @@
+"""Stiefel factored feedback training for local transformer.
+
+Replaces FA's random B with: δ_l = α_l · (e_L @ C^T) @ U_l^T
+where C is fixed row-orthonormal, U_l is per-layer learnable on Stiefel.
+
+Each block uses fused attention backward + GELU STE + center_scale LN.
+Head trained via detached CE loss. Embedding frozen.
+g_l reconstruction modules provide local proxy signal for U_l updates.
+"""
+import argparse
+import json
+import math
+import pickle
+import time
+from pathlib import Path
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from model_local import LocalGPTConfig
+from train_recon import ReconTransformer, get_batch, FeedbackModule
+from stiefel_feedback import StiefelFeedbackSystem
+
+
+def main():
+ p = argparse.ArgumentParser()
+ p.add_argument("--run_name", type=str, required=True)
+ p.add_argument("--seed", type=int, default=1337)
+ p.add_argument("--data_dir", type=str, default="data/shakespeare_char")
+ p.add_argument("--out_dir", type=str, default="runs_local")
+ p.add_argument("--block_size", type=int, default=256)
+ p.add_argument("--batch_size", type=int, default=64)
+ p.add_argument("--n_layer", type=int, default=6)
+ p.add_argument("--n_head", type=int, default=6)
+ p.add_argument("--n_embd", type=int, default=384)
+ p.add_argument("--dropout", type=float, default=0.2)
+ p.add_argument("--max_iters", type=int, default=5000)
+ p.add_argument("--warmup_iters", type=int, default=100)
+ p.add_argument("--max_lr", type=float, default=1e-3)
+ p.add_argument("--min_lr", type=float, default=1e-4)
+ p.add_argument("--rank", type=int, default=128)
+ p.add_argument("--eta_B", type=float, default=3e-5)
+ p.add_argument("--freeze_fb_steps", type=int, default=200)
+ p.add_argument("--sigma_recon", type=float, default=0.1)
+ p.add_argument("--eta_target", type=float, default=0.1)
+ p.add_argument("--eval_interval", type=int, default=250)
+ p.add_argument("--eval_iters", type=int, default=100)
+ p.add_argument("--log_interval", type=int, default=50)
+ p.add_argument("--attn_mode", type=str, default="softmax")
+ args = p.parse_args()
+
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed_all(args.seed)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ data_dir = Path(args.data_dir)
+ with open(data_dir / "meta.pkl", "rb") as f:
+ meta = pickle.load(f)
+ vocab_size = meta["vocab_size"]
+
+ run_dir = Path(args.out_dir) / args.run_name
+ run_dir.mkdir(parents=True, exist_ok=True)
+ log_path = run_dir / "log.jsonl"
+ log_path.write_text("")
+
+ cfg = LocalGPTConfig(
+ block_size=args.block_size, vocab_size=vocab_size,
+ n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd,
+ dropout=args.dropout, attn_mode=args.attn_mode,
+ method="bp", # intra-block standard autograd with fused attention
+ fuse_attn_local=True, ste_gelu=True, ln_mode="center_scale",
+ )
+ model = ReconTransformer(cfg).to(device)
+
+ # Stiefel feedback system
+ layer_dims = [args.n_embd] * args.n_layer # each block output is d_model
+ fb_system = StiefelFeedbackSystem(vocab_size, layer_dims, rank=args.rank).to(device)
+
+ n_params = sum(p.numel() for p in model.parameters())
+ n_fb = sum(p.numel() for p in fb_system.parameters())
+
+ # Optimizers: forward (blocks + head), feedback g_l (reconstruction)
+ forward_params = list(model.head.parameters()) + list(model.ln_f.parameters())
+ for block in model.blocks:
+ forward_params.extend(block.parameters())
+ feedback_g_params = list(model.feedbacks.parameters())
+
+ opt_fwd = torch.optim.AdamW(forward_params, lr=args.max_lr, weight_decay=0.1)
+ opt_fb_g = torch.optim.AdamW(feedback_g_params, lr=args.max_lr, weight_decay=0.01)
+ # U_l and α_l are updated manually via Stiefel retraction, not via optimizer
+
+ t0 = time.time()
+
+ def log(rec):
+ rec["t"] = time.time() - t0
+ with open(log_path, "a") as f:
+ f.write(json.dumps(rec) + "\n")
+
+ log({"event": "start", "method": "stiefel_factored", "params": n_params,
+ "fb_params": n_fb, "rank": args.rank, "config": vars(args)})
+ print(f"[{args.run_name}] stiefel factored, params={n_params/1e6:.2f}M, fb={n_fb/1e3:.1f}K, rank={args.rank}")
+
+ def lr_schedule(it):
+ if it < args.warmup_iters:
+ return args.max_lr * (it + 1) / (args.warmup_iters + 1)
+ decay = 0.5 * (1 + math.cos(math.pi * (it - args.warmup_iters) /
+ max(1, args.max_iters - args.warmup_iters)))
+ return args.min_lr + decay * (args.max_lr - args.min_lr)
+
+ @torch.no_grad()
+ def eval_loss():
+ model.eval()
+ losses = torch.zeros(args.eval_iters)
+ for k in range(args.eval_iters):
+ X, Y = get_batch("val", data_dir, args.block_size, args.batch_size, device)
+ acts = model.forward_activations(X)
+ logits = model.logits_from_h(acts[-1])
+ loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1))
+ losses[k] = loss.item()
+ model.train()
+ return losses.mean().item()
+
+ model.train()
+ for it in range(args.max_iters + 1):
+ lr = lr_schedule(it)
+ for g in opt_fwd.param_groups:
+ g["lr"] = lr
+
+ if it % args.eval_interval == 0 or it == args.max_iters:
+ val = eval_loss()
+ log({"event": "eval", "iter": it, "val_loss": val, "lr": lr})
+ print(f"[{args.run_name}] iter {it:5d} val {val:.4f} lr {lr:.4g}")
+
+ if it == args.max_iters:
+ break
+
+ X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device)
+
+ # 1. Forward pass
+ activations = model.forward_activations(X)
+ logits = model.logits_from_h(activations[-1])
+ ce_loss = F.cross_entropy(logits.view(-1, vocab_size), Y.view(-1))
+
+ # 2. Compute e_L and compress
+ with torch.no_grad():
+ probs = F.softmax(logits.detach(), dim=-1)
+ onehot = F.one_hot(Y, num_classes=vocab_size).float()
+ e_L = (probs - onehot) / Y.numel()
+ c = fb_system.compress_error(e_L)
+
+ # 3. Compute per-layer δ via Stiefel feedback
+ deltas = fb_system.compute_deltas(c)
+
+ # 4. Train g_l (reconstruction feedback modules)
+ opt_fb_g.zero_grad()
+ recon_loss = model.reconstruction_loss(activations, sigma=args.sigma_recon)
+ recon_loss.backward()
+ opt_fb_g.step()
+
+ # 5. Get local proxy signals g_hat_l from reconstruction modules
+ g_hats = []
+ for l in range(cfg.n_layer):
+ with torch.no_grad():
+ h_l = activations[l].detach()
+ h_lp1 = activations[l + 1].detach()
+ g_hat_l = model.feedbacks[l](h_lp1) - h_l # reconstruction error
+ g_hats.append(g_hat_l)
+
+ # 6. Update Stiefel feedback (U_l, α_l)
+ frozen = (it < args.freeze_fb_steps)
+ fb_diags = fb_system.update_all(g_hats, c, frozen=frozen, eta_B=args.eta_B)
+
+ # 7. Train forward weights via block-local loss using Stiefel δ as targets
+ opt_fwd.zero_grad()
+
+ # 7a. Head via detached CE
+ h_L_det = activations[-1].detach()
+ logits_head = model.logits_from_h(h_L_det)
+ head_loss = F.cross_entropy(logits_head.view(-1, vocab_size), Y.view(-1))
+ head_loss.backward()
+
+ # 7b. Each block: local target = h_l + δ_l (feedback signal as target displacement)
+ for l in range(cfg.n_layer):
+ h_l = activations[l] if l == 0 else activations[l].detach()
+ h_lp1 = activations[l + 1].detach()
+ # Target for block l's output: current output + δ_l displacement
+ target_lp1 = h_lp1 - deltas[l].detach() # push toward lower loss
+ h_lp1_pred = model.blocks[l](h_l)
+ block_loss = F.mse_loss(h_lp1_pred, target_lp1)
+ block_loss.backward()
+
+ torch.nn.utils.clip_grad_norm_(forward_params, 1.0)
+ opt_fwd.step()
+
+ if it % args.log_interval == 0:
+ fb_info = {}
+ if not frozen and fb_diags:
+ fb_info = {
+ "alpha_mean": sum(d.get("alpha", 0) for d in fb_diags) / len(fb_diags),
+ "rho_mean": sum(d.get("rho", 0) for d in fb_diags) / len(fb_diags),
+ "Delta_frob_mean": sum(d.get("Delta_frob", 0) for d in fb_diags) / len(fb_diags),
+ }
+ log({"event": "step", "iter": it, "ce_loss": ce_loss.item(),
+ "recon_loss": recon_loss.item(), "head_loss": head_loss.item(),
+ "frozen": frozen, **fb_info, "lr": lr})
+
+
+if __name__ == "__main__":
+ main()