summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/__init__.py0
-rw-r--r--src/data/__init__.py0
-rw-r--r--src/data/dolma.py226
-rw-r--r--src/model/__init__.py0
-rw-r--r--src/model/olmo_graph.py397
-rw-r--r--src/model/pipeline.py144
-rw-r--r--src/model/predictor.py275
-rw-r--r--src/training/__init__.py0
-rw-r--r--src/training/checkpointing.py92
-rw-r--r--src/training/schedulers.py35
-rw-r--r--src/training/trainer.py465
-rw-r--r--src/utils/__init__.py0
-rw-r--r--src/utils/logging.py63
-rw-r--r--src/utils/topology.py87
14 files changed, 1784 insertions, 0 deletions
diff --git a/src/__init__.py b/src/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/__init__.py
diff --git a/src/data/__init__.py b/src/data/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/data/__init__.py
diff --git a/src/data/dolma.py b/src/data/dolma.py
new file mode 100644
index 0000000..4e2baaf
--- /dev/null
+++ b/src/data/dolma.py
@@ -0,0 +1,226 @@
+"""Streaming dataloader for Dolma v1.7 with sequence packing.
+
+Produces packed sequences of fixed length for both OLMo and Qwen tokenizers.
+See CLAUDE.md §3.1.1 for sequence packing specification.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import Iterator, Optional
+
+import torch
+from datasets import load_dataset
+from torch.utils.data import IterableDataset
+from transformers import AutoTokenizer
+
+
+class DolmaPackedDataset(IterableDataset):
+ """Streaming Dolma dataset with sequence packing.
+
+ Concatenates documents with EOS separators, then chunks into fixed-length
+ sequences. No padding — every token contributes to NLL.
+
+ Each sample yields:
+ olmo_ids: [seq_len] — OLMo input token IDs
+ olmo_labels: [seq_len] — shifted labels (next-token prediction)
+ raw_text: str — decoded text for Qwen encoder
+ """
+
+ def __init__(
+ self,
+ olmo_tokenizer: AutoTokenizer,
+ seq_len: int = 1024,
+ dataset_name: str = "allenai/dolma",
+ dataset_version: str = "v1_7",
+ rank: int = 0,
+ world_size: int = 1,
+ max_samples: Optional[int] = None,
+ ):
+ super().__init__()
+ self.olmo_tokenizer = olmo_tokenizer
+ self.seq_len = seq_len
+ self.dataset_name = dataset_name
+ self.dataset_version = dataset_version
+ self.rank = rank
+ self.world_size = world_size
+ self.max_samples = max_samples
+
+ self.eos_id = olmo_tokenizer.eos_token_id
+ assert self.eos_id is not None, "OLMo tokenizer must have an EOS token"
+
+ def __iter__(self) -> Iterator[dict]:
+ """Yield packed sequences from Dolma stream."""
+ try:
+ dataset = load_dataset(
+ self.dataset_name,
+ name=self.dataset_version,
+ split="train",
+ streaming=True,
+ trust_remote_code=True,
+ )
+ except Exception:
+ # Fallback if specific version not available
+ dataset = load_dataset(
+ self.dataset_name,
+ split="train",
+ streaming=True,
+ trust_remote_code=True,
+ )
+
+ # Shard for multi-GPU
+ if self.world_size > 1:
+ dataset = dataset.shard(num_shards=self.world_size, index=self.rank)
+
+ buffer: list[int] = []
+ sample_count = 0
+
+ for doc in dataset:
+ if self.max_samples is not None and sample_count >= self.max_samples:
+ break
+
+ text = doc.get("text", "")
+ if not text.strip():
+ continue
+
+ tokens = self.olmo_tokenizer(text, add_special_tokens=False)["input_ids"]
+ buffer.extend(tokens)
+ buffer.append(self.eos_id)
+
+ # Yield packed sequences as buffer fills
+ while len(buffer) >= self.seq_len + 1:
+ chunk = buffer[:self.seq_len + 1]
+ buffer = buffer[self.seq_len + 1:]
+
+ olmo_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long)
+ olmo_labels = torch.tensor(chunk[1:self.seq_len + 1], dtype=torch.long)
+ raw_text = self.olmo_tokenizer.decode(chunk[:self.seq_len], skip_special_tokens=False)
+
+ yield {
+ "olmo_ids": olmo_ids,
+ "olmo_labels": olmo_labels,
+ "raw_text": raw_text,
+ }
+ sample_count += 1
+
+ if self.max_samples is not None and sample_count >= self.max_samples:
+ break
+
+
+def build_train_dataloader(
+ olmo_tokenizer: AutoTokenizer,
+ seq_len: int = 1024,
+ batch_size: int = 4,
+ dataset_name: str = "allenai/dolma",
+ dataset_version: str = "v1_7",
+ rank: int = 0,
+ world_size: int = 1,
+ num_workers: int = 0,
+) -> torch.utils.data.DataLoader:
+ """Build training dataloader with sequence packing."""
+ dataset = DolmaPackedDataset(
+ olmo_tokenizer=olmo_tokenizer,
+ seq_len=seq_len,
+ dataset_name=dataset_name,
+ dataset_version=dataset_version,
+ rank=rank,
+ world_size=world_size,
+ )
+ return torch.utils.data.DataLoader(
+ dataset,
+ batch_size=batch_size,
+ num_workers=num_workers,
+ collate_fn=_collate_packed,
+ )
+
+
+def build_eval_dataloader(
+ olmo_tokenizer: AutoTokenizer,
+ seq_len: int = 1024,
+ batch_size: int = 4,
+ dataset_name: str = "allenai/dolma",
+ dataset_version: str = "v1_7",
+ eval_skip: int = 1_000_000,
+ eval_size: int = 1_000,
+ cache_path: Optional[str] = None,
+) -> list[dict]:
+ """Build eval batches (cached in memory).
+
+ Skips eval_skip examples in the stream, then takes eval_size packed sequences.
+ Caches to disk to avoid repeated skip on restart.
+ """
+ # Try loading from cache
+ if cache_path and os.path.exists(cache_path):
+ print(f"Loading eval cache from {cache_path}")
+ return torch.load(cache_path)
+
+ print(f"Building eval set (skip={eval_skip}, size={eval_size})...")
+
+ try:
+ dataset = load_dataset(
+ dataset_name,
+ name=dataset_version,
+ split="train",
+ streaming=True,
+ trust_remote_code=True,
+ )
+ except Exception:
+ dataset = load_dataset(
+ dataset_name,
+ split="train",
+ streaming=True,
+ trust_remote_code=True,
+ )
+
+ # Skip to held-out region
+ dataset = dataset.skip(eval_skip)
+
+ eos_id = olmo_tokenizer.eos_token_id
+ buffer: list[int] = []
+ eval_samples: list[dict] = []
+
+ for doc in dataset:
+ if len(eval_samples) >= eval_size:
+ break
+
+ text = doc.get("text", "")
+ if not text.strip():
+ continue
+
+ tokens = olmo_tokenizer(text, add_special_tokens=False)["input_ids"]
+ buffer.extend(tokens)
+ buffer.append(eos_id)
+
+ while len(buffer) >= seq_len + 1 and len(eval_samples) < eval_size:
+ chunk = buffer[:seq_len + 1]
+ buffer = buffer[seq_len + 1:]
+ eval_samples.append({
+ "olmo_ids": torch.tensor(chunk[:seq_len], dtype=torch.long),
+ "olmo_labels": torch.tensor(chunk[1:seq_len + 1], dtype=torch.long),
+ "raw_text": olmo_tokenizer.decode(chunk[:seq_len], skip_special_tokens=False),
+ })
+
+ print(f"Built {len(eval_samples)} eval sequences")
+
+ # Batch the samples
+ eval_batches = []
+ for i in range(0, len(eval_samples), batch_size):
+ batch_items = eval_samples[i:i + batch_size]
+ eval_batches.append(_collate_packed(batch_items))
+
+ # Cache to disk
+ if cache_path:
+ os.makedirs(os.path.dirname(cache_path) or ".", exist_ok=True)
+ torch.save(eval_batches, cache_path)
+ print(f"Eval cache saved to {cache_path}")
+
+ return eval_batches
+
+
+def _collate_packed(batch: list[dict]) -> dict:
+ """Collate packed samples into a batch dict."""
+ return {
+ "olmo_ids": torch.stack([s["olmo_ids"] for s in batch]),
+ "olmo_labels": torch.stack([s["olmo_labels"] for s in batch]),
+ "raw_text": [s["raw_text"] for s in batch],
+ }
diff --git a/src/model/__init__.py b/src/model/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/model/__init__.py
diff --git a/src/model/olmo_graph.py b/src/model/olmo_graph.py
new file mode 100644
index 0000000..af9f848
--- /dev/null
+++ b/src/model/olmo_graph.py
@@ -0,0 +1,397 @@
+"""Modified OLMo2-1B forward pass with adjacency matrix A injection.
+
+This module implements the core DAGFormer modification: per-head input
+assembly controlled by a 256x256 adjacency matrix A. Each head receives
+its own input (a gated combination of prior heads' outputs), rather than
+the shared residual stream.
+
+Key design decisions:
+- Uses proportional attribution for post_attention_layernorm decomposition
+ (OLMo2 is post-norm, not pre-norm as CLAUDE.md §2.1 assumes)
+- Concatenate→q_norm→split pattern for per-head Q/K normalization
+- Weight slices via .view() (not .clone()) for Phase 2 compatibility
+- When A=all-ones and input_norm="none", output is identical to vanilla OLMo2
+"""
+
+from __future__ import annotations
+
+import math
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from transformers import AutoModelForCausalLM
+from transformers.models.olmo2.modeling_olmo2 import (
+ apply_rotary_pos_emb,
+)
+
+
+def create_block_upper_triangular_mask(num_nodes: int = 256, heads_per_layer: int = 16) -> torch.Tensor:
+ """Create block-upper-triangular mask based on LAYER indices.
+
+ mask[i,j] = 1 iff layer(j) > layer(i), i.e. j//16 > i//16.
+ Same-layer and backward connections are 0.
+ Do NOT use torch.triu() — it allows same-layer connections.
+
+ Returns:
+ mask: [num_nodes, num_nodes] float tensor with 0s and 1s
+ """
+ layer_idx = torch.arange(num_nodes) // heads_per_layer
+ mask = (layer_idx.unsqueeze(1) < layer_idx.unsqueeze(0)).float() # [256, 256]
+ return mask
+
+
+class InputNormalizer(nn.Module):
+ """Normalization methods for gated head output sums (CLAUDE.md §6.1).
+
+ Applied ONLY to the gated_sum component, not the base (embedding + MLPs).
+ """
+
+ def __init__(self, method: str, model_dim: int = 2048, num_nodes: int = 256):
+ super().__init__()
+ self.method = method
+ self.model_dim = model_dim
+
+ if method == "none":
+ pass
+ elif method == "gate_mean":
+ pass # no learnable params
+ elif method == "rms_post":
+ self.norm = nn.RMSNorm(model_dim)
+ elif method == "ln_post":
+ self.norm = nn.LayerNorm(model_dim)
+ elif method == "rms_pre":
+ self.norms = nn.ModuleList([nn.RMSNorm(model_dim) for _ in range(num_nodes)])
+ else:
+ raise ValueError(f"Unknown input_norm method: {method}")
+
+ def forward(
+ self,
+ gated_sum: torch.Tensor,
+ A_slice: Optional[torch.Tensor] = None,
+ prior_head_outs: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """Normalize the gated sum of prior head outputs.
+
+ Args:
+ gated_sum: [batch, num_heads, seq, model_dim] — gated sum for this layer's heads
+ A_slice: [batch, num_prior_nodes, num_heads] — gate values (for gate_mean)
+ prior_head_outs: [batch, num_prior_nodes, seq, model_dim] — for rms_pre
+ Returns:
+ Normalized gated_sum, same shape
+ """
+ if self.method == "none":
+ return gated_sum
+
+ elif self.method == "gate_mean":
+ assert A_slice is not None
+ # Sum of gates per target head: [batch, num_heads]
+ gate_sum = A_slice.sum(dim=1) # [batch, num_heads]
+ # Divide gated_sum by gate_sum (avoid div by zero)
+ divisor = gate_sum.clamp(min=1e-8) # [batch, num_heads]
+ return gated_sum / divisor[:, :, None, None] # broadcast over [seq, model_dim]
+
+ elif self.method == "rms_post":
+ return self.norm(gated_sum)
+
+ elif self.method == "ln_post":
+ return self.norm(gated_sum)
+
+ elif self.method == "rms_pre":
+ # Apply per-source-node RMSNorm before gating, then recompute gated sum
+ # This requires prior_head_outs and A_slice
+ assert prior_head_outs is not None and A_slice is not None
+ num_prior = prior_head_outs.shape[1]
+ # Normalize each source node's output
+ normed_sources = []
+ for i in range(num_prior):
+ normed_sources.append(self.norms[i](prior_head_outs[:, i]))
+ normed_sources = torch.stack(normed_sources, dim=1) # [B, num_prior, S, D]
+ # Recompute gated sum with normed sources
+ return torch.einsum('bih,bisd->bhsd', A_slice, normed_sources)
+
+ raise ValueError(f"Unknown method: {self.method}")
+
+
+class DAGFormerOLMo(nn.Module):
+ """Wraps OLMo2-1B with adjacency matrix A injection for per-head routing.
+
+ When A is all-ones and input_norm is "none", this produces output
+ identical to vanilla OLMo2-1B (baseline reproduction invariant).
+ """
+
+ def __init__(
+ self,
+ model: AutoModelForCausalLM,
+ input_norm: str = "none",
+ num_layers: int = 16,
+ num_heads: int = 16,
+ ):
+ super().__init__()
+ self.olmo = model
+ self.num_layers = num_layers
+ self.num_heads = num_heads
+ self.num_nodes = num_layers * num_heads
+ self.model_dim = model.config.hidden_size
+ self.head_dim = self.model_dim // num_heads
+ self.rms_norm_eps = model.config.rms_norm_eps
+
+ # Runtime assertions
+ assert model.config.num_attention_heads == num_heads, \
+ f"Expected {num_heads} attention heads, got {model.config.num_attention_heads}"
+ assert model.config.num_key_value_heads == num_heads, \
+ f"Expected MHA ({num_heads} KV heads), got {model.config.num_key_value_heads} — GQA detected"
+
+ # Verify no bias
+ layer0_attn = model.model.layers[0].self_attn
+ assert layer0_attn.o_proj.bias is None, \
+ "Expected no bias in o_proj — update per-head splitting if bias exists"
+
+ # Block-upper-triangular mask: [256, 256]
+ self.register_buffer('dag_mask', create_block_upper_triangular_mask(self.num_nodes, num_heads))
+
+ # Input normalization
+ self.input_normalizer = InputNormalizer(input_norm, self.model_dim, self.num_nodes)
+
+ # Attention scaling factor
+ self.scaling = self.head_dim ** -0.5
+
+ def _get_head_weight_views(self, layer_idx: int) -> dict:
+ """Get per-head weight views for a given layer.
+
+ Uses .view() which returns views of the same storage — no copy,
+ gradients flow through for Phase 2 compatibility.
+ """
+ layer = self.olmo.model.layers[layer_idx]
+ attn = layer.self_attn
+
+ # Q, K, V projections: [model_dim, model_dim] → [num_heads, head_dim, model_dim]
+ W_q = attn.q_proj.weight.view(self.num_heads, self.head_dim, self.model_dim)
+ W_k = attn.k_proj.weight.view(self.num_heads, self.head_dim, self.model_dim)
+ W_v = attn.v_proj.weight.view(self.num_heads, self.head_dim, self.model_dim)
+
+ # O projection: [model_dim, model_dim]
+ # Split by INPUT dimension (columns): [model_dim, num_heads, head_dim]
+ # Permute to [num_heads, model_dim, head_dim] for einsum
+ W_o = attn.o_proj.weight.view(self.model_dim, self.num_heads, self.head_dim)
+ W_o = W_o.permute(1, 0, 2) # [num_heads, model_dim, head_dim]
+
+ return {
+ 'W_q': W_q, 'W_k': W_k, 'W_v': W_v, 'W_o': W_o,
+ 'q_norm': attn.q_norm,
+ 'k_norm': attn.k_norm,
+ 'post_attn_norm': layer.post_attention_layernorm,
+ 'post_ff_norm': layer.post_feedforward_layernorm,
+ 'mlp': layer.mlp,
+ }
+
+ def forward(
+ self,
+ olmo_ids: torch.Tensor,
+ A: torch.Tensor,
+ ) -> torch.Tensor:
+ """Modified OLMo2-1B forward pass with per-head routing via A.
+
+ Args:
+ olmo_ids: [batch, seq_len] — tokenized by OLMo's tokenizer
+ A: [batch, 256, 256] — block-upper-triangular gate matrix
+
+ Returns:
+ logits: [batch, seq_len, vocab_size]
+ """
+ batch, seq_len = olmo_ids.shape
+ device = olmo_ids.device
+
+ assert A.shape == (batch, self.num_nodes, self.num_nodes), \
+ f"A shape mismatch: expected ({batch}, {self.num_nodes}, {self.num_nodes}), got {A.shape}"
+
+ # Cast A to model dtype (predictor outputs float32, OLMo uses bfloat16)
+ model_dtype = self.olmo.model.embed_tokens.weight.dtype
+ A = A.to(dtype=model_dtype)
+
+ # Token embedding
+ embedding = self.olmo.model.embed_tokens(olmo_ids) # [B, S, D]
+
+ # Position embeddings (computed once, shared across all layers)
+ position_ids = torch.arange(seq_len, device=device).unsqueeze(0) # [1, S]
+ position_embeddings = self.olmo.model.rotary_emb(embedding, position_ids)
+ cos, sin = position_embeddings
+
+ # Causal attention mask: [1, 1, S, S]
+ causal_mask = torch.zeros(1, 1, seq_len, seq_len, device=device, dtype=embedding.dtype)
+ causal_mask.masked_fill_(
+ torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1),
+ float('-inf'),
+ )
+
+ # Storage for outputs across layers
+ # We accumulate head_outputs as a list of [B, 16, S, D] tensors (one per layer)
+ all_head_outputs: list[torch.Tensor] = [] # each: [B, 16, S, D]
+ mlp_outputs: list[torch.Tensor] = [] # each: [B, S, D]
+
+ # Running base: embedding + accumulated MLP outputs (for per-head assembly)
+ base = embedding.clone() # [B, S, D]
+ # Accumulated ungated attention outputs (for MLP input)
+ attn_accumulated = torch.zeros_like(embedding) # [B, S, D]
+
+ for l in range(self.num_layers):
+ weights = self._get_head_weight_views(l)
+
+ # === ASSEMBLE PER-HEAD INPUTS ===
+ if l == 0:
+ # Layer 0: all heads see only the embedding (no prior heads or MLPs)
+ assembled = embedding.unsqueeze(1).expand(-1, self.num_heads, -1, -1)
+ # assembled: [B, 16, S, D]
+ else:
+ # base_l = embedding + Σ_{l'<l} mlp_outputs[l']
+ # (base is updated incrementally after each layer's MLP)
+
+ # Stack all prior head outputs: [B, l*16, S, D]
+ prior_head_outs = torch.cat(all_head_outputs, dim=1)
+
+ # Slice A for connections into this layer's heads
+ # A[:, source_nodes, target_nodes]
+ # source: nodes 0..(l*16-1), target: nodes l*16..(l*16+15)
+ A_slice = A[:, :l * self.num_heads, l * self.num_heads:(l + 1) * self.num_heads]
+ # A_slice: [B, l*16, 16]
+
+ # Batched gated sum via einsum
+ gated_sum = torch.einsum('bih,bisd->bhsd', A_slice, prior_head_outs)
+ # gated_sum: [B, 16, S, D]
+
+ # Apply input normalization (only to gated_sum, not base)
+ if self.input_normalizer.method == "rms_pre":
+ gated_sum = self.input_normalizer(
+ gated_sum, A_slice=A_slice, prior_head_outs=prior_head_outs
+ )
+ elif self.input_normalizer.method == "gate_mean":
+ gated_sum = self.input_normalizer(gated_sum, A_slice=A_slice)
+ else:
+ gated_sum = self.input_normalizer(gated_sum)
+
+ # assembled = base + gated_sum
+ assembled = base.unsqueeze(1) + gated_sum # [B, 16, S, D]
+
+ # === PER-HEAD Q/K/V PROJECTION ===
+ W_q, W_k, W_v, W_o = weights['W_q'], weights['W_k'], weights['W_v'], weights['W_o']
+
+ # Per-head projections via einsum
+ # assembled: [B, H, S, D], W_q: [H, head_dim, D]
+ q_per_head = torch.einsum('bhsd,hod->bhso', assembled, W_q) # [B, H, S, head_dim]
+ k_per_head = torch.einsum('bhsd,hod->bhso', assembled, W_k)
+ v_per_head = torch.einsum('bhsd,hod->bhso', assembled, W_v)
+
+ # === Q_NORM / K_NORM ===
+ # OLMo2 applies RMSNorm to concatenated Q/K (2048-dim) AFTER projection.
+ # Concat all heads → norm → split back.
+ # When A=1 (all heads same input), this equals q_norm(q_proj(shared_input)).
+ q_concat = rearrange(q_per_head, 'b h s d -> b s (h d)') # [B, S, 2048]
+ q_normed = weights['q_norm'](q_concat)
+ q_per_head = rearrange(q_normed, 'b s (h d) -> b h s d', h=self.num_heads)
+
+ k_concat = rearrange(k_per_head, 'b h s d -> b s (h d)')
+ k_normed = weights['k_norm'](k_concat)
+ k_per_head = rearrange(k_normed, 'b s (h d) -> b h s d', h=self.num_heads)
+
+ # V has NO norm in OLMo2
+
+ # === APPLY RoPE ===
+ q_per_head, k_per_head = apply_rotary_pos_emb(q_per_head, k_per_head, cos, sin)
+
+ # === ATTENTION COMPUTATION ===
+ # q,k,v: [B, H, S, head_dim]
+ attn_weights = torch.matmul(q_per_head, k_per_head.transpose(-2, -1)) * self.scaling
+ # attn_weights: [B, H, S, S]
+ attn_weights = attn_weights + causal_mask # [1, 1, S, S] broadcasts
+ attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_per_head.dtype)
+ attn_values = torch.matmul(attn_weights, v_per_head) # [B, H, S, head_dim]
+
+ # === PER-HEAD O_PROJ ===
+ # attn_values: [B, H, S, head_dim], W_o: [H, model_dim, head_dim]
+ raw_head_outs = torch.einsum('bhsd,hod->bhso', attn_values, W_o)
+ # raw_head_outs: [B, H, S, model_dim]
+
+ # === PROPORTIONAL ATTRIBUTION WITH POST_ATTN_NORM ===
+ # OLMo2 applies post_attention_layernorm to the COMBINED attention output.
+ # RMSNorm(Σ_h x_h) = weight * (Σ_h x_h) / RMS(Σ_h x_h)
+ # = Σ_h [weight * x_h / RMS(Σ_h x_h)]
+ # We attribute each head's normed output proportionally.
+ raw_sum = raw_head_outs.sum(dim=1) # [B, S, D]
+ # Compute RMS of the sum
+ variance = raw_sum.to(torch.float32).pow(2).mean(-1, keepdim=True)
+ rms = torch.sqrt(variance + self.rms_norm_eps) # [B, S, 1]
+ # Apply post_attn_norm weight and scale
+ norm_weight = weights['post_attn_norm'].weight # [D]
+ # head_output[h] = norm_weight * raw_head_out[h] / rms
+ scale = (norm_weight / rms).unsqueeze(1) # [B, 1, S, D]
+ head_outputs_l = raw_head_outs.float() * scale # [B, H, S, D]
+ head_outputs_l = head_outputs_l.to(raw_head_outs.dtype)
+
+ # Store for routing to later layers
+ all_head_outputs.append(head_outputs_l)
+
+ # === MLP COMPUTATION (standard, ungated) ===
+ # attn_normed = Σ_h head_output[l,h] = post_attn_norm(raw_sum)
+ attn_normed = head_outputs_l.sum(dim=1) # [B, S, D]
+
+ # MLP input = full residual stream (embedding + all prior MLPs + all attn up to current)
+ # In vanilla OLMo2: mlp_input = residual + post_attn_norm(attn_output)
+ # where residual includes ALL prior components (embedding + prior MLPs + prior attns)
+ mlp_in = base + attn_accumulated + attn_normed
+
+ # Update accumulated attention for next layer
+ attn_accumulated = attn_accumulated + attn_normed
+
+ # MLP forward + post_feedforward_layernorm
+ mlp_raw = weights['mlp'](mlp_in)
+ mlp_output_l = weights['post_ff_norm'](mlp_raw)
+ mlp_outputs.append(mlp_output_l)
+
+ # Update running base for next layer
+ # base_{l+1} = base_l + mlp_output_l = embedding + Σ_{l'<=l} mlp_output[l']
+ base = base + mlp_output_l
+
+ # === FINAL OUTPUT ===
+ # final_state = embedding + Σ_l mlp_output[l] + Σ_l Σ_h head_output[l,h]
+ # = embedding + Σ_l [post_attn_norm(attn_out_l) + post_ff_norm(mlp_out_l)]
+ # 'base' = embedding + Σ_l mlp_output[l]
+ # 'attn_accumulated' = Σ_l attn_output[l] (ungated sum of all attention outputs)
+ final_state = base + attn_accumulated
+
+ # Apply final norm and lm_head
+ final_state = self.olmo.model.norm(final_state)
+ logits = self.olmo.lm_head(final_state)
+
+ return logits
+
+
+def compute_vanilla_nll(
+ model: AutoModelForCausalLM,
+ input_ids: torch.Tensor,
+ labels: torch.Tensor,
+) -> torch.Tensor:
+ """Compute NLL using vanilla OLMo2 forward pass (no A injection).
+
+ Used for baseline comparison in sanity checks.
+ """
+ with torch.no_grad():
+ outputs = model(input_ids=input_ids)
+ logits = outputs.logits
+ nll = F.cross_entropy(
+ logits[:, :-1].contiguous().view(-1, logits.size(-1)),
+ labels[:, 1:].contiguous().view(-1),
+ )
+ return nll
+
+
+def create_all_ones_A(batch_size: int, num_nodes: int = 256, num_heads: int = 16) -> torch.Tensor:
+ """Create A matrix with 1.0 for all valid (cross-layer) entries.
+
+ When used with input_norm="none", this should reproduce vanilla OLMo2.
+ """
+ A = torch.zeros(batch_size, num_nodes, num_nodes)
+ mask = create_block_upper_triangular_mask(num_nodes, num_heads)
+ A = A + mask.unsqueeze(0) # broadcast mask to batch
+ return A
diff --git a/src/model/pipeline.py b/src/model/pipeline.py
new file mode 100644
index 0000000..bbfcabf
--- /dev/null
+++ b/src/model/pipeline.py
@@ -0,0 +1,144 @@
+"""End-to-end DAGFormer pipeline: raw text → predictor → A → OLMo → NLL.
+
+Glues the structure predictor (Qwen + MLP) with the modified OLMo forward.
+This is what the trainer calls. See CLAUDE.md §5 for file responsibilities.
+"""
+
+from __future__ import annotations
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A
+from src.model.predictor import StructurePredictor
+
+
+class DAGFormerPipeline(nn.Module):
+ """Combines StructurePredictor + DAGFormerOLMo into a single forward pass.
+
+ Forward: raw_text → predictor → A → modified OLMo → logits → NLL
+
+ Only the predictor's MLP params are trainable. OLMo and Qwen are frozen.
+ """
+
+ def __init__(
+ self,
+ olmo_model_id: str = "allenai/OLMo-2-0425-1B",
+ qwen_model_id: str = "Qwen/Qwen3-Embedding-0.6B",
+ predictor_hidden_dim: int = 1024,
+ predictor_rank: int = 32,
+ cascading_gate_k: float = 5.0,
+ input_norm: str = "none",
+ qwen_input_prefix: str = "",
+ device: Optional[torch.device] = None,
+ ):
+ super().__init__()
+
+ # Load frozen OLMo2-1B
+ olmo = AutoModelForCausalLM.from_pretrained(
+ olmo_model_id,
+ torch_dtype=torch.bfloat16,
+ )
+ olmo.eval()
+ for p in olmo.parameters():
+ p.requires_grad_(False)
+
+ # Wrap OLMo with DAGFormer modification
+ self.olmo_wrapper = DAGFormerOLMo(model=olmo, input_norm=input_norm)
+
+ # Structure predictor (Qwen encoder + MLP decoder)
+ self.predictor = StructurePredictor(
+ qwen_model_id=qwen_model_id,
+ hidden_dim=predictor_hidden_dim,
+ rank=predictor_rank,
+ cascading_gate_k=cascading_gate_k,
+ qwen_input_prefix=qwen_input_prefix,
+ device=device,
+ )
+
+ self.vocab_size = olmo.config.vocab_size
+
+ if device is not None:
+ self.to(device)
+
+ def forward(
+ self,
+ raw_texts: list[str],
+ olmo_ids: torch.Tensor,
+ olmo_labels: torch.Tensor,
+ tau: float,
+ lambda_sparsity: float = 0.0,
+ mode: str = "train",
+ ) -> dict[str, torch.Tensor]:
+ """Full forward pass: text → A → logits → loss.
+
+ Args:
+ raw_texts: list of raw text strings (batch)
+ olmo_ids: [batch, seq_len] — OLMo tokenized input
+ olmo_labels: [batch, seq_len] — shifted labels for NLL
+ tau: Gumbel-Sigmoid temperature
+ lambda_sparsity: sparsity coefficient (λ_t)
+ mode: "train", "eval_soft", or "eval_hard"
+
+ Returns:
+ dict with keys:
+ "total_loss": nll + lambda * mean(A) — what the optimizer sees
+ "nll": cross-entropy loss
+ "sparsity_loss": lambda * mean(A)
+ "A": [batch, 256, 256] adjacency matrix
+ """
+ # Step 1: Predict adjacency matrix
+ A = self.predictor(raw_texts, tau=tau, mode=mode)
+ # A: [batch, 256, 256]
+
+ # Step 2: Modified OLMo forward with A
+ logits = self.olmo_wrapper(olmo_ids, A)
+ # logits: [batch, seq_len, vocab_size]
+
+ # Step 3: Compute NLL (next-token prediction)
+ # Shift: logits[:, :-1] predicts labels[:, 1:]
+ nll = F.cross_entropy(
+ logits[:, :-1].contiguous().view(-1, self.vocab_size),
+ olmo_labels[:, 1:].contiguous().view(-1),
+ )
+
+ # Step 4: Sparsity regularization
+ sparsity_loss = lambda_sparsity * A.mean()
+ total_loss = nll + sparsity_loss
+
+ return {
+ "total_loss": total_loss,
+ "nll": nll,
+ "sparsity_loss": sparsity_loss,
+ "A": A,
+ }
+
+ def forward_baseline(
+ self,
+ olmo_ids: torch.Tensor,
+ olmo_labels: torch.Tensor,
+ ) -> torch.Tensor:
+ """Forward with A=all-ones (baseline reproduction).
+
+ Used for eval/nll_baseline metric.
+ """
+ batch = olmo_ids.shape[0]
+ A = create_all_ones_A(batch).to(olmo_ids.device)
+ with torch.no_grad():
+ logits = self.olmo_wrapper(olmo_ids, A)
+ nll = F.cross_entropy(
+ logits[:, :-1].contiguous().view(-1, self.vocab_size),
+ olmo_labels[:, 1:].contiguous().view(-1),
+ )
+ return nll
+
+ def get_trainable_parameters(self) -> list[nn.Parameter]:
+ """Return only the trainable parameters (predictor MLP + any norm params)."""
+ params = list(self.predictor.get_trainable_parameters())
+ # Also include input normalizer params if they exist
+ params.extend(self.olmo_wrapper.input_normalizer.parameters())
+ return params
diff --git a/src/model/predictor.py b/src/model/predictor.py
new file mode 100644
index 0000000..0bc0ae3
--- /dev/null
+++ b/src/model/predictor.py
@@ -0,0 +1,275 @@
+"""Structure predictor: Qwen encoder + MLP decoder + Gumbel-Sigmoid + cascading gate.
+
+Takes raw text, produces a 256x256 adjacency matrix A controlling per-head
+routing in OLMo2-1B. See CLAUDE.md §2.3 for full specification.
+
+Components:
+- QwenEncoder: frozen Qwen3-Embedding-0.6B, mean-pooled to single vector
+- PredictorMLP: trainable MLP with low-rank output heads (U, V → Z = UV^T)
+- Gumbel-Sigmoid: differentiable relaxation of binary gates (3 modes)
+- Cascading gate: kill outgoing edges from disconnected nodes
+- Block-upper-triangular mask: enforce DAG constraint (layer(j) > layer(i))
+"""
+
+from __future__ import annotations
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from transformers import AutoModel, AutoTokenizer
+
+from src.model.olmo_graph import create_block_upper_triangular_mask
+
+
+class QwenEncoder(nn.Module):
+ """Frozen Qwen3-Embedding-0.6B encoder.
+
+ Produces a single fixed-size vector per sequence via mean pooling.
+ Uses its OWN tokenizer (separate from OLMo's).
+ """
+
+ def __init__(self, model_id: str = "Qwen/Qwen3-Embedding-0.6B", device: Optional[torch.device] = None):
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
+ self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True)
+ self.model.eval()
+ for p in self.model.parameters():
+ p.requires_grad_(False)
+
+ self.embed_dim: int = self.model.config.hidden_size # 1024 for Qwen3-Embedding-0.6B
+
+ if device is not None:
+ self.model = self.model.to(device)
+
+ def encode(self, raw_texts: list[str], prefix: str = "") -> torch.Tensor:
+ """Encode raw text strings to pooled embeddings.
+
+ Args:
+ raw_texts: list of raw text strings (one per sequence in batch)
+ prefix: optional prefix for Qwen input (default: "" — no prefix)
+
+ Returns:
+ pooled: [batch, embed_dim] — mean-pooled embedding per sequence
+ """
+ if prefix:
+ raw_texts = [prefix + t for t in raw_texts]
+
+ device = next(self.model.parameters()).device
+ inputs = self.tokenizer(
+ raw_texts,
+ padding=True,
+ truncation=True,
+ max_length=8192,
+ return_tensors="pt",
+ ).to(device)
+
+ with torch.no_grad():
+ outputs = self.model(**inputs)
+
+ # Mean pooling over sequence dimension (masking padding tokens)
+ attention_mask = inputs["attention_mask"].unsqueeze(-1) # [B, S, 1]
+ last_hidden = outputs.last_hidden_state # [B, S, embed_dim]
+ pooled = (last_hidden * attention_mask).sum(dim=1) / attention_mask.sum(dim=1).clamp(min=1e-8)
+ # pooled: [B, embed_dim]
+
+ return pooled
+
+
+class PredictorMLP(nn.Module):
+ """Trainable MLP decoder with low-rank output heads.
+
+ Maps Qwen embedding → logit matrix Z = UV^T ∈ R^{256×256}.
+ See CLAUDE.md §2.3 for architecture.
+ """
+
+ def __init__(self, input_dim: int, hidden_dim: int = 1024, rank: int = 32, num_nodes: int = 256):
+ super().__init__()
+ self.rank = rank
+ self.num_nodes = num_nodes
+
+ self.trunk = nn.Sequential(
+ nn.Linear(input_dim, hidden_dim),
+ nn.GELU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.GELU(),
+ )
+ self.head_U = nn.Linear(hidden_dim, num_nodes * rank)
+ self.head_V = nn.Linear(hidden_dim, num_nodes * rank)
+
+ def forward(self, e: torch.Tensor) -> torch.Tensor:
+ """Map embedding to logit matrix.
+
+ Args:
+ e: [batch, input_dim] — pooled Qwen embedding
+
+ Returns:
+ Z: [batch, 256, 256] — raw logit matrix (before mask/Gumbel)
+ """
+ h = self.trunk(e) # [batch, hidden_dim]
+ U = self.head_U(h).view(-1, self.num_nodes, self.rank) # [B, 256, r]
+ V = self.head_V(h).view(-1, self.num_nodes, self.rank) # [B, 256, r]
+ Z = torch.bmm(U, V.transpose(-1, -2)) # [B, 256, 256]
+ return Z
+
+
+def gumbel_sigmoid(
+ Z_masked: torch.Tensor,
+ tau: float,
+ mode: str = "train",
+) -> torch.Tensor:
+ """Apply Gumbel-Sigmoid relaxation to masked logits.
+
+ Three modes (CLAUDE.md §2.3):
+ - "train": Gumbel noise + temperature → differentiable continuous relaxation
+ - "eval_soft": σ(Z/τ) — deterministic soft gates, no noise
+ - "eval_hard": (Z > 0).float() — deterministic binary 0/1
+
+ Args:
+ Z_masked: [batch, 256, 256] — logits with invalid positions at -1e9
+ tau: temperature (τ > 0 for train/eval_soft)
+ mode: one of "train", "eval_soft", "eval_hard"
+
+ Returns:
+ A: [batch, 256, 256] — gate values in [0, 1] (or {0, 1} for hard mode)
+ """
+ if mode == "train":
+ # Sample from Logistic(0, 1): G = log(U) - log(1-U), U ~ Uniform(0,1)
+ U = torch.rand_like(Z_masked).clamp(1e-8, 1 - 1e-8)
+ G = torch.log(U) - torch.log(1 - U)
+ return torch.sigmoid((Z_masked + G) / tau)
+ elif mode == "eval_soft":
+ return torch.sigmoid(Z_masked / tau)
+ elif mode == "eval_hard":
+ return (Z_masked > 0).float()
+ else:
+ raise ValueError(f"Unknown Gumbel-Sigmoid mode: {mode}. Expected: train, eval_soft, eval_hard")
+
+
+def cascading_gate(
+ A: torch.Tensor,
+ k: float = 5.0,
+ hard: bool = False,
+) -> torch.Tensor:
+ """Apply cascading activation gate: kill outgoing edges from disconnected nodes.
+
+ One-pass computation (not layer-by-layer):
+ 1. Compute incoming sums: inc_j = Σ_i A[i, j]
+ 2. Compute gates: g_j = σ(k * inc_j) (soft) or (inc_j > 0) (hard)
+ 3. Apply: A[j, :] *= g_j
+
+ Uses ORIGINAL A values for incoming sums (before any gates applied).
+ See CLAUDE.md §2.3 cascading gate section.
+
+ Args:
+ A: [batch, 256, 256] — gate matrix
+ k: steepness of sigmoid gate (default: 5.0)
+ hard: if True, use binary gates (for eval_hard mode)
+
+ Returns:
+ A_gated: [batch, 256, 256] — A with cascading gate applied
+ """
+ # Incoming sum per node: [batch, 256]
+ inc = A.sum(dim=1) # sum over source dimension (rows)
+
+ if hard:
+ g = (inc > 0).float() # [batch, 256]
+ else:
+ g = torch.sigmoid(k * inc) # [batch, 256]
+
+ # Gate outgoing edges: A[j, :] *= g[j]
+ # g: [B, 256] → [B, 256, 1] to broadcast with A: [B, 256, 256]
+ return A * g.unsqueeze(2)
+
+
+class StructurePredictor(nn.Module):
+ """Full structure predictor: raw text → adjacency matrix A.
+
+ Pipeline: raw_text → [Qwen encoder] → e → [MLP] → Z → [mask] → [Gumbel] → [cascade] → A
+
+ The only trainable component is the PredictorMLP. Qwen is frozen.
+ """
+
+ def __init__(
+ self,
+ qwen_model_id: str = "Qwen/Qwen3-Embedding-0.6B",
+ hidden_dim: int = 1024,
+ rank: int = 32,
+ cascading_gate_k: float = 5.0,
+ qwen_input_prefix: str = "",
+ num_nodes: int = 256,
+ heads_per_layer: int = 16,
+ device: Optional[torch.device] = None,
+ ):
+ super().__init__()
+ self.cascading_gate_k = cascading_gate_k
+ self.qwen_input_prefix = qwen_input_prefix
+ self.num_nodes = num_nodes
+ self.heads_per_layer = heads_per_layer
+
+ # Frozen Qwen encoder
+ self.qwen_encoder = QwenEncoder(model_id=qwen_model_id, device=device)
+
+ # Trainable MLP decoder
+ self.mlp = PredictorMLP(
+ input_dim=self.qwen_encoder.embed_dim,
+ hidden_dim=hidden_dim,
+ rank=rank,
+ num_nodes=num_nodes,
+ )
+
+ # Block-upper-triangular mask (registered as buffer — moves with .to(device))
+ self.register_buffer(
+ 'dag_mask',
+ create_block_upper_triangular_mask(num_nodes, heads_per_layer),
+ )
+
+ # Move all components to device (buffers + trainable MLP)
+ if device is not None:
+ self.to(device)
+
+ def forward(
+ self,
+ raw_texts: list[str],
+ tau: float,
+ mode: str = "train",
+ ) -> torch.Tensor:
+ """Predict adjacency matrix A from raw text.
+
+ Args:
+ raw_texts: list of raw text strings (batch)
+ tau: Gumbel-Sigmoid temperature
+ mode: "train", "eval_soft", or "eval_hard"
+
+ Returns:
+ A: [batch, 256, 256] — block-upper-triangular gate matrix
+ """
+ # Step 1: Qwen encoding (frozen, no grad)
+ e = self.qwen_encoder.encode(raw_texts, prefix=self.qwen_input_prefix)
+ # e: [batch, qwen_embed_dim]
+
+ # Step 2: MLP decoder → logits
+ Z = self.mlp(e) # [batch, 256, 256]
+ assert Z.shape[1:] == (self.num_nodes, self.num_nodes), \
+ f"Z shape mismatch: expected (*, {self.num_nodes}, {self.num_nodes}), got {Z.shape}"
+
+ # Step 3: Apply block-upper-triangular mask
+ # Force invalid positions to -inf so sigmoid → 0
+ mask = self.dag_mask # [256, 256]
+ Z_masked = Z * mask + (-1e9) * (1 - mask)
+
+ # Step 4: Gumbel-Sigmoid
+ hard = (mode == "eval_hard")
+ A = gumbel_sigmoid(Z_masked, tau=tau, mode=mode)
+
+ # Step 5: Cascading activation gate
+ A = cascading_gate(A, k=self.cascading_gate_k, hard=hard)
+
+ assert A.shape[1:] == (self.num_nodes, self.num_nodes), \
+ f"A shape mismatch: expected (*, {self.num_nodes}, {self.num_nodes}), got {A.shape}"
+
+ return A
+
+ def get_trainable_parameters(self) -> list[nn.Parameter]:
+ """Return only the trainable MLP parameters (not Qwen)."""
+ return list(self.mlp.parameters())
diff --git a/src/training/__init__.py b/src/training/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/training/__init__.py
diff --git a/src/training/checkpointing.py b/src/training/checkpointing.py
new file mode 100644
index 0000000..9ff02df
--- /dev/null
+++ b/src/training/checkpointing.py
@@ -0,0 +1,92 @@
+"""Checkpoint save/load for predictor + optimizer + schedule state.
+
+Only saves trainable components (predictor MLP, optimizer, schedule state).
+Frozen models (OLMo, Qwen) are not checkpointed — they load from HuggingFace.
+"""
+
+from __future__ import annotations
+
+import os
+from typing import Any, Optional
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+
+def save_checkpoint(
+ save_dir: str,
+ step: int,
+ predictor: nn.Module,
+ optimizer: optim.Optimizer,
+ scheduler: Any,
+ best_eval_nll: float,
+ extra: Optional[dict] = None,
+) -> str:
+ """Save training checkpoint.
+
+ Args:
+ save_dir: directory to save checkpoint
+ step: current global step
+ predictor: the structure predictor (only MLP params are saved)
+ optimizer: AdamW optimizer
+ scheduler: LR scheduler
+ best_eval_nll: best eval NLL so far
+ extra: any additional state to save
+
+ Returns:
+ path: path to saved checkpoint
+ """
+ os.makedirs(save_dir, exist_ok=True)
+ path = os.path.join(save_dir, f"checkpoint_step{step}.pt")
+
+ state = {
+ "step": step,
+ "predictor_state_dict": predictor.state_dict(),
+ "optimizer_state_dict": optimizer.state_dict(),
+ "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None,
+ "best_eval_nll": best_eval_nll,
+ }
+ if extra:
+ state.update(extra)
+
+ torch.save(state, path)
+ print(f"Checkpoint saved: {path}")
+ return path
+
+
+def load_checkpoint(
+ path: str,
+ predictor: nn.Module,
+ optimizer: Optional[optim.Optimizer] = None,
+ scheduler: Optional[Any] = None,
+ device: Optional[torch.device] = None,
+) -> dict:
+ """Load training checkpoint.
+
+ Args:
+ path: path to checkpoint file
+ predictor: structure predictor to load weights into
+ optimizer: optimizer to restore state (optional — skip for eval)
+ scheduler: LR scheduler to restore state (optional)
+ device: device to map tensors to
+
+ Returns:
+ state dict with step, best_eval_nll, and any extras
+ """
+ map_location = device if device is not None else "cpu"
+ state = torch.load(path, map_location=map_location)
+
+ predictor.load_state_dict(state["predictor_state_dict"])
+ print(f"Predictor state loaded from {path}")
+
+ if optimizer is not None and "optimizer_state_dict" in state:
+ optimizer.load_state_dict(state["optimizer_state_dict"])
+
+ if scheduler is not None and state.get("scheduler_state_dict") is not None:
+ scheduler.load_state_dict(state["scheduler_state_dict"])
+
+ return {
+ "step": state["step"],
+ "best_eval_nll": state.get("best_eval_nll", float("inf")),
+ }
diff --git a/src/training/schedulers.py b/src/training/schedulers.py
new file mode 100644
index 0000000..7cda3b4
--- /dev/null
+++ b/src/training/schedulers.py
@@ -0,0 +1,35 @@
+"""Schedule functions for temperature (τ), sparsity (λ), and learning rate.
+
+All schedules are deterministic functions of the current step.
+See CLAUDE.md §3.1 for exact formulas.
+"""
+
+from __future__ import annotations
+
+import math
+
+
+def tau_schedule(step: int, total_steps: int, tau_init: float, tau_final: float) -> float:
+ """Cosine annealing for Gumbel-Sigmoid temperature.
+
+ τ(t) = τ_f + 0.5(τ_i - τ_f)(1 + cos(πt/T))
+
+ Starts at tau_init, ends at tau_final.
+ """
+ if total_steps <= 0:
+ return tau_final
+ progress = min(step / total_steps, 1.0)
+ return tau_final + 0.5 * (tau_init - tau_final) * (1 + math.cos(math.pi * progress))
+
+
+def lambda_schedule(step: int, total_steps: int, lambda_max: float, warmup_frac: float = 0.2) -> float:
+ """Linear ramp for sparsity coefficient.
+
+ Ramps linearly from 0 to lambda_max over first warmup_frac of training.
+ """
+ if lambda_max == 0.0:
+ return 0.0
+ warmup_steps = int(total_steps * warmup_frac)
+ if warmup_steps <= 0:
+ return lambda_max
+ return lambda_max * min(step / warmup_steps, 1.0)
diff --git a/src/training/trainer.py b/src/training/trainer.py
new file mode 100644
index 0000000..6be949e
--- /dev/null
+++ b/src/training/trainer.py
@@ -0,0 +1,465 @@
+"""Training loop for DAGFormer Phase 1.
+
+Pure PyTorch + DDP. Only the predictor MLP is trainable.
+See CLAUDE.md §3.1 for training specification.
+"""
+
+from __future__ import annotations
+
+import math
+import os
+import warnings
+from dataclasses import dataclass, field
+from typing import Any, Optional
+
+import torch
+import torch.distributed as dist
+import torch.nn as nn
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.optim.lr_scheduler import CosineAnnealingLR
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from src.data.dolma import build_eval_dataloader, build_train_dataloader
+from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A
+from src.model.predictor import StructurePredictor
+from src.training.checkpointing import load_checkpoint, save_checkpoint
+from src.training.schedulers import lambda_schedule, tau_schedule
+from src.utils.logging import finish_wandb, init_wandb, log_metrics
+from src.utils.topology import compute_topology_metrics
+
+import torch.nn.functional as F
+
+
+@dataclass
+class TrainConfig:
+ """Training configuration. Parsed from YAML."""
+
+ # Model
+ olmo_model_id: str = "allenai/OLMo-2-0425-1B"
+ qwen_model_id: str = "Qwen/Qwen3-Embedding-0.6B"
+
+ # Predictor
+ predictor_hidden_dim: int = 1024
+ predictor_rank: int = 32
+ cascading_gate_k: float = 5.0
+ input_norm: str = "none"
+ qwen_input_prefix: str = ""
+
+ # Data
+ dataset: str = "allenai/dolma"
+ dataset_name: str = "v1_7"
+ seq_len: int = 1024
+ batch_size: int = 4
+ micro_batch_size: int = 4
+
+ # Eval
+ eval_skip: int = 1_000_000
+ eval_size: int = 1_000
+
+ # Training
+ total_steps: int = 1000
+ lr: float = 3e-4
+ weight_decay: float = 0.01
+ optimizer: str = "adamw"
+
+ # Schedules
+ tau_init: float = 5.0
+ tau_final: float = 0.2
+ tau_schedule: str = "cosine"
+ lambda_max: float = 0.0
+ lambda_warmup_frac: float = 0.2
+
+ # Logging
+ wandb_project: str = "dagformer"
+ wandb_run_name: str = "default"
+ log_every: int = 10
+ eval_every: int = 100
+
+ # Checkpointing
+ save_every: int = 500
+ save_dir: str = "checkpoints/"
+ resume_from: str = ""
+
+ # Hardware
+ num_gpus: int = 1
+
+ @classmethod
+ def from_yaml(cls, path: str) -> TrainConfig:
+ import yaml
+ with open(path) as f:
+ data = yaml.safe_load(f)
+
+ known_keys = {f.name for f in cls.__dataclass_fields__.values()}
+ unknown = set(data.keys()) - known_keys
+ if unknown:
+ raise ValueError(f"Unknown config keys: {unknown}")
+
+ # Coerce types to match dataclass field annotations
+ import dataclasses
+ for f in dataclasses.fields(cls):
+ if f.name in data:
+ expected_type = f.type
+ if expected_type == "float" or expected_type is float:
+ data[f.name] = float(data[f.name])
+ elif expected_type == "int" or expected_type is int:
+ data[f.name] = int(data[f.name])
+
+ return cls(**data)
+
+ def to_dict(self) -> dict[str, Any]:
+ from dataclasses import asdict
+ return asdict(self)
+
+
+class Trainer:
+ """DAGFormer Phase 1 training loop."""
+
+ def __init__(self, config: TrainConfig, local_rank: int = 0, world_size: int = 1):
+ self.config = config
+ self.local_rank = local_rank
+ self.world_size = world_size
+ self.is_main = (local_rank == 0)
+ self.device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu")
+
+ # Gradient accumulation
+ assert config.batch_size % config.micro_batch_size == 0, \
+ f"batch_size ({config.batch_size}) must be divisible by micro_batch_size ({config.micro_batch_size})"
+ self.accum_steps = config.batch_size // config.micro_batch_size
+
+ self._build_models()
+ self._build_optimizer()
+ self._build_data()
+ self._setup_logging()
+
+ self.global_step = 0
+ self.best_eval_nll = float("inf")
+ self.collapse_counter = 0 # consecutive steps with collapsed A
+
+ # Resume from checkpoint if specified
+ if config.resume_from:
+ state = load_checkpoint(
+ config.resume_from,
+ self.predictor,
+ self.optimizer,
+ self.lr_scheduler,
+ device=self.device,
+ )
+ self.global_step = state["step"]
+ self.best_eval_nll = state["best_eval_nll"]
+ if self.is_main:
+ print(f"Resumed from step {self.global_step}")
+
+ def _build_models(self) -> None:
+ config = self.config
+
+ # Load frozen OLMo2-1B
+ if self.is_main:
+ print(f"Loading {config.olmo_model_id}...")
+ self.olmo = AutoModelForCausalLM.from_pretrained(
+ config.olmo_model_id,
+ torch_dtype=torch.bfloat16,
+ ).to(self.device)
+ self.olmo.eval()
+ for p in self.olmo.parameters():
+ p.requires_grad_(False)
+
+ # Verify frozen
+ assert all(not p.requires_grad for p in self.olmo.parameters()), \
+ "OLMo parameters should be frozen"
+
+ # OLMo tokenizer
+ self.olmo_tokenizer = AutoTokenizer.from_pretrained(config.olmo_model_id)
+
+ # DAGFormer OLMo wrapper
+ self.olmo_wrapper = DAGFormerOLMo(
+ model=self.olmo,
+ input_norm=config.input_norm,
+ ).to(self.device)
+
+ # Structure predictor (includes frozen Qwen + trainable MLP)
+ if self.is_main:
+ print(f"Loading {config.qwen_model_id}...")
+ self.predictor = StructurePredictor(
+ qwen_model_id=config.qwen_model_id,
+ hidden_dim=config.predictor_hidden_dim,
+ rank=config.predictor_rank,
+ cascading_gate_k=config.cascading_gate_k,
+ qwen_input_prefix=config.qwen_input_prefix,
+ device=self.device,
+ )
+
+ # DDP wrapping — only the predictor MLP (trainable component)
+ if self.world_size > 1:
+ self.predictor.mlp = DDP(
+ self.predictor.mlp,
+ device_ids=[self.local_rank],
+ )
+
+ if self.is_main:
+ trainable = sum(p.numel() for p in self.predictor.get_trainable_parameters())
+ norm_params = sum(p.numel() for p in self.olmo_wrapper.input_normalizer.parameters())
+ print(f"Trainable params: predictor={trainable:,}, norm={norm_params:,}")
+
+ def _build_optimizer(self) -> None:
+ config = self.config
+
+ # Collect all trainable parameters
+ params = list(self.predictor.get_trainable_parameters())
+ params.extend(self.olmo_wrapper.input_normalizer.parameters())
+
+ assert config.optimizer == "adamw", f"Only adamw supported, got {config.optimizer}"
+ self.optimizer = torch.optim.AdamW(
+ params,
+ lr=config.lr,
+ betas=(0.9, 0.999),
+ weight_decay=config.weight_decay,
+ )
+ self.lr_scheduler = CosineAnnealingLR(
+ self.optimizer,
+ T_max=config.total_steps,
+ eta_min=0.0,
+ )
+
+ def _build_data(self) -> None:
+ config = self.config
+
+ self.train_loader = build_train_dataloader(
+ olmo_tokenizer=self.olmo_tokenizer,
+ seq_len=config.seq_len,
+ batch_size=config.micro_batch_size,
+ dataset_name=config.dataset,
+ dataset_version=config.dataset_name,
+ rank=self.local_rank,
+ world_size=self.world_size,
+ )
+
+ # Eval data: only on main rank
+ if self.is_main:
+ cache_path = os.path.join(config.save_dir, "eval_cache.pt")
+ self.eval_batches = build_eval_dataloader(
+ olmo_tokenizer=self.olmo_tokenizer,
+ seq_len=config.seq_len,
+ batch_size=config.micro_batch_size,
+ dataset_name=config.dataset,
+ dataset_version=config.dataset_name,
+ eval_skip=config.eval_skip,
+ eval_size=config.eval_size,
+ cache_path=cache_path,
+ )
+ else:
+ self.eval_batches = []
+
+ def _setup_logging(self) -> None:
+ if self.is_main:
+ self.wandb_run = init_wandb(
+ project=self.config.wandb_project,
+ run_name=self.config.wandb_run_name,
+ config=self.config.to_dict(),
+ )
+ else:
+ self.wandb_run = None
+
+ def train(self) -> None:
+ """Main training loop."""
+ config = self.config
+ train_iter = iter(self.train_loader)
+
+ if self.is_main:
+ print(f"\nStarting training: {config.total_steps} steps")
+ print(f" batch_size={config.batch_size}, micro_batch={config.micro_batch_size}, accum={self.accum_steps}")
+ print(f" tau: {config.tau_init} → {config.tau_final}")
+ print(f" lambda: 0 → {config.lambda_max}")
+ print()
+
+ while self.global_step < config.total_steps:
+ # Schedule values
+ tau = tau_schedule(self.global_step, config.total_steps, config.tau_init, config.tau_final)
+ lam = lambda_schedule(self.global_step, config.total_steps, config.lambda_max, config.lambda_warmup_frac)
+
+ # Gradient accumulation
+ self.optimizer.zero_grad()
+ total_nll = 0.0
+ total_sparsity = 0.0
+ total_mean_A = 0.0
+
+ for micro_step in range(self.accum_steps):
+ try:
+ batch = next(train_iter)
+ except StopIteration:
+ train_iter = iter(self.train_loader)
+ batch = next(train_iter)
+
+ olmo_ids = batch["olmo_ids"].to(self.device)
+ olmo_labels = batch["olmo_labels"].to(self.device)
+ raw_texts = batch["raw_text"]
+
+ # Forward: predictor → A → OLMo → loss
+ A = self.predictor(raw_texts, tau=tau, mode="train")
+ logits = self.olmo_wrapper(olmo_ids, A)
+
+ # NLL loss
+ nll = F.cross_entropy(
+ logits[:, :-1].contiguous().view(-1, self.olmo.config.vocab_size),
+ olmo_labels[:, 1:].contiguous().view(-1),
+ )
+
+ # Sparsity loss
+ sparsity = lam * A.mean()
+ loss = (nll + sparsity) / self.accum_steps
+
+ loss.backward()
+
+ total_nll += nll.item() / self.accum_steps
+ total_sparsity += sparsity.item() / self.accum_steps
+ total_mean_A += A.mean().item() / self.accum_steps
+
+ # Optimizer step
+ self.optimizer.step()
+ self.lr_scheduler.step()
+
+ # Logging
+ if self.is_main and self.global_step % config.log_every == 0:
+ # Gradient norm
+ grad_norm = 0.0
+ for p in self.predictor.get_trainable_parameters():
+ if p.grad is not None:
+ grad_norm += p.grad.data.norm(2).item() ** 2
+ for p in self.olmo_wrapper.input_normalizer.parameters():
+ if p.grad is not None:
+ grad_norm += p.grad.data.norm(2).item() ** 2
+ grad_norm = grad_norm ** 0.5
+
+ metrics = {
+ "train/nll": total_nll,
+ "train/sparsity_loss": total_sparsity,
+ "train/total_loss": total_nll + total_sparsity,
+ "topology/mean_A": total_mean_A,
+ "schedule/tau": tau,
+ "schedule/lambda": lam,
+ "grad/predictor_norm": grad_norm,
+ }
+ log_metrics(metrics, self.global_step, self.wandb_run)
+
+ # Collapse alarm
+ if total_mean_A < 0.01 or total_mean_A > 0.99:
+ self.collapse_counter += 1
+ if self.collapse_counter >= 100:
+ warnings.warn(
+ f"COLLAPSE ALARM: mean_A={total_mean_A:.4f} for {self.collapse_counter} steps"
+ )
+ else:
+ self.collapse_counter = 0
+
+ # Eval
+ if self.is_main and self.global_step > 0 and self.global_step % config.eval_every == 0:
+ self._run_eval(tau)
+
+ # Checkpoint
+ if self.is_main and self.global_step > 0 and self.global_step % config.save_every == 0:
+ save_checkpoint(
+ config.save_dir,
+ self.global_step,
+ self.predictor,
+ self.optimizer,
+ self.lr_scheduler,
+ self.best_eval_nll,
+ )
+
+ self.global_step += 1
+
+ # Barrier for multi-GPU sync
+ if self.world_size > 1:
+ dist.barrier()
+
+ # Final eval and checkpoint
+ if self.is_main:
+ self._run_eval(tau_schedule(config.total_steps, config.total_steps, config.tau_init, config.tau_final))
+ save_checkpoint(
+ config.save_dir,
+ self.global_step,
+ self.predictor,
+ self.optimizer,
+ self.lr_scheduler,
+ self.best_eval_nll,
+ )
+
+ finish_wandb(self.wandb_run)
+ if self.is_main:
+ print("\nTraining complete.")
+
+ @torch.no_grad()
+ def _run_eval(self, tau: float) -> None:
+ """Run evaluation on held-out data (rank 0 only).
+
+ Reports: eval/nll_soft, eval/nll_hard, eval/nll_baseline
+ """
+ if not self.eval_batches:
+ return
+
+ self.predictor.eval()
+
+ nll_soft_total = 0.0
+ nll_hard_total = 0.0
+ nll_baseline_total = 0.0
+ n_batches = 0
+ topology_metrics_accum: dict[str, float] = {}
+
+ for batch in self.eval_batches:
+ olmo_ids = batch["olmo_ids"].to(self.device)
+ olmo_labels = batch["olmo_labels"].to(self.device)
+ raw_texts = batch["raw_text"]
+
+ vocab_size = self.olmo.config.vocab_size
+
+ # Eval soft
+ A_soft = self.predictor(raw_texts, tau=tau, mode="eval_soft")
+ logits_soft = self.olmo_wrapper(olmo_ids, A_soft)
+ nll_soft = F.cross_entropy(
+ logits_soft[:, :-1].contiguous().view(-1, vocab_size),
+ olmo_labels[:, 1:].contiguous().view(-1),
+ )
+ nll_soft_total += nll_soft.item()
+
+ # Eval hard
+ A_hard = self.predictor(raw_texts, tau=tau, mode="eval_hard")
+ logits_hard = self.olmo_wrapper(olmo_ids, A_hard)
+ nll_hard = F.cross_entropy(
+ logits_hard[:, :-1].contiguous().view(-1, vocab_size),
+ olmo_labels[:, 1:].contiguous().view(-1),
+ )
+ nll_hard_total += nll_hard.item()
+
+ # Baseline (A=1)
+ A_ones = create_all_ones_A(olmo_ids.shape[0]).to(self.device)
+ logits_base = self.olmo_wrapper(olmo_ids, A_ones)
+ nll_base = F.cross_entropy(
+ logits_base[:, :-1].contiguous().view(-1, vocab_size),
+ olmo_labels[:, 1:].contiguous().view(-1),
+ )
+ nll_baseline_total += nll_base.item()
+
+ # Topology metrics (from soft A)
+ topo = compute_topology_metrics(A_soft)
+ for k, v in topo.items():
+ topology_metrics_accum[k] = topology_metrics_accum.get(k, 0.0) + v
+
+ n_batches += 1
+
+ # Average
+ metrics = {
+ "eval/nll_soft": nll_soft_total / n_batches,
+ "eval/nll_hard": nll_hard_total / n_batches,
+ "eval/nll_baseline": nll_baseline_total / n_batches,
+ }
+ for k, v in topology_metrics_accum.items():
+ metrics[k] = v / n_batches
+
+ log_metrics(metrics, self.global_step, self.wandb_run)
+
+ # Track best
+ eval_nll = metrics["eval/nll_soft"]
+ if eval_nll < self.best_eval_nll:
+ self.best_eval_nll = eval_nll
+ print(f" New best eval NLL: {eval_nll:.4f}")
+
+ self.predictor.train()
diff --git a/src/utils/__init__.py b/src/utils/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/utils/__init__.py
diff --git a/src/utils/logging.py b/src/utils/logging.py
new file mode 100644
index 0000000..a9b3b10
--- /dev/null
+++ b/src/utils/logging.py
@@ -0,0 +1,63 @@
+"""Wandb integration for DAGFormer training.
+
+Logs all metrics from CLAUDE.md §7.
+"""
+
+from __future__ import annotations
+
+from typing import Any, Optional
+
+import wandb
+
+
+def init_wandb(
+ project: str,
+ run_name: str,
+ config: dict[str, Any],
+) -> Optional[wandb.sdk.wandb_run.Run]:
+ """Initialize wandb run.
+
+ Args:
+ project: wandb project name
+ run_name: run display name
+ config: full config dict to log
+
+ Returns:
+ wandb run object (or None if wandb init fails)
+ """
+ try:
+ run = wandb.init(
+ project=project,
+ name=run_name,
+ config=config,
+ )
+ return run
+ except Exception as e:
+ print(f"WARNING: wandb init failed: {e}. Continuing without wandb.")
+ return None
+
+
+def log_metrics(
+ metrics: dict[str, float],
+ step: int,
+ run: Optional[wandb.sdk.wandb_run.Run] = None,
+) -> None:
+ """Log metrics to wandb.
+
+ Args:
+ metrics: dict of metric_name → value
+ step: global training step
+ run: wandb run (if None, prints to stdout instead)
+ """
+ if run is not None:
+ wandb.log(metrics, step=step)
+ else:
+ # Fallback: print metrics
+ parts = [f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}" for k, v in metrics.items()]
+ print(f"[step {step}] {', '.join(parts)}")
+
+
+def finish_wandb(run: Optional[wandb.sdk.wandb_run.Run] = None) -> None:
+ """Finish wandb run."""
+ if run is not None:
+ wandb.finish()
diff --git a/src/utils/topology.py b/src/utils/topology.py
new file mode 100644
index 0000000..3a2e0a8
--- /dev/null
+++ b/src/utils/topology.py
@@ -0,0 +1,87 @@
+"""A matrix analysis utilities for logging and monitoring.
+
+Computes topology metrics per CLAUDE.md §7.
+"""
+
+from __future__ import annotations
+
+import torch
+
+from src.model.olmo_graph import create_block_upper_triangular_mask
+
+
+def compute_topology_metrics(A: torch.Tensor, num_heads: int = 16) -> dict[str, float]:
+ """Compute topology metrics from adjacency matrix A.
+
+ Args:
+ A: [batch, 256, 256] — gate matrix
+ num_heads: heads per layer (16 for OLMo2-1B)
+
+ Returns:
+ dict with metrics: mean_A, seq_gate_frac, hyp_gate_frac, jaccard_var
+ """
+ batch = A.shape[0]
+ num_nodes = A.shape[1]
+ mask = create_block_upper_triangular_mask(num_nodes, num_heads).to(A.device)
+
+ # Mean A over valid entries
+ valid_entries = A[:, mask.bool()] # [batch, 30720]
+ mean_A = valid_entries.mean().item()
+
+ # Classify connections: adjacent (layer diff == 1) vs skip (layer diff > 1)
+ layer_idx = torch.arange(num_nodes, device=A.device) // num_heads
+ layer_diff = layer_idx.unsqueeze(0) - layer_idx.unsqueeze(1) # [256, 256]
+ # layer_diff[i,j] = layer(j) - layer(i)
+
+ adj_mask = (layer_diff == 1) & mask.bool() # adjacent-layer connections
+ skip_mask = (layer_diff > 1) & mask.bool() # skip connections
+
+ # Fraction of gates > 0.5
+ adj_vals = A[:, adj_mask] # [batch, 3840]
+ skip_vals = A[:, skip_mask] # [batch, 26880]
+
+ seq_gate_frac = (adj_vals > 0.5).float().mean().item() if adj_vals.numel() > 0 else 0.0
+ hyp_gate_frac = (skip_vals > 0.5).float().mean().item() if skip_vals.numel() > 0 else 0.0
+
+ # Jaccard variance across batch
+ jaccard_var = _jaccard_variance(A, mask).item() if batch > 1 else 0.0
+
+ return {
+ "topology/mean_A": mean_A,
+ "topology/seq_gate_frac": seq_gate_frac,
+ "topology/hyp_gate_frac": hyp_gate_frac,
+ "topology/jaccard_var": jaccard_var,
+ }
+
+
+def _jaccard_variance(A: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
+ """Compute variance of pairwise Jaccard similarity across batch.
+
+ Measures how context-dependent the topologies are.
+ Higher variance = more context-dependent routing.
+ """
+ batch = A.shape[0]
+ if batch < 2:
+ return torch.tensor(0.0)
+
+ # Binarize at 0.5 threshold for Jaccard
+ binary = (A > 0.5).float()
+ valid = mask.bool()
+
+ # Extract valid entries: [batch, num_valid]
+ entries = binary[:, valid]
+
+ # Pairwise Jaccard
+ jaccards = []
+ for i in range(batch):
+ for j in range(i + 1, batch):
+ intersection = (entries[i] * entries[j]).sum()
+ union = ((entries[i] + entries[j]) > 0).float().sum()
+ jaccard = intersection / union.clamp(min=1.0)
+ jaccards.append(jaccard)
+
+ if not jaccards:
+ return torch.tensor(0.0)
+
+ jaccards = torch.stack(jaccards)
+ return jaccards.var()