summaryrefslogtreecommitdiff
path: root/train_gpt.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_gpt.py')
-rw-r--r--train_gpt.py282
1 files changed, 264 insertions, 18 deletions
diff --git a/train_gpt.py b/train_gpt.py
index 651beb2..85e2cc4 100644
--- a/train_gpt.py
+++ b/train_gpt.py
@@ -86,6 +86,13 @@ class Hyperparameters:
adam_eps = float(os.environ.get("ADAM_EPS", 1e-8))
grad_clip_norm = float(os.environ.get("GRAD_CLIP_NORM", 0.0))
+ # Test-time training (LoRA) hyperparameters.
+ ttt_lora_rank = int(os.environ.get("TTT_LORA_RANK", 8))
+ ttt_lora_lr = float(os.environ.get("TTT_LORA_LR", 0.01))
+ ttt_chunk_size = int(os.environ.get("TTT_CHUNK_SIZE", 256))
+ ttt_eval_seq_len = int(os.environ.get("TTT_EVAL_SEQ_LEN", 1024))
+ ttt_batch_size = int(os.environ.get("TTT_BATCH_SIZE", 64))
+
# -----------------------------
# MUON OPTIMIZER
# -----------------------------
@@ -580,11 +587,14 @@ class CausalSelfAttention(nn.Module):
self.q_gain = nn.Parameter(torch.full((num_heads,), qk_gain_init, dtype=torch.float32))
self.rotary = Rotary(self.head_dim, base=rope_base)
- def forward(self, x: Tensor) -> Tensor:
+ def forward(self, x: Tensor, q_delta=None, v_delta=None) -> Tensor:
bsz, seqlen, dim = x.shape
- q = self.c_q(x).reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2)
- k = self.c_k(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
- v = self.c_v(x).reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
+ q = self.c_q(x) + (q_delta if q_delta is not None else 0)
+ k = self.c_k(x)
+ v = self.c_v(x) + (v_delta if v_delta is not None else 0)
+ q = q.reshape(bsz, seqlen, self.num_heads, self.head_dim).transpose(1, 2)
+ k = k.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
+ v = v.reshape(bsz, seqlen, self.num_kv_heads, self.head_dim).transpose(1, 2)
q = F.rms_norm(q, (q.size(-1),))
k = F.rms_norm(k, (k.size(-1),))
cos, sin = self.rotary(seqlen, x.device, q.dtype)
@@ -636,10 +646,13 @@ class Block(nn.Module):
self.mlp_scale = nn.Parameter(torch.ones(dim, dtype=torch.float32))
self.resid_mix = nn.Parameter(torch.stack((torch.ones(dim), torch.zeros(dim))).float())
- def forward(self, x: Tensor, x0: Tensor) -> Tensor:
+ def forward(self, x: Tensor, x0: Tensor, q_delta_fn=None, v_delta_fn=None) -> Tensor:
mix = self.resid_mix.to(dtype=x.dtype)
x = mix[0][None, None, :] * x + mix[1][None, None, :] * x0
- attn_out = self.attn(self.attn_norm(x))
+ n = self.attn_norm(x)
+ qd = q_delta_fn(n) if q_delta_fn is not None else None
+ vd = v_delta_fn(n) if v_delta_fn is not None else None
+ attn_out = self.attn(n, qd, vd)
x = x + self.attn_scale.to(dtype=x.dtype)[None, None, :] * attn_out
x = x + self.mlp_scale.to(dtype=x.dtype)[None, None, :] * self.mlp(self.mlp_norm(x))
return x
@@ -697,7 +710,7 @@ class GPT(nn.Module):
if isinstance(module, nn.Linear) and getattr(module, "_zero_init", False):
nn.init.zeros_(module.weight)
- def forward(self, input_ids: Tensor, target_ids: Tensor) -> Tensor:
+ def forward(self, input_ids: Tensor, target_ids: Tensor, lora=None) -> Tensor:
x = self.tok_emb(input_ids)
x = F.rms_norm(x, (x.size(-1),))
x0 = x
@@ -705,24 +718,241 @@ class GPT(nn.Module):
# First half stores skips; second half reuses them in reverse order.
for i in range(self.num_encoder_layers):
- x = self.blocks[i](x, x0)
+ qd = lora.q_loras[i] if lora else None
+ vd = lora.v_loras[i] if lora else None
+ x = self.blocks[i](x, x0, qd, vd)
skips.append(x)
for i in range(self.num_decoder_layers):
+ bi = self.num_encoder_layers + i
if skips:
x = x + self.skip_weights[i].to(dtype=x.dtype)[None, None, :] * skips.pop()
- x = self.blocks[self.num_encoder_layers + i](x, x0)
-
- x = self.final_norm(x).reshape(-1, x.size(-1))
- targets = target_ids.reshape(-1)
+ qd = lora.q_loras[bi] if lora else None
+ vd = lora.v_loras[bi] if lora else None
+ x = self.blocks[bi](x, x0, qd, vd)
+ x = self.final_norm(x)
if self.tie_embeddings:
- logits_proj = F.linear(x, self.tok_emb.weight)
+ logits = F.linear(x, self.tok_emb.weight)
+ else:
+ logits = self.lm_head(x)
+ logits = logits + (lora.lm_head_lora(x) if lora else 0)
+ logits = self.logit_softcap * torch.tanh(logits / self.logit_softcap)
+ if lora:
+ bsz, sl, V = logits.shape
+ return F.cross_entropy(
+ logits.float().reshape(-1, V), target_ids.reshape(-1), reduction="none").reshape(bsz, sl)
+ return F.cross_entropy(logits.float().reshape(-1, logits.size(-1)), target_ids.reshape(-1), reduction="mean")
+
+
+# -----------------------------
+# TEST-TIME TRAINING (LoRA)
+# -----------------------------
+#
+# At evaluation time, we adapt per-document low-rank adapters on the validation data.
+# Each document gets its own adapter, so there is no inter-document dependency.
+
+BOS_ID = 1
+
+class BatchedLinearLoRA(nn.Module):
+ """LoRA for a linear layer, with independent weights per batch element.
+ Computes x @ Aᵀ @ Bᵀ = x @ (BA)ᵀ, i.e. the LoRA delta is ΔW = BA."""
+ def __init__(self, bsz: int, in_features: int, out_features: int, rank: int):
+ super().__init__()
+ self.in_features = in_features
+ self.A = nn.Parameter(torch.empty(bsz, rank, in_features)) # down-projection
+ self.B = nn.Parameter(torch.zeros(bsz, out_features, rank)) # up-projection
+ self.reset()
+
+ def forward(self, x: Tensor) -> Tensor:
+ return (x @ self.A.transpose(1, 2)) @ self.B.transpose(1, 2) # (bsz, T, out)
+
+ def reset(self) -> None:
+ bound = 1.0 / math.sqrt(self.in_features)
+ with torch.no_grad():
+ self.A.uniform_(-bound, bound) # kaiming-uniform
+ self.B.zero_()
+
+class BatchedTTTLoRA(nn.Module):
+ """All LoRA adapters for one batch: LM head and Q/V per block."""
+ def __init__(self, bsz: int, model: GPT, rank: int):
+ super().__init__()
+ dim = model.tok_emb.embedding_dim
+ vocab = model.tok_emb.num_embeddings
+ self.lm_head_lora = BatchedLinearLoRA(bsz, dim, vocab, rank)
+ self.q_loras = nn.ModuleList()
+ self.v_loras = nn.ModuleList()
+ for block in model.blocks:
+ self.q_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_q.weight.shape[0], rank))
+ self.v_loras.append(BatchedLinearLoRA(bsz, dim, block.attn.c_v.weight.shape[0], rank))
+
+ def reset(self) -> None:
+ for m in self.modules():
+ if isinstance(m, BatchedLinearLoRA):
+ m.reset()
+
+def _reset_ttt_optimizer(opt):
+ for group in opt.param_groups:
+ for p in group['params']:
+ s = opt.state.get(p)
+ if not s: # Fresh state.
+ continue
+ s['exp_avg'].zero_()
+ s['exp_avg_sq'].zero_()
+ s['step'].fill_(0)
+
+def _build_ttt_optimizer(lora, args: Hyperparameters):
+ return torch.optim.Adam(lora.parameters(), lr=args.ttt_lora_lr, betas=(args.beta1, args.beta2), eps=1e-10)
+
+def _find_docs(all_tokens: Tensor, include_next_bos: bool = True) -> list[tuple[int, int]]:
+ """Return (start_offset, length) for each document, identified by BOS boundaries.
+
+ If include_next_bos is True, include next document's BOS (to match continuous-stream
+ eval token count exactly).
+ """
+ bos_positions = (all_tokens == BOS_ID).nonzero(as_tuple=True)[0].numpy()
+ docs = []
+ for i in range(len(bos_positions)):
+ start = int(bos_positions[i])
+ end = int(bos_positions[i + 1]) if i + 1 < len(bos_positions) else all_tokens.numel()
+ if include_next_bos and i + 1 < len(bos_positions):
+ end += 1
+ assert end - start >= 2
+ docs.append((start, end - start))
+ return docs
+
+def _compute_chunk_window(ci: int, pred_len: int, num_chunks: int, chunk_size: int, eval_seq_len: int):
+ """Return (win_start, win_len, chunk_offset, chunk_len) for chunk `ci` of a doc."""
+ chunk_start = ci * chunk_size
+ chunk_end = pred_len if ci == num_chunks - 1 else (ci + 1) * chunk_size
+ win_start = max(0, chunk_end - eval_seq_len)
+ win_len = chunk_end - win_start
+ chunk_offset = chunk_start - win_start
+ chunk_len = chunk_end - chunk_start
+ return win_start, win_len, chunk_offset, chunk_len
+
+def _accumulate_bpb(
+ ptl: Tensor, x: Tensor, y: Tensor,
+ batch_i: int, chunk_offset: int, chunk_len: int,
+ base_bytes_lut: Tensor, has_leading_space_lut: Tensor, is_boundary_token_lut: Tensor,
+ loss_sum: Tensor, byte_sum: Tensor, token_count: Tensor,
+):
+ """Add one doc-chunk's contribution to the running BPB accumulators."""
+ lbl = ptl[batch_i, chunk_offset:chunk_offset + chunk_len].to(torch.float64)
+ prev = x[batch_i, chunk_offset:chunk_offset + chunk_len]
+ tgt = y[batch_i, chunk_offset:chunk_offset + chunk_len]
+ tok_bytes = base_bytes_lut[tgt].to(torch.float64)
+ tok_bytes += has_leading_space_lut[tgt] & ~is_boundary_token_lut[prev]
+ loss_sum += lbl.sum()
+ byte_sum += tok_bytes.sum()
+ token_count += chunk_len
+
+def eval_val_ttt_lora(
+ args: Hyperparameters,
+ base_model: GPT,
+ rank: int,
+ world_size: int,
+ device: torch.device,
+ base_bytes_lut: Tensor,
+ has_leading_space_lut: Tensor,
+ is_boundary_token_lut: Tensor,
+) -> tuple[float, float]:
+ """Evaluate with batched LoRA test-time training. Returns (val_loss, val_bpb)."""
+ # Load validation tokens and find document boundaries
+ files = sorted(glob.glob(args.val_files))
+ all_tokens = torch.cat([load_data_shard(Path(f)) for f in files])
+ docs = _find_docs(all_tokens)
+
+ # Each rank takes a contiguous slice of documents
+ rank_docs = docs[(len(docs) * rank) // world_size : (len(docs) * (rank + 1)) // world_size]
+ chunk_size = args.ttt_chunk_size
+ eval_seq_len = args.ttt_eval_seq_len
+ batch_size = args.ttt_batch_size
+ lora_rank = args.ttt_lora_rank
+
+ rank_docs.sort(key=lambda d: (d[1] - 2) // chunk_size)
+
+ base_model.eval()
+ for p in base_model.parameters():
+ p.requires_grad_(False)
+
+ lora = BatchedTTTLoRA(batch_size, base_model, lora_rank).to(device)
+ opt = _build_ttt_optimizer(lora, args)
+
+ loss_sum = torch.zeros((), device=device, dtype=torch.float64)
+ byte_sum = torch.zeros((), device=device, dtype=torch.float64)
+ token_count = torch.zeros((), device=device, dtype=torch.float64)
+
+ for bi in range(0, len(rank_docs), batch_size):
+ batch = rank_docs[bi:bi + batch_size]
+ bsz = len(batch)
+
+ if bsz == batch_size:
+ cur_lora, cur_opt = lora, opt
+ cur_lora.reset()
+ _reset_ttt_optimizer(cur_opt)
else:
- if self.lm_head is None:
- raise RuntimeError("lm_head is required when tie_embeddings=False")
- logits_proj = self.lm_head(x)
- logits = self.logit_softcap * torch.tanh(logits_proj / self.logit_softcap)
- return F.cross_entropy(logits.float(), targets, reduction="mean")
+ cur_lora = BatchedTTTLoRA(bsz, base_model, lora_rank).to(device)
+ cur_opt = _build_ttt_optimizer(cur_lora, args)
+
+ pred_lens = [doc_len - 1 for _, doc_len in batch]
+ num_chunks = [(pl + chunk_size - 1) // chunk_size for pl in pred_lens]
+ max_nc = max(num_chunks)
+
+ for ci in range(max_nc):
+ chunk_stats = _compute_chunk_window(ci, (ci + 1) * chunk_size, ci + 1, chunk_size, eval_seq_len)
+ context_size, chunk_offset = chunk_stats[1], chunk_stats[2]
+
+ active = [ci < nc for nc in num_chunks]
+ needs_train = any(ci < nc - 1 for nc in num_chunks)
+
+ x = torch.zeros(bsz, context_size, dtype=torch.int64, device=device)
+ y = torch.zeros(bsz, context_size, dtype=torch.int64, device=device)
+ doc_info = [] # (chunk_offset, chunk_len) per doc
+ for b in range(bsz):
+ if not active[b]:
+ doc_info.append((0, 0))
+ continue
+ ds, dl = batch[b]
+ ws, wl, co, cl = _compute_chunk_window(ci, pred_lens[b], num_chunks[b], chunk_size, eval_seq_len)
+ chunk = all_tokens[ds + ws: ds + ws + wl + 1]
+ toks = chunk.to(dtype=torch.int64, device=device)
+ x[b, :wl] = toks[:-1]
+ y[b, :wl] = toks[1:]
+ doc_info.append((co, cl))
+
+ # Forward pass (keep grad graph alive only when we need to train)
+ if needs_train:
+ with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ ptl = base_model(x, y, lora=cur_lora)
+ else:
+ with torch.no_grad(), torch.autocast(device_type="cuda", dtype=torch.bfloat16):
+ ptl = base_model(x, y, lora=cur_lora)
+
+ # Score: accumulate loss and byte counts for BPB (before training on chunk)
+ with torch.no_grad():
+ for b in range(bsz):
+ if not active[b]:
+ continue
+ co, cl = doc_info[b]
+ _accumulate_bpb(
+ ptl, x, y, b, co, cl, base_bytes_lut, has_leading_space_lut,
+ is_boundary_token_lut, loss_sum, byte_sum, token_count)
+
+ # Train: one Adam step on the LoRA params using this chunk's loss
+ if needs_train:
+ mask = torch.tensor([float(ci < num_chunks[b] - 1) for b in range(bsz)], device=device)
+ per_doc = ptl[:, chunk_offset:chunk_offset + chunk_size].mean(dim=-1)
+ cur_opt.zero_grad()
+ (per_doc * mask).sum().backward()
+ cur_opt.step()
+ if dist.is_available() and dist.is_initialized():
+ dist.all_reduce(loss_sum, op=dist.ReduceOp.SUM)
+ dist.all_reduce(byte_sum, op=dist.ReduceOp.SUM)
+ dist.all_reduce(token_count, op=dist.ReduceOp.SUM)
+
+ val_loss = float(loss_sum.item() / token_count.item())
+ val_bpb = float((loss_sum.item() / math.log(2.0)) / byte_sum.item())
+ return val_loss, val_bpb
# -----------------------------
# TRAINING
@@ -839,6 +1069,8 @@ def main() -> None:
for module in base_model.modules():
if isinstance(module, CastedLinear):
module.float()
+ if isinstance(module, Rotary):
+ module.inv_freq.data = module.inv_freq.data.float()
restore_low_dim_params_to_fp32(base_model)
compiled_model = torch.compile(base_model, dynamic=False, fullgraph=True)
model: nn.Module = DDP(compiled_model, device_ids=[local_rank], broadcast_buffers=False) if distributed else compiled_model
@@ -1118,6 +1350,20 @@ def main() -> None:
)
log0(f"final_int8_zlib_roundtrip_exact val_loss:{q_val_loss:.8f} val_bpb:{q_val_bpb:.8f}")
+ # LoRA test-time training evaluation (the competition score)
+ torch._dynamo.reset()
+ torch.cuda.synchronize()
+ t_ttt = time.perf_counter()
+ ttt_val_loss, ttt_val_bpb = eval_val_ttt_lora(
+ args, base_model, rank, world_size, device,
+ base_bytes_lut, has_leading_space_lut, is_boundary_token_lut,
+ )
+ torch.cuda.synchronize()
+ log0(
+ f"final_int8_ttt_lora val_loss:{ttt_val_loss:.4f} val_bpb:{ttt_val_bpb:.4f} "
+ f"eval_time:{1000.0 * (time.perf_counter() - t_ttt):.0f}ms"
+ )
+
if distributed:
dist.destroy_process_group()