"""Train a tiny char-level GPT on tinyshakespeare with softmax or sigmoid attn. Logs train/val loss to runs//log.jsonl (one JSON per line). No checkpoints are saved (disk-conscious). """ import argparse import json import math import os import pickle import time from pathlib import Path import numpy as np import torch from model import GPT, GPTConfig def get_batch(split: str, data_dir: Path, block_size: int, batch_size: int, device: str): fn = "train.bin" if split == "train" else "val.bin" data = np.memmap(data_dir / fn, dtype=np.uint16, mode="r") ix = torch.randint(len(data) - block_size - 1, (batch_size,)) x = torch.stack([torch.from_numpy(data[i : i + block_size].astype(np.int64)) for i in ix]) y = torch.stack([torch.from_numpy(data[i + 1 : i + 1 + block_size].astype(np.int64)) for i in ix]) return x.to(device, non_blocking=True), y.to(device, non_blocking=True) @torch.no_grad() def estimate_loss(model, data_dir, block_size, batch_size, device, eval_iters): out = {} model.eval() for split in ("train", "val"): losses = torch.zeros(eval_iters) for k in range(eval_iters): X, Y = get_batch(split, data_dir, block_size, batch_size, device) _, loss = model(X, Y) losses[k] = loss.item() out[split] = losses.mean().item() model.train() return out def lr_schedule(it, warmup_iters, lr_decay_iters, max_lr, min_lr): if it < warmup_iters: return max_lr * (it + 1) / (warmup_iters + 1) if it > lr_decay_iters: return min_lr decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters) coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) return min_lr + coeff * (max_lr - min_lr) def main(): p = argparse.ArgumentParser() p.add_argument("--run_name", type=str, required=True) p.add_argument("--attn_mode", type=str, default="softmax", choices=["softmax", "sigmoid"]) p.add_argument("--sigmoid_bias_mode", type=str, default="neg_log_n", choices=["zero", "neg_log_n", "learned"]) p.add_argument("--seed", type=int, default=1337) p.add_argument("--data_dir", type=str, default="data/shakespeare_char") p.add_argument("--out_dir", type=str, default="runs") p.add_argument("--block_size", type=int, default=256) p.add_argument("--batch_size", type=int, default=64) p.add_argument("--n_layer", type=int, default=6) p.add_argument("--n_head", type=int, default=6) p.add_argument("--n_embd", type=int, default=384) p.add_argument("--dropout", type=float, default=0.2) p.add_argument("--max_iters", type=int, default=5000) p.add_argument("--warmup_iters", type=int, default=100) p.add_argument("--lr_decay_iters", type=int, default=5000) p.add_argument("--max_lr", type=float, default=1e-3) p.add_argument("--min_lr", type=float, default=1e-4) p.add_argument("--weight_decay", type=float, default=0.1) p.add_argument("--beta1", type=float, default=0.9) p.add_argument("--beta2", type=float, default=0.99) p.add_argument("--grad_clip", type=float, default=1.0) p.add_argument("--eval_interval", type=int, default=250) p.add_argument("--eval_iters", type=int, default=100) p.add_argument("--log_interval", type=int, default=50) p.add_argument("--dtype", type=str, default="bfloat16", choices=["float32", "bfloat16", "float16"]) args = p.parse_args() torch.manual_seed(args.seed) torch.cuda.manual_seed_all(args.seed) device = "cuda" if torch.cuda.is_available() else "cpu" data_dir = Path(args.data_dir) with open(data_dir / "meta.pkl", "rb") as f: meta = pickle.load(f) vocab_size = meta["vocab_size"] run_dir = Path(args.out_dir) / args.run_name run_dir.mkdir(parents=True, exist_ok=True) log_path = run_dir / "log.jsonl" cfg_path = run_dir / "config.json" with open(cfg_path, "w") as f: json.dump(vars(args) | {"vocab_size": vocab_size}, f, indent=2) cfg = GPTConfig( block_size=args.block_size, vocab_size=vocab_size, n_layer=args.n_layer, n_head=args.n_head, n_embd=args.n_embd, dropout=args.dropout, attn_mode=args.attn_mode, sigmoid_bias_mode=args.sigmoid_bias_mode, ) model = GPT(cfg).to(device) n_params = model.num_params() # AdamW with weight-decay-free for 1D params (ln, embeddings, biases, sig_bias) decay_params, nodecay_params = [], [] for n, pr in model.named_parameters(): if not pr.requires_grad: continue if pr.dim() >= 2: decay_params.append(pr) else: nodecay_params.append(pr) optimizer = torch.optim.AdamW( [ {"params": decay_params, "weight_decay": args.weight_decay}, {"params": nodecay_params, "weight_decay": 0.0}, ], lr=args.max_lr, betas=(args.beta1, args.beta2), fused=(device == "cuda"), ) dtype_map = {"float32": torch.float32, "bfloat16": torch.bfloat16, "float16": torch.float16} amp_dtype = dtype_map[args.dtype] scaler = torch.amp.GradScaler("cuda", enabled=(args.dtype == "float16")) t0 = time.time() def log(record: dict): record["t"] = time.time() - t0 with open(log_path, "a") as f: f.write(json.dumps(record) + "\n") log({"event": "start", "params": n_params, "config": vars(args) | {"vocab_size": vocab_size}}) print(f"[{args.run_name}] params={n_params/1e6:.2f}M device={device} dtype={args.dtype}") model.train() for it in range(args.max_iters + 1): lr = lr_schedule(it, args.warmup_iters, args.lr_decay_iters, args.max_lr, args.min_lr) for g in optimizer.param_groups: g["lr"] = lr if it % args.eval_interval == 0 or it == args.max_iters: losses = estimate_loss(model, data_dir, args.block_size, args.batch_size, device, args.eval_iters) log({"event": "eval", "iter": it, "train_loss": losses["train"], "val_loss": losses["val"], "lr": lr}) print(f"[{args.run_name}] iter {it:5d} train {losses['train']:.4f} val {losses['val']:.4f} lr {lr:.4g}") if it == args.max_iters: break X, Y = get_batch("train", data_dir, args.block_size, args.batch_size, device) with torch.amp.autocast(device_type="cuda", dtype=amp_dtype, enabled=(device == "cuda")): _, loss = model(X, Y) optimizer.zero_grad(set_to_none=True) if args.dtype == "float16": scaler.scale(loss).backward() scaler.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) scaler.step(optimizer) scaler.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step() if it % args.log_interval == 0: log({"event": "step", "iter": it, "train_loss": loss.item(), "lr": lr}) log({"event": "done", "iter": args.max_iters}) print(f"[{args.run_name}] done in {time.time()-t0:.1f}s") if __name__ == "__main__": main()