From 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:00:39 -0600 Subject: Initial implementation: DAGFormer Phase 1 - olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 --- src/__init__.py | 0 src/data/__init__.py | 0 src/data/dolma.py | 226 ++++++++++++++++++++ src/model/__init__.py | 0 src/model/olmo_graph.py | 397 ++++++++++++++++++++++++++++++++++++ src/model/pipeline.py | 144 +++++++++++++ src/model/predictor.py | 275 +++++++++++++++++++++++++ src/training/__init__.py | 0 src/training/checkpointing.py | 92 +++++++++ src/training/schedulers.py | 35 ++++ src/training/trainer.py | 465 ++++++++++++++++++++++++++++++++++++++++++ src/utils/__init__.py | 0 src/utils/logging.py | 63 ++++++ src/utils/topology.py | 87 ++++++++ 14 files changed, 1784 insertions(+) create mode 100644 src/__init__.py create mode 100644 src/data/__init__.py create mode 100644 src/data/dolma.py create mode 100644 src/model/__init__.py create mode 100644 src/model/olmo_graph.py create mode 100644 src/model/pipeline.py create mode 100644 src/model/predictor.py create mode 100644 src/training/__init__.py create mode 100644 src/training/checkpointing.py create mode 100644 src/training/schedulers.py create mode 100644 src/training/trainer.py create mode 100644 src/utils/__init__.py create mode 100644 src/utils/logging.py create mode 100644 src/utils/topology.py (limited to 'src') diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/data/dolma.py b/src/data/dolma.py new file mode 100644 index 0000000..4e2baaf --- /dev/null +++ b/src/data/dolma.py @@ -0,0 +1,226 @@ +"""Streaming dataloader for Dolma v1.7 with sequence packing. + +Produces packed sequences of fixed length for both OLMo and Qwen tokenizers. +See CLAUDE.md §3.1.1 for sequence packing specification. +""" + +from __future__ import annotations + +import os +from typing import Iterator, Optional + +import torch +from datasets import load_dataset +from torch.utils.data import IterableDataset +from transformers import AutoTokenizer + + +class DolmaPackedDataset(IterableDataset): + """Streaming Dolma dataset with sequence packing. + + Concatenates documents with EOS separators, then chunks into fixed-length + sequences. No padding — every token contributes to NLL. + + Each sample yields: + olmo_ids: [seq_len] — OLMo input token IDs + olmo_labels: [seq_len] — shifted labels (next-token prediction) + raw_text: str — decoded text for Qwen encoder + """ + + def __init__( + self, + olmo_tokenizer: AutoTokenizer, + seq_len: int = 1024, + dataset_name: str = "allenai/dolma", + dataset_version: str = "v1_7", + rank: int = 0, + world_size: int = 1, + max_samples: Optional[int] = None, + ): + super().__init__() + self.olmo_tokenizer = olmo_tokenizer + self.seq_len = seq_len + self.dataset_name = dataset_name + self.dataset_version = dataset_version + self.rank = rank + self.world_size = world_size + self.max_samples = max_samples + + self.eos_id = olmo_tokenizer.eos_token_id + assert self.eos_id is not None, "OLMo tokenizer must have an EOS token" + + def __iter__(self) -> Iterator[dict]: + """Yield packed sequences from Dolma stream.""" + try: + dataset = load_dataset( + self.dataset_name, + name=self.dataset_version, + split="train", + streaming=True, + trust_remote_code=True, + ) + except Exception: + # Fallback if specific version not available + dataset = load_dataset( + self.dataset_name, + split="train", + streaming=True, + trust_remote_code=True, + ) + + # Shard for multi-GPU + if self.world_size > 1: + dataset = dataset.shard(num_shards=self.world_size, index=self.rank) + + buffer: list[int] = [] + sample_count = 0 + + for doc in dataset: + if self.max_samples is not None and sample_count >= self.max_samples: + break + + text = doc.get("text", "") + if not text.strip(): + continue + + tokens = self.olmo_tokenizer(text, add_special_tokens=False)["input_ids"] + buffer.extend(tokens) + buffer.append(self.eos_id) + + # Yield packed sequences as buffer fills + while len(buffer) >= self.seq_len + 1: + chunk = buffer[:self.seq_len + 1] + buffer = buffer[self.seq_len + 1:] + + olmo_ids = torch.tensor(chunk[:self.seq_len], dtype=torch.long) + olmo_labels = torch.tensor(chunk[1:self.seq_len + 1], dtype=torch.long) + raw_text = self.olmo_tokenizer.decode(chunk[:self.seq_len], skip_special_tokens=False) + + yield { + "olmo_ids": olmo_ids, + "olmo_labels": olmo_labels, + "raw_text": raw_text, + } + sample_count += 1 + + if self.max_samples is not None and sample_count >= self.max_samples: + break + + +def build_train_dataloader( + olmo_tokenizer: AutoTokenizer, + seq_len: int = 1024, + batch_size: int = 4, + dataset_name: str = "allenai/dolma", + dataset_version: str = "v1_7", + rank: int = 0, + world_size: int = 1, + num_workers: int = 0, +) -> torch.utils.data.DataLoader: + """Build training dataloader with sequence packing.""" + dataset = DolmaPackedDataset( + olmo_tokenizer=olmo_tokenizer, + seq_len=seq_len, + dataset_name=dataset_name, + dataset_version=dataset_version, + rank=rank, + world_size=world_size, + ) + return torch.utils.data.DataLoader( + dataset, + batch_size=batch_size, + num_workers=num_workers, + collate_fn=_collate_packed, + ) + + +def build_eval_dataloader( + olmo_tokenizer: AutoTokenizer, + seq_len: int = 1024, + batch_size: int = 4, + dataset_name: str = "allenai/dolma", + dataset_version: str = "v1_7", + eval_skip: int = 1_000_000, + eval_size: int = 1_000, + cache_path: Optional[str] = None, +) -> list[dict]: + """Build eval batches (cached in memory). + + Skips eval_skip examples in the stream, then takes eval_size packed sequences. + Caches to disk to avoid repeated skip on restart. + """ + # Try loading from cache + if cache_path and os.path.exists(cache_path): + print(f"Loading eval cache from {cache_path}") + return torch.load(cache_path) + + print(f"Building eval set (skip={eval_skip}, size={eval_size})...") + + try: + dataset = load_dataset( + dataset_name, + name=dataset_version, + split="train", + streaming=True, + trust_remote_code=True, + ) + except Exception: + dataset = load_dataset( + dataset_name, + split="train", + streaming=True, + trust_remote_code=True, + ) + + # Skip to held-out region + dataset = dataset.skip(eval_skip) + + eos_id = olmo_tokenizer.eos_token_id + buffer: list[int] = [] + eval_samples: list[dict] = [] + + for doc in dataset: + if len(eval_samples) >= eval_size: + break + + text = doc.get("text", "") + if not text.strip(): + continue + + tokens = olmo_tokenizer(text, add_special_tokens=False)["input_ids"] + buffer.extend(tokens) + buffer.append(eos_id) + + while len(buffer) >= seq_len + 1 and len(eval_samples) < eval_size: + chunk = buffer[:seq_len + 1] + buffer = buffer[seq_len + 1:] + eval_samples.append({ + "olmo_ids": torch.tensor(chunk[:seq_len], dtype=torch.long), + "olmo_labels": torch.tensor(chunk[1:seq_len + 1], dtype=torch.long), + "raw_text": olmo_tokenizer.decode(chunk[:seq_len], skip_special_tokens=False), + }) + + print(f"Built {len(eval_samples)} eval sequences") + + # Batch the samples + eval_batches = [] + for i in range(0, len(eval_samples), batch_size): + batch_items = eval_samples[i:i + batch_size] + eval_batches.append(_collate_packed(batch_items)) + + # Cache to disk + if cache_path: + os.makedirs(os.path.dirname(cache_path) or ".", exist_ok=True) + torch.save(eval_batches, cache_path) + print(f"Eval cache saved to {cache_path}") + + return eval_batches + + +def _collate_packed(batch: list[dict]) -> dict: + """Collate packed samples into a batch dict.""" + return { + "olmo_ids": torch.stack([s["olmo_ids"] for s in batch]), + "olmo_labels": torch.stack([s["olmo_labels"] for s in batch]), + "raw_text": [s["raw_text"] for s in batch], + } diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/model/olmo_graph.py b/src/model/olmo_graph.py new file mode 100644 index 0000000..af9f848 --- /dev/null +++ b/src/model/olmo_graph.py @@ -0,0 +1,397 @@ +"""Modified OLMo2-1B forward pass with adjacency matrix A injection. + +This module implements the core DAGFormer modification: per-head input +assembly controlled by a 256x256 adjacency matrix A. Each head receives +its own input (a gated combination of prior heads' outputs), rather than +the shared residual stream. + +Key design decisions: +- Uses proportional attribution for post_attention_layernorm decomposition + (OLMo2 is post-norm, not pre-norm as CLAUDE.md §2.1 assumes) +- Concatenate→q_norm→split pattern for per-head Q/K normalization +- Weight slices via .view() (not .clone()) for Phase 2 compatibility +- When A=all-ones and input_norm="none", output is identical to vanilla OLMo2 +""" + +from __future__ import annotations + +import math +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from transformers import AutoModelForCausalLM +from transformers.models.olmo2.modeling_olmo2 import ( + apply_rotary_pos_emb, +) + + +def create_block_upper_triangular_mask(num_nodes: int = 256, heads_per_layer: int = 16) -> torch.Tensor: + """Create block-upper-triangular mask based on LAYER indices. + + mask[i,j] = 1 iff layer(j) > layer(i), i.e. j//16 > i//16. + Same-layer and backward connections are 0. + Do NOT use torch.triu() — it allows same-layer connections. + + Returns: + mask: [num_nodes, num_nodes] float tensor with 0s and 1s + """ + layer_idx = torch.arange(num_nodes) // heads_per_layer + mask = (layer_idx.unsqueeze(1) < layer_idx.unsqueeze(0)).float() # [256, 256] + return mask + + +class InputNormalizer(nn.Module): + """Normalization methods for gated head output sums (CLAUDE.md §6.1). + + Applied ONLY to the gated_sum component, not the base (embedding + MLPs). + """ + + def __init__(self, method: str, model_dim: int = 2048, num_nodes: int = 256): + super().__init__() + self.method = method + self.model_dim = model_dim + + if method == "none": + pass + elif method == "gate_mean": + pass # no learnable params + elif method == "rms_post": + self.norm = nn.RMSNorm(model_dim) + elif method == "ln_post": + self.norm = nn.LayerNorm(model_dim) + elif method == "rms_pre": + self.norms = nn.ModuleList([nn.RMSNorm(model_dim) for _ in range(num_nodes)]) + else: + raise ValueError(f"Unknown input_norm method: {method}") + + def forward( + self, + gated_sum: torch.Tensor, + A_slice: Optional[torch.Tensor] = None, + prior_head_outs: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + """Normalize the gated sum of prior head outputs. + + Args: + gated_sum: [batch, num_heads, seq, model_dim] — gated sum for this layer's heads + A_slice: [batch, num_prior_nodes, num_heads] — gate values (for gate_mean) + prior_head_outs: [batch, num_prior_nodes, seq, model_dim] — for rms_pre + Returns: + Normalized gated_sum, same shape + """ + if self.method == "none": + return gated_sum + + elif self.method == "gate_mean": + assert A_slice is not None + # Sum of gates per target head: [batch, num_heads] + gate_sum = A_slice.sum(dim=1) # [batch, num_heads] + # Divide gated_sum by gate_sum (avoid div by zero) + divisor = gate_sum.clamp(min=1e-8) # [batch, num_heads] + return gated_sum / divisor[:, :, None, None] # broadcast over [seq, model_dim] + + elif self.method == "rms_post": + return self.norm(gated_sum) + + elif self.method == "ln_post": + return self.norm(gated_sum) + + elif self.method == "rms_pre": + # Apply per-source-node RMSNorm before gating, then recompute gated sum + # This requires prior_head_outs and A_slice + assert prior_head_outs is not None and A_slice is not None + num_prior = prior_head_outs.shape[1] + # Normalize each source node's output + normed_sources = [] + for i in range(num_prior): + normed_sources.append(self.norms[i](prior_head_outs[:, i])) + normed_sources = torch.stack(normed_sources, dim=1) # [B, num_prior, S, D] + # Recompute gated sum with normed sources + return torch.einsum('bih,bisd->bhsd', A_slice, normed_sources) + + raise ValueError(f"Unknown method: {self.method}") + + +class DAGFormerOLMo(nn.Module): + """Wraps OLMo2-1B with adjacency matrix A injection for per-head routing. + + When A is all-ones and input_norm is "none", this produces output + identical to vanilla OLMo2-1B (baseline reproduction invariant). + """ + + def __init__( + self, + model: AutoModelForCausalLM, + input_norm: str = "none", + num_layers: int = 16, + num_heads: int = 16, + ): + super().__init__() + self.olmo = model + self.num_layers = num_layers + self.num_heads = num_heads + self.num_nodes = num_layers * num_heads + self.model_dim = model.config.hidden_size + self.head_dim = self.model_dim // num_heads + self.rms_norm_eps = model.config.rms_norm_eps + + # Runtime assertions + assert model.config.num_attention_heads == num_heads, \ + f"Expected {num_heads} attention heads, got {model.config.num_attention_heads}" + assert model.config.num_key_value_heads == num_heads, \ + f"Expected MHA ({num_heads} KV heads), got {model.config.num_key_value_heads} — GQA detected" + + # Verify no bias + layer0_attn = model.model.layers[0].self_attn + assert layer0_attn.o_proj.bias is None, \ + "Expected no bias in o_proj — update per-head splitting if bias exists" + + # Block-upper-triangular mask: [256, 256] + self.register_buffer('dag_mask', create_block_upper_triangular_mask(self.num_nodes, num_heads)) + + # Input normalization + self.input_normalizer = InputNormalizer(input_norm, self.model_dim, self.num_nodes) + + # Attention scaling factor + self.scaling = self.head_dim ** -0.5 + + def _get_head_weight_views(self, layer_idx: int) -> dict: + """Get per-head weight views for a given layer. + + Uses .view() which returns views of the same storage — no copy, + gradients flow through for Phase 2 compatibility. + """ + layer = self.olmo.model.layers[layer_idx] + attn = layer.self_attn + + # Q, K, V projections: [model_dim, model_dim] → [num_heads, head_dim, model_dim] + W_q = attn.q_proj.weight.view(self.num_heads, self.head_dim, self.model_dim) + W_k = attn.k_proj.weight.view(self.num_heads, self.head_dim, self.model_dim) + W_v = attn.v_proj.weight.view(self.num_heads, self.head_dim, self.model_dim) + + # O projection: [model_dim, model_dim] + # Split by INPUT dimension (columns): [model_dim, num_heads, head_dim] + # Permute to [num_heads, model_dim, head_dim] for einsum + W_o = attn.o_proj.weight.view(self.model_dim, self.num_heads, self.head_dim) + W_o = W_o.permute(1, 0, 2) # [num_heads, model_dim, head_dim] + + return { + 'W_q': W_q, 'W_k': W_k, 'W_v': W_v, 'W_o': W_o, + 'q_norm': attn.q_norm, + 'k_norm': attn.k_norm, + 'post_attn_norm': layer.post_attention_layernorm, + 'post_ff_norm': layer.post_feedforward_layernorm, + 'mlp': layer.mlp, + } + + def forward( + self, + olmo_ids: torch.Tensor, + A: torch.Tensor, + ) -> torch.Tensor: + """Modified OLMo2-1B forward pass with per-head routing via A. + + Args: + olmo_ids: [batch, seq_len] — tokenized by OLMo's tokenizer + A: [batch, 256, 256] — block-upper-triangular gate matrix + + Returns: + logits: [batch, seq_len, vocab_size] + """ + batch, seq_len = olmo_ids.shape + device = olmo_ids.device + + assert A.shape == (batch, self.num_nodes, self.num_nodes), \ + f"A shape mismatch: expected ({batch}, {self.num_nodes}, {self.num_nodes}), got {A.shape}" + + # Cast A to model dtype (predictor outputs float32, OLMo uses bfloat16) + model_dtype = self.olmo.model.embed_tokens.weight.dtype + A = A.to(dtype=model_dtype) + + # Token embedding + embedding = self.olmo.model.embed_tokens(olmo_ids) # [B, S, D] + + # Position embeddings (computed once, shared across all layers) + position_ids = torch.arange(seq_len, device=device).unsqueeze(0) # [1, S] + position_embeddings = self.olmo.model.rotary_emb(embedding, position_ids) + cos, sin = position_embeddings + + # Causal attention mask: [1, 1, S, S] + causal_mask = torch.zeros(1, 1, seq_len, seq_len, device=device, dtype=embedding.dtype) + causal_mask.masked_fill_( + torch.triu(torch.ones(seq_len, seq_len, device=device, dtype=torch.bool), diagonal=1), + float('-inf'), + ) + + # Storage for outputs across layers + # We accumulate head_outputs as a list of [B, 16, S, D] tensors (one per layer) + all_head_outputs: list[torch.Tensor] = [] # each: [B, 16, S, D] + mlp_outputs: list[torch.Tensor] = [] # each: [B, S, D] + + # Running base: embedding + accumulated MLP outputs (for per-head assembly) + base = embedding.clone() # [B, S, D] + # Accumulated ungated attention outputs (for MLP input) + attn_accumulated = torch.zeros_like(embedding) # [B, S, D] + + for l in range(self.num_layers): + weights = self._get_head_weight_views(l) + + # === ASSEMBLE PER-HEAD INPUTS === + if l == 0: + # Layer 0: all heads see only the embedding (no prior heads or MLPs) + assembled = embedding.unsqueeze(1).expand(-1, self.num_heads, -1, -1) + # assembled: [B, 16, S, D] + else: + # base_l = embedding + Σ_{l'bhsd', A_slice, prior_head_outs) + # gated_sum: [B, 16, S, D] + + # Apply input normalization (only to gated_sum, not base) + if self.input_normalizer.method == "rms_pre": + gated_sum = self.input_normalizer( + gated_sum, A_slice=A_slice, prior_head_outs=prior_head_outs + ) + elif self.input_normalizer.method == "gate_mean": + gated_sum = self.input_normalizer(gated_sum, A_slice=A_slice) + else: + gated_sum = self.input_normalizer(gated_sum) + + # assembled = base + gated_sum + assembled = base.unsqueeze(1) + gated_sum # [B, 16, S, D] + + # === PER-HEAD Q/K/V PROJECTION === + W_q, W_k, W_v, W_o = weights['W_q'], weights['W_k'], weights['W_v'], weights['W_o'] + + # Per-head projections via einsum + # assembled: [B, H, S, D], W_q: [H, head_dim, D] + q_per_head = torch.einsum('bhsd,hod->bhso', assembled, W_q) # [B, H, S, head_dim] + k_per_head = torch.einsum('bhsd,hod->bhso', assembled, W_k) + v_per_head = torch.einsum('bhsd,hod->bhso', assembled, W_v) + + # === Q_NORM / K_NORM === + # OLMo2 applies RMSNorm to concatenated Q/K (2048-dim) AFTER projection. + # Concat all heads → norm → split back. + # When A=1 (all heads same input), this equals q_norm(q_proj(shared_input)). + q_concat = rearrange(q_per_head, 'b h s d -> b s (h d)') # [B, S, 2048] + q_normed = weights['q_norm'](q_concat) + q_per_head = rearrange(q_normed, 'b s (h d) -> b h s d', h=self.num_heads) + + k_concat = rearrange(k_per_head, 'b h s d -> b s (h d)') + k_normed = weights['k_norm'](k_concat) + k_per_head = rearrange(k_normed, 'b s (h d) -> b h s d', h=self.num_heads) + + # V has NO norm in OLMo2 + + # === APPLY RoPE === + q_per_head, k_per_head = apply_rotary_pos_emb(q_per_head, k_per_head, cos, sin) + + # === ATTENTION COMPUTATION === + # q,k,v: [B, H, S, head_dim] + attn_weights = torch.matmul(q_per_head, k_per_head.transpose(-2, -1)) * self.scaling + # attn_weights: [B, H, S, S] + attn_weights = attn_weights + causal_mask # [1, 1, S, S] broadcasts + attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(q_per_head.dtype) + attn_values = torch.matmul(attn_weights, v_per_head) # [B, H, S, head_dim] + + # === PER-HEAD O_PROJ === + # attn_values: [B, H, S, head_dim], W_o: [H, model_dim, head_dim] + raw_head_outs = torch.einsum('bhsd,hod->bhso', attn_values, W_o) + # raw_head_outs: [B, H, S, model_dim] + + # === PROPORTIONAL ATTRIBUTION WITH POST_ATTN_NORM === + # OLMo2 applies post_attention_layernorm to the COMBINED attention output. + # RMSNorm(Σ_h x_h) = weight * (Σ_h x_h) / RMS(Σ_h x_h) + # = Σ_h [weight * x_h / RMS(Σ_h x_h)] + # We attribute each head's normed output proportionally. + raw_sum = raw_head_outs.sum(dim=1) # [B, S, D] + # Compute RMS of the sum + variance = raw_sum.to(torch.float32).pow(2).mean(-1, keepdim=True) + rms = torch.sqrt(variance + self.rms_norm_eps) # [B, S, 1] + # Apply post_attn_norm weight and scale + norm_weight = weights['post_attn_norm'].weight # [D] + # head_output[h] = norm_weight * raw_head_out[h] / rms + scale = (norm_weight / rms).unsqueeze(1) # [B, 1, S, D] + head_outputs_l = raw_head_outs.float() * scale # [B, H, S, D] + head_outputs_l = head_outputs_l.to(raw_head_outs.dtype) + + # Store for routing to later layers + all_head_outputs.append(head_outputs_l) + + # === MLP COMPUTATION (standard, ungated) === + # attn_normed = Σ_h head_output[l,h] = post_attn_norm(raw_sum) + attn_normed = head_outputs_l.sum(dim=1) # [B, S, D] + + # MLP input = full residual stream (embedding + all prior MLPs + all attn up to current) + # In vanilla OLMo2: mlp_input = residual + post_attn_norm(attn_output) + # where residual includes ALL prior components (embedding + prior MLPs + prior attns) + mlp_in = base + attn_accumulated + attn_normed + + # Update accumulated attention for next layer + attn_accumulated = attn_accumulated + attn_normed + + # MLP forward + post_feedforward_layernorm + mlp_raw = weights['mlp'](mlp_in) + mlp_output_l = weights['post_ff_norm'](mlp_raw) + mlp_outputs.append(mlp_output_l) + + # Update running base for next layer + # base_{l+1} = base_l + mlp_output_l = embedding + Σ_{l'<=l} mlp_output[l'] + base = base + mlp_output_l + + # === FINAL OUTPUT === + # final_state = embedding + Σ_l mlp_output[l] + Σ_l Σ_h head_output[l,h] + # = embedding + Σ_l [post_attn_norm(attn_out_l) + post_ff_norm(mlp_out_l)] + # 'base' = embedding + Σ_l mlp_output[l] + # 'attn_accumulated' = Σ_l attn_output[l] (ungated sum of all attention outputs) + final_state = base + attn_accumulated + + # Apply final norm and lm_head + final_state = self.olmo.model.norm(final_state) + logits = self.olmo.lm_head(final_state) + + return logits + + +def compute_vanilla_nll( + model: AutoModelForCausalLM, + input_ids: torch.Tensor, + labels: torch.Tensor, +) -> torch.Tensor: + """Compute NLL using vanilla OLMo2 forward pass (no A injection). + + Used for baseline comparison in sanity checks. + """ + with torch.no_grad(): + outputs = model(input_ids=input_ids) + logits = outputs.logits + nll = F.cross_entropy( + logits[:, :-1].contiguous().view(-1, logits.size(-1)), + labels[:, 1:].contiguous().view(-1), + ) + return nll + + +def create_all_ones_A(batch_size: int, num_nodes: int = 256, num_heads: int = 16) -> torch.Tensor: + """Create A matrix with 1.0 for all valid (cross-layer) entries. + + When used with input_norm="none", this should reproduce vanilla OLMo2. + """ + A = torch.zeros(batch_size, num_nodes, num_nodes) + mask = create_block_upper_triangular_mask(num_nodes, num_heads) + A = A + mask.unsqueeze(0) # broadcast mask to batch + return A diff --git a/src/model/pipeline.py b/src/model/pipeline.py new file mode 100644 index 0000000..bbfcabf --- /dev/null +++ b/src/model/pipeline.py @@ -0,0 +1,144 @@ +"""End-to-end DAGFormer pipeline: raw text → predictor → A → OLMo → NLL. + +Glues the structure predictor (Qwen + MLP) with the modified OLMo forward. +This is what the trainer calls. See CLAUDE.md §5 for file responsibilities. +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import AutoModelForCausalLM, AutoTokenizer + +from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A +from src.model.predictor import StructurePredictor + + +class DAGFormerPipeline(nn.Module): + """Combines StructurePredictor + DAGFormerOLMo into a single forward pass. + + Forward: raw_text → predictor → A → modified OLMo → logits → NLL + + Only the predictor's MLP params are trainable. OLMo and Qwen are frozen. + """ + + def __init__( + self, + olmo_model_id: str = "allenai/OLMo-2-0425-1B", + qwen_model_id: str = "Qwen/Qwen3-Embedding-0.6B", + predictor_hidden_dim: int = 1024, + predictor_rank: int = 32, + cascading_gate_k: float = 5.0, + input_norm: str = "none", + qwen_input_prefix: str = "", + device: Optional[torch.device] = None, + ): + super().__init__() + + # Load frozen OLMo2-1B + olmo = AutoModelForCausalLM.from_pretrained( + olmo_model_id, + torch_dtype=torch.bfloat16, + ) + olmo.eval() + for p in olmo.parameters(): + p.requires_grad_(False) + + # Wrap OLMo with DAGFormer modification + self.olmo_wrapper = DAGFormerOLMo(model=olmo, input_norm=input_norm) + + # Structure predictor (Qwen encoder + MLP decoder) + self.predictor = StructurePredictor( + qwen_model_id=qwen_model_id, + hidden_dim=predictor_hidden_dim, + rank=predictor_rank, + cascading_gate_k=cascading_gate_k, + qwen_input_prefix=qwen_input_prefix, + device=device, + ) + + self.vocab_size = olmo.config.vocab_size + + if device is not None: + self.to(device) + + def forward( + self, + raw_texts: list[str], + olmo_ids: torch.Tensor, + olmo_labels: torch.Tensor, + tau: float, + lambda_sparsity: float = 0.0, + mode: str = "train", + ) -> dict[str, torch.Tensor]: + """Full forward pass: text → A → logits → loss. + + Args: + raw_texts: list of raw text strings (batch) + olmo_ids: [batch, seq_len] — OLMo tokenized input + olmo_labels: [batch, seq_len] — shifted labels for NLL + tau: Gumbel-Sigmoid temperature + lambda_sparsity: sparsity coefficient (λ_t) + mode: "train", "eval_soft", or "eval_hard" + + Returns: + dict with keys: + "total_loss": nll + lambda * mean(A) — what the optimizer sees + "nll": cross-entropy loss + "sparsity_loss": lambda * mean(A) + "A": [batch, 256, 256] adjacency matrix + """ + # Step 1: Predict adjacency matrix + A = self.predictor(raw_texts, tau=tau, mode=mode) + # A: [batch, 256, 256] + + # Step 2: Modified OLMo forward with A + logits = self.olmo_wrapper(olmo_ids, A) + # logits: [batch, seq_len, vocab_size] + + # Step 3: Compute NLL (next-token prediction) + # Shift: logits[:, :-1] predicts labels[:, 1:] + nll = F.cross_entropy( + logits[:, :-1].contiguous().view(-1, self.vocab_size), + olmo_labels[:, 1:].contiguous().view(-1), + ) + + # Step 4: Sparsity regularization + sparsity_loss = lambda_sparsity * A.mean() + total_loss = nll + sparsity_loss + + return { + "total_loss": total_loss, + "nll": nll, + "sparsity_loss": sparsity_loss, + "A": A, + } + + def forward_baseline( + self, + olmo_ids: torch.Tensor, + olmo_labels: torch.Tensor, + ) -> torch.Tensor: + """Forward with A=all-ones (baseline reproduction). + + Used for eval/nll_baseline metric. + """ + batch = olmo_ids.shape[0] + A = create_all_ones_A(batch).to(olmo_ids.device) + with torch.no_grad(): + logits = self.olmo_wrapper(olmo_ids, A) + nll = F.cross_entropy( + logits[:, :-1].contiguous().view(-1, self.vocab_size), + olmo_labels[:, 1:].contiguous().view(-1), + ) + return nll + + def get_trainable_parameters(self) -> list[nn.Parameter]: + """Return only the trainable parameters (predictor MLP + any norm params).""" + params = list(self.predictor.get_trainable_parameters()) + # Also include input normalizer params if they exist + params.extend(self.olmo_wrapper.input_normalizer.parameters()) + return params diff --git a/src/model/predictor.py b/src/model/predictor.py new file mode 100644 index 0000000..0bc0ae3 --- /dev/null +++ b/src/model/predictor.py @@ -0,0 +1,275 @@ +"""Structure predictor: Qwen encoder + MLP decoder + Gumbel-Sigmoid + cascading gate. + +Takes raw text, produces a 256x256 adjacency matrix A controlling per-head +routing in OLMo2-1B. See CLAUDE.md §2.3 for full specification. + +Components: +- QwenEncoder: frozen Qwen3-Embedding-0.6B, mean-pooled to single vector +- PredictorMLP: trainable MLP with low-rank output heads (U, V → Z = UV^T) +- Gumbel-Sigmoid: differentiable relaxation of binary gates (3 modes) +- Cascading gate: kill outgoing edges from disconnected nodes +- Block-upper-triangular mask: enforce DAG constraint (layer(j) > layer(i)) +""" + +from __future__ import annotations + +from typing import Optional + +import torch +import torch.nn as nn +from transformers import AutoModel, AutoTokenizer + +from src.model.olmo_graph import create_block_upper_triangular_mask + + +class QwenEncoder(nn.Module): + """Frozen Qwen3-Embedding-0.6B encoder. + + Produces a single fixed-size vector per sequence via mean pooling. + Uses its OWN tokenizer (separate from OLMo's). + """ + + def __init__(self, model_id: str = "Qwen/Qwen3-Embedding-0.6B", device: Optional[torch.device] = None): + super().__init__() + self.tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True) + self.model = AutoModel.from_pretrained(model_id, trust_remote_code=True) + self.model.eval() + for p in self.model.parameters(): + p.requires_grad_(False) + + self.embed_dim: int = self.model.config.hidden_size # 1024 for Qwen3-Embedding-0.6B + + if device is not None: + self.model = self.model.to(device) + + def encode(self, raw_texts: list[str], prefix: str = "") -> torch.Tensor: + """Encode raw text strings to pooled embeddings. + + Args: + raw_texts: list of raw text strings (one per sequence in batch) + prefix: optional prefix for Qwen input (default: "" — no prefix) + + Returns: + pooled: [batch, embed_dim] — mean-pooled embedding per sequence + """ + if prefix: + raw_texts = [prefix + t for t in raw_texts] + + device = next(self.model.parameters()).device + inputs = self.tokenizer( + raw_texts, + padding=True, + truncation=True, + max_length=8192, + return_tensors="pt", + ).to(device) + + with torch.no_grad(): + outputs = self.model(**inputs) + + # Mean pooling over sequence dimension (masking padding tokens) + attention_mask = inputs["attention_mask"].unsqueeze(-1) # [B, S, 1] + last_hidden = outputs.last_hidden_state # [B, S, embed_dim] + pooled = (last_hidden * attention_mask).sum(dim=1) / attention_mask.sum(dim=1).clamp(min=1e-8) + # pooled: [B, embed_dim] + + return pooled + + +class PredictorMLP(nn.Module): + """Trainable MLP decoder with low-rank output heads. + + Maps Qwen embedding → logit matrix Z = UV^T ∈ R^{256×256}. + See CLAUDE.md §2.3 for architecture. + """ + + def __init__(self, input_dim: int, hidden_dim: int = 1024, rank: int = 32, num_nodes: int = 256): + super().__init__() + self.rank = rank + self.num_nodes = num_nodes + + self.trunk = nn.Sequential( + nn.Linear(input_dim, hidden_dim), + nn.GELU(), + nn.Linear(hidden_dim, hidden_dim), + nn.GELU(), + ) + self.head_U = nn.Linear(hidden_dim, num_nodes * rank) + self.head_V = nn.Linear(hidden_dim, num_nodes * rank) + + def forward(self, e: torch.Tensor) -> torch.Tensor: + """Map embedding to logit matrix. + + Args: + e: [batch, input_dim] — pooled Qwen embedding + + Returns: + Z: [batch, 256, 256] — raw logit matrix (before mask/Gumbel) + """ + h = self.trunk(e) # [batch, hidden_dim] + U = self.head_U(h).view(-1, self.num_nodes, self.rank) # [B, 256, r] + V = self.head_V(h).view(-1, self.num_nodes, self.rank) # [B, 256, r] + Z = torch.bmm(U, V.transpose(-1, -2)) # [B, 256, 256] + return Z + + +def gumbel_sigmoid( + Z_masked: torch.Tensor, + tau: float, + mode: str = "train", +) -> torch.Tensor: + """Apply Gumbel-Sigmoid relaxation to masked logits. + + Three modes (CLAUDE.md §2.3): + - "train": Gumbel noise + temperature → differentiable continuous relaxation + - "eval_soft": σ(Z/τ) — deterministic soft gates, no noise + - "eval_hard": (Z > 0).float() — deterministic binary 0/1 + + Args: + Z_masked: [batch, 256, 256] — logits with invalid positions at -1e9 + tau: temperature (τ > 0 for train/eval_soft) + mode: one of "train", "eval_soft", "eval_hard" + + Returns: + A: [batch, 256, 256] — gate values in [0, 1] (or {0, 1} for hard mode) + """ + if mode == "train": + # Sample from Logistic(0, 1): G = log(U) - log(1-U), U ~ Uniform(0,1) + U = torch.rand_like(Z_masked).clamp(1e-8, 1 - 1e-8) + G = torch.log(U) - torch.log(1 - U) + return torch.sigmoid((Z_masked + G) / tau) + elif mode == "eval_soft": + return torch.sigmoid(Z_masked / tau) + elif mode == "eval_hard": + return (Z_masked > 0).float() + else: + raise ValueError(f"Unknown Gumbel-Sigmoid mode: {mode}. Expected: train, eval_soft, eval_hard") + + +def cascading_gate( + A: torch.Tensor, + k: float = 5.0, + hard: bool = False, +) -> torch.Tensor: + """Apply cascading activation gate: kill outgoing edges from disconnected nodes. + + One-pass computation (not layer-by-layer): + 1. Compute incoming sums: inc_j = Σ_i A[i, j] + 2. Compute gates: g_j = σ(k * inc_j) (soft) or (inc_j > 0) (hard) + 3. Apply: A[j, :] *= g_j + + Uses ORIGINAL A values for incoming sums (before any gates applied). + See CLAUDE.md §2.3 cascading gate section. + + Args: + A: [batch, 256, 256] — gate matrix + k: steepness of sigmoid gate (default: 5.0) + hard: if True, use binary gates (for eval_hard mode) + + Returns: + A_gated: [batch, 256, 256] — A with cascading gate applied + """ + # Incoming sum per node: [batch, 256] + inc = A.sum(dim=1) # sum over source dimension (rows) + + if hard: + g = (inc > 0).float() # [batch, 256] + else: + g = torch.sigmoid(k * inc) # [batch, 256] + + # Gate outgoing edges: A[j, :] *= g[j] + # g: [B, 256] → [B, 256, 1] to broadcast with A: [B, 256, 256] + return A * g.unsqueeze(2) + + +class StructurePredictor(nn.Module): + """Full structure predictor: raw text → adjacency matrix A. + + Pipeline: raw_text → [Qwen encoder] → e → [MLP] → Z → [mask] → [Gumbel] → [cascade] → A + + The only trainable component is the PredictorMLP. Qwen is frozen. + """ + + def __init__( + self, + qwen_model_id: str = "Qwen/Qwen3-Embedding-0.6B", + hidden_dim: int = 1024, + rank: int = 32, + cascading_gate_k: float = 5.0, + qwen_input_prefix: str = "", + num_nodes: int = 256, + heads_per_layer: int = 16, + device: Optional[torch.device] = None, + ): + super().__init__() + self.cascading_gate_k = cascading_gate_k + self.qwen_input_prefix = qwen_input_prefix + self.num_nodes = num_nodes + self.heads_per_layer = heads_per_layer + + # Frozen Qwen encoder + self.qwen_encoder = QwenEncoder(model_id=qwen_model_id, device=device) + + # Trainable MLP decoder + self.mlp = PredictorMLP( + input_dim=self.qwen_encoder.embed_dim, + hidden_dim=hidden_dim, + rank=rank, + num_nodes=num_nodes, + ) + + # Block-upper-triangular mask (registered as buffer — moves with .to(device)) + self.register_buffer( + 'dag_mask', + create_block_upper_triangular_mask(num_nodes, heads_per_layer), + ) + + # Move all components to device (buffers + trainable MLP) + if device is not None: + self.to(device) + + def forward( + self, + raw_texts: list[str], + tau: float, + mode: str = "train", + ) -> torch.Tensor: + """Predict adjacency matrix A from raw text. + + Args: + raw_texts: list of raw text strings (batch) + tau: Gumbel-Sigmoid temperature + mode: "train", "eval_soft", or "eval_hard" + + Returns: + A: [batch, 256, 256] — block-upper-triangular gate matrix + """ + # Step 1: Qwen encoding (frozen, no grad) + e = self.qwen_encoder.encode(raw_texts, prefix=self.qwen_input_prefix) + # e: [batch, qwen_embed_dim] + + # Step 2: MLP decoder → logits + Z = self.mlp(e) # [batch, 256, 256] + assert Z.shape[1:] == (self.num_nodes, self.num_nodes), \ + f"Z shape mismatch: expected (*, {self.num_nodes}, {self.num_nodes}), got {Z.shape}" + + # Step 3: Apply block-upper-triangular mask + # Force invalid positions to -inf so sigmoid → 0 + mask = self.dag_mask # [256, 256] + Z_masked = Z * mask + (-1e9) * (1 - mask) + + # Step 4: Gumbel-Sigmoid + hard = (mode == "eval_hard") + A = gumbel_sigmoid(Z_masked, tau=tau, mode=mode) + + # Step 5: Cascading activation gate + A = cascading_gate(A, k=self.cascading_gate_k, hard=hard) + + assert A.shape[1:] == (self.num_nodes, self.num_nodes), \ + f"A shape mismatch: expected (*, {self.num_nodes}, {self.num_nodes}), got {A.shape}" + + return A + + def get_trainable_parameters(self) -> list[nn.Parameter]: + """Return only the trainable MLP parameters (not Qwen).""" + return list(self.mlp.parameters()) diff --git a/src/training/__init__.py b/src/training/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/training/checkpointing.py b/src/training/checkpointing.py new file mode 100644 index 0000000..9ff02df --- /dev/null +++ b/src/training/checkpointing.py @@ -0,0 +1,92 @@ +"""Checkpoint save/load for predictor + optimizer + schedule state. + +Only saves trainable components (predictor MLP, optimizer, schedule state). +Frozen models (OLMo, Qwen) are not checkpointed — they load from HuggingFace. +""" + +from __future__ import annotations + +import os +from typing import Any, Optional + +import torch +import torch.nn as nn +import torch.optim as optim + + +def save_checkpoint( + save_dir: str, + step: int, + predictor: nn.Module, + optimizer: optim.Optimizer, + scheduler: Any, + best_eval_nll: float, + extra: Optional[dict] = None, +) -> str: + """Save training checkpoint. + + Args: + save_dir: directory to save checkpoint + step: current global step + predictor: the structure predictor (only MLP params are saved) + optimizer: AdamW optimizer + scheduler: LR scheduler + best_eval_nll: best eval NLL so far + extra: any additional state to save + + Returns: + path: path to saved checkpoint + """ + os.makedirs(save_dir, exist_ok=True) + path = os.path.join(save_dir, f"checkpoint_step{step}.pt") + + state = { + "step": step, + "predictor_state_dict": predictor.state_dict(), + "optimizer_state_dict": optimizer.state_dict(), + "scheduler_state_dict": scheduler.state_dict() if scheduler is not None else None, + "best_eval_nll": best_eval_nll, + } + if extra: + state.update(extra) + + torch.save(state, path) + print(f"Checkpoint saved: {path}") + return path + + +def load_checkpoint( + path: str, + predictor: nn.Module, + optimizer: Optional[optim.Optimizer] = None, + scheduler: Optional[Any] = None, + device: Optional[torch.device] = None, +) -> dict: + """Load training checkpoint. + + Args: + path: path to checkpoint file + predictor: structure predictor to load weights into + optimizer: optimizer to restore state (optional — skip for eval) + scheduler: LR scheduler to restore state (optional) + device: device to map tensors to + + Returns: + state dict with step, best_eval_nll, and any extras + """ + map_location = device if device is not None else "cpu" + state = torch.load(path, map_location=map_location) + + predictor.load_state_dict(state["predictor_state_dict"]) + print(f"Predictor state loaded from {path}") + + if optimizer is not None and "optimizer_state_dict" in state: + optimizer.load_state_dict(state["optimizer_state_dict"]) + + if scheduler is not None and state.get("scheduler_state_dict") is not None: + scheduler.load_state_dict(state["scheduler_state_dict"]) + + return { + "step": state["step"], + "best_eval_nll": state.get("best_eval_nll", float("inf")), + } diff --git a/src/training/schedulers.py b/src/training/schedulers.py new file mode 100644 index 0000000..7cda3b4 --- /dev/null +++ b/src/training/schedulers.py @@ -0,0 +1,35 @@ +"""Schedule functions for temperature (τ), sparsity (λ), and learning rate. + +All schedules are deterministic functions of the current step. +See CLAUDE.md §3.1 for exact formulas. +""" + +from __future__ import annotations + +import math + + +def tau_schedule(step: int, total_steps: int, tau_init: float, tau_final: float) -> float: + """Cosine annealing for Gumbel-Sigmoid temperature. + + τ(t) = τ_f + 0.5(τ_i - τ_f)(1 + cos(πt/T)) + + Starts at tau_init, ends at tau_final. + """ + if total_steps <= 0: + return tau_final + progress = min(step / total_steps, 1.0) + return tau_final + 0.5 * (tau_init - tau_final) * (1 + math.cos(math.pi * progress)) + + +def lambda_schedule(step: int, total_steps: int, lambda_max: float, warmup_frac: float = 0.2) -> float: + """Linear ramp for sparsity coefficient. + + Ramps linearly from 0 to lambda_max over first warmup_frac of training. + """ + if lambda_max == 0.0: + return 0.0 + warmup_steps = int(total_steps * warmup_frac) + if warmup_steps <= 0: + return lambda_max + return lambda_max * min(step / warmup_steps, 1.0) diff --git a/src/training/trainer.py b/src/training/trainer.py new file mode 100644 index 0000000..6be949e --- /dev/null +++ b/src/training/trainer.py @@ -0,0 +1,465 @@ +"""Training loop for DAGFormer Phase 1. + +Pure PyTorch + DDP. Only the predictor MLP is trainable. +See CLAUDE.md §3.1 for training specification. +""" + +from __future__ import annotations + +import math +import os +import warnings +from dataclasses import dataclass, field +from typing import Any, Optional + +import torch +import torch.distributed as dist +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +from torch.optim.lr_scheduler import CosineAnnealingLR +from transformers import AutoModelForCausalLM, AutoTokenizer + +from src.data.dolma import build_eval_dataloader, build_train_dataloader +from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A +from src.model.predictor import StructurePredictor +from src.training.checkpointing import load_checkpoint, save_checkpoint +from src.training.schedulers import lambda_schedule, tau_schedule +from src.utils.logging import finish_wandb, init_wandb, log_metrics +from src.utils.topology import compute_topology_metrics + +import torch.nn.functional as F + + +@dataclass +class TrainConfig: + """Training configuration. Parsed from YAML.""" + + # Model + olmo_model_id: str = "allenai/OLMo-2-0425-1B" + qwen_model_id: str = "Qwen/Qwen3-Embedding-0.6B" + + # Predictor + predictor_hidden_dim: int = 1024 + predictor_rank: int = 32 + cascading_gate_k: float = 5.0 + input_norm: str = "none" + qwen_input_prefix: str = "" + + # Data + dataset: str = "allenai/dolma" + dataset_name: str = "v1_7" + seq_len: int = 1024 + batch_size: int = 4 + micro_batch_size: int = 4 + + # Eval + eval_skip: int = 1_000_000 + eval_size: int = 1_000 + + # Training + total_steps: int = 1000 + lr: float = 3e-4 + weight_decay: float = 0.01 + optimizer: str = "adamw" + + # Schedules + tau_init: float = 5.0 + tau_final: float = 0.2 + tau_schedule: str = "cosine" + lambda_max: float = 0.0 + lambda_warmup_frac: float = 0.2 + + # Logging + wandb_project: str = "dagformer" + wandb_run_name: str = "default" + log_every: int = 10 + eval_every: int = 100 + + # Checkpointing + save_every: int = 500 + save_dir: str = "checkpoints/" + resume_from: str = "" + + # Hardware + num_gpus: int = 1 + + @classmethod + def from_yaml(cls, path: str) -> TrainConfig: + import yaml + with open(path) as f: + data = yaml.safe_load(f) + + known_keys = {f.name for f in cls.__dataclass_fields__.values()} + unknown = set(data.keys()) - known_keys + if unknown: + raise ValueError(f"Unknown config keys: {unknown}") + + # Coerce types to match dataclass field annotations + import dataclasses + for f in dataclasses.fields(cls): + if f.name in data: + expected_type = f.type + if expected_type == "float" or expected_type is float: + data[f.name] = float(data[f.name]) + elif expected_type == "int" or expected_type is int: + data[f.name] = int(data[f.name]) + + return cls(**data) + + def to_dict(self) -> dict[str, Any]: + from dataclasses import asdict + return asdict(self) + + +class Trainer: + """DAGFormer Phase 1 training loop.""" + + def __init__(self, config: TrainConfig, local_rank: int = 0, world_size: int = 1): + self.config = config + self.local_rank = local_rank + self.world_size = world_size + self.is_main = (local_rank == 0) + self.device = torch.device(f"cuda:{local_rank}" if torch.cuda.is_available() else "cpu") + + # Gradient accumulation + assert config.batch_size % config.micro_batch_size == 0, \ + f"batch_size ({config.batch_size}) must be divisible by micro_batch_size ({config.micro_batch_size})" + self.accum_steps = config.batch_size // config.micro_batch_size + + self._build_models() + self._build_optimizer() + self._build_data() + self._setup_logging() + + self.global_step = 0 + self.best_eval_nll = float("inf") + self.collapse_counter = 0 # consecutive steps with collapsed A + + # Resume from checkpoint if specified + if config.resume_from: + state = load_checkpoint( + config.resume_from, + self.predictor, + self.optimizer, + self.lr_scheduler, + device=self.device, + ) + self.global_step = state["step"] + self.best_eval_nll = state["best_eval_nll"] + if self.is_main: + print(f"Resumed from step {self.global_step}") + + def _build_models(self) -> None: + config = self.config + + # Load frozen OLMo2-1B + if self.is_main: + print(f"Loading {config.olmo_model_id}...") + self.olmo = AutoModelForCausalLM.from_pretrained( + config.olmo_model_id, + torch_dtype=torch.bfloat16, + ).to(self.device) + self.olmo.eval() + for p in self.olmo.parameters(): + p.requires_grad_(False) + + # Verify frozen + assert all(not p.requires_grad for p in self.olmo.parameters()), \ + "OLMo parameters should be frozen" + + # OLMo tokenizer + self.olmo_tokenizer = AutoTokenizer.from_pretrained(config.olmo_model_id) + + # DAGFormer OLMo wrapper + self.olmo_wrapper = DAGFormerOLMo( + model=self.olmo, + input_norm=config.input_norm, + ).to(self.device) + + # Structure predictor (includes frozen Qwen + trainable MLP) + if self.is_main: + print(f"Loading {config.qwen_model_id}...") + self.predictor = StructurePredictor( + qwen_model_id=config.qwen_model_id, + hidden_dim=config.predictor_hidden_dim, + rank=config.predictor_rank, + cascading_gate_k=config.cascading_gate_k, + qwen_input_prefix=config.qwen_input_prefix, + device=self.device, + ) + + # DDP wrapping — only the predictor MLP (trainable component) + if self.world_size > 1: + self.predictor.mlp = DDP( + self.predictor.mlp, + device_ids=[self.local_rank], + ) + + if self.is_main: + trainable = sum(p.numel() for p in self.predictor.get_trainable_parameters()) + norm_params = sum(p.numel() for p in self.olmo_wrapper.input_normalizer.parameters()) + print(f"Trainable params: predictor={trainable:,}, norm={norm_params:,}") + + def _build_optimizer(self) -> None: + config = self.config + + # Collect all trainable parameters + params = list(self.predictor.get_trainable_parameters()) + params.extend(self.olmo_wrapper.input_normalizer.parameters()) + + assert config.optimizer == "adamw", f"Only adamw supported, got {config.optimizer}" + self.optimizer = torch.optim.AdamW( + params, + lr=config.lr, + betas=(0.9, 0.999), + weight_decay=config.weight_decay, + ) + self.lr_scheduler = CosineAnnealingLR( + self.optimizer, + T_max=config.total_steps, + eta_min=0.0, + ) + + def _build_data(self) -> None: + config = self.config + + self.train_loader = build_train_dataloader( + olmo_tokenizer=self.olmo_tokenizer, + seq_len=config.seq_len, + batch_size=config.micro_batch_size, + dataset_name=config.dataset, + dataset_version=config.dataset_name, + rank=self.local_rank, + world_size=self.world_size, + ) + + # Eval data: only on main rank + if self.is_main: + cache_path = os.path.join(config.save_dir, "eval_cache.pt") + self.eval_batches = build_eval_dataloader( + olmo_tokenizer=self.olmo_tokenizer, + seq_len=config.seq_len, + batch_size=config.micro_batch_size, + dataset_name=config.dataset, + dataset_version=config.dataset_name, + eval_skip=config.eval_skip, + eval_size=config.eval_size, + cache_path=cache_path, + ) + else: + self.eval_batches = [] + + def _setup_logging(self) -> None: + if self.is_main: + self.wandb_run = init_wandb( + project=self.config.wandb_project, + run_name=self.config.wandb_run_name, + config=self.config.to_dict(), + ) + else: + self.wandb_run = None + + def train(self) -> None: + """Main training loop.""" + config = self.config + train_iter = iter(self.train_loader) + + if self.is_main: + print(f"\nStarting training: {config.total_steps} steps") + print(f" batch_size={config.batch_size}, micro_batch={config.micro_batch_size}, accum={self.accum_steps}") + print(f" tau: {config.tau_init} → {config.tau_final}") + print(f" lambda: 0 → {config.lambda_max}") + print() + + while self.global_step < config.total_steps: + # Schedule values + tau = tau_schedule(self.global_step, config.total_steps, config.tau_init, config.tau_final) + lam = lambda_schedule(self.global_step, config.total_steps, config.lambda_max, config.lambda_warmup_frac) + + # Gradient accumulation + self.optimizer.zero_grad() + total_nll = 0.0 + total_sparsity = 0.0 + total_mean_A = 0.0 + + for micro_step in range(self.accum_steps): + try: + batch = next(train_iter) + except StopIteration: + train_iter = iter(self.train_loader) + batch = next(train_iter) + + olmo_ids = batch["olmo_ids"].to(self.device) + olmo_labels = batch["olmo_labels"].to(self.device) + raw_texts = batch["raw_text"] + + # Forward: predictor → A → OLMo → loss + A = self.predictor(raw_texts, tau=tau, mode="train") + logits = self.olmo_wrapper(olmo_ids, A) + + # NLL loss + nll = F.cross_entropy( + logits[:, :-1].contiguous().view(-1, self.olmo.config.vocab_size), + olmo_labels[:, 1:].contiguous().view(-1), + ) + + # Sparsity loss + sparsity = lam * A.mean() + loss = (nll + sparsity) / self.accum_steps + + loss.backward() + + total_nll += nll.item() / self.accum_steps + total_sparsity += sparsity.item() / self.accum_steps + total_mean_A += A.mean().item() / self.accum_steps + + # Optimizer step + self.optimizer.step() + self.lr_scheduler.step() + + # Logging + if self.is_main and self.global_step % config.log_every == 0: + # Gradient norm + grad_norm = 0.0 + for p in self.predictor.get_trainable_parameters(): + if p.grad is not None: + grad_norm += p.grad.data.norm(2).item() ** 2 + for p in self.olmo_wrapper.input_normalizer.parameters(): + if p.grad is not None: + grad_norm += p.grad.data.norm(2).item() ** 2 + grad_norm = grad_norm ** 0.5 + + metrics = { + "train/nll": total_nll, + "train/sparsity_loss": total_sparsity, + "train/total_loss": total_nll + total_sparsity, + "topology/mean_A": total_mean_A, + "schedule/tau": tau, + "schedule/lambda": lam, + "grad/predictor_norm": grad_norm, + } + log_metrics(metrics, self.global_step, self.wandb_run) + + # Collapse alarm + if total_mean_A < 0.01 or total_mean_A > 0.99: + self.collapse_counter += 1 + if self.collapse_counter >= 100: + warnings.warn( + f"COLLAPSE ALARM: mean_A={total_mean_A:.4f} for {self.collapse_counter} steps" + ) + else: + self.collapse_counter = 0 + + # Eval + if self.is_main and self.global_step > 0 and self.global_step % config.eval_every == 0: + self._run_eval(tau) + + # Checkpoint + if self.is_main and self.global_step > 0 and self.global_step % config.save_every == 0: + save_checkpoint( + config.save_dir, + self.global_step, + self.predictor, + self.optimizer, + self.lr_scheduler, + self.best_eval_nll, + ) + + self.global_step += 1 + + # Barrier for multi-GPU sync + if self.world_size > 1: + dist.barrier() + + # Final eval and checkpoint + if self.is_main: + self._run_eval(tau_schedule(config.total_steps, config.total_steps, config.tau_init, config.tau_final)) + save_checkpoint( + config.save_dir, + self.global_step, + self.predictor, + self.optimizer, + self.lr_scheduler, + self.best_eval_nll, + ) + + finish_wandb(self.wandb_run) + if self.is_main: + print("\nTraining complete.") + + @torch.no_grad() + def _run_eval(self, tau: float) -> None: + """Run evaluation on held-out data (rank 0 only). + + Reports: eval/nll_soft, eval/nll_hard, eval/nll_baseline + """ + if not self.eval_batches: + return + + self.predictor.eval() + + nll_soft_total = 0.0 + nll_hard_total = 0.0 + nll_baseline_total = 0.0 + n_batches = 0 + topology_metrics_accum: dict[str, float] = {} + + for batch in self.eval_batches: + olmo_ids = batch["olmo_ids"].to(self.device) + olmo_labels = batch["olmo_labels"].to(self.device) + raw_texts = batch["raw_text"] + + vocab_size = self.olmo.config.vocab_size + + # Eval soft + A_soft = self.predictor(raw_texts, tau=tau, mode="eval_soft") + logits_soft = self.olmo_wrapper(olmo_ids, A_soft) + nll_soft = F.cross_entropy( + logits_soft[:, :-1].contiguous().view(-1, vocab_size), + olmo_labels[:, 1:].contiguous().view(-1), + ) + nll_soft_total += nll_soft.item() + + # Eval hard + A_hard = self.predictor(raw_texts, tau=tau, mode="eval_hard") + logits_hard = self.olmo_wrapper(olmo_ids, A_hard) + nll_hard = F.cross_entropy( + logits_hard[:, :-1].contiguous().view(-1, vocab_size), + olmo_labels[:, 1:].contiguous().view(-1), + ) + nll_hard_total += nll_hard.item() + + # Baseline (A=1) + A_ones = create_all_ones_A(olmo_ids.shape[0]).to(self.device) + logits_base = self.olmo_wrapper(olmo_ids, A_ones) + nll_base = F.cross_entropy( + logits_base[:, :-1].contiguous().view(-1, vocab_size), + olmo_labels[:, 1:].contiguous().view(-1), + ) + nll_baseline_total += nll_base.item() + + # Topology metrics (from soft A) + topo = compute_topology_metrics(A_soft) + for k, v in topo.items(): + topology_metrics_accum[k] = topology_metrics_accum.get(k, 0.0) + v + + n_batches += 1 + + # Average + metrics = { + "eval/nll_soft": nll_soft_total / n_batches, + "eval/nll_hard": nll_hard_total / n_batches, + "eval/nll_baseline": nll_baseline_total / n_batches, + } + for k, v in topology_metrics_accum.items(): + metrics[k] = v / n_batches + + log_metrics(metrics, self.global_step, self.wandb_run) + + # Track best + eval_nll = metrics["eval/nll_soft"] + if eval_nll < self.best_eval_nll: + self.best_eval_nll = eval_nll + print(f" New best eval NLL: {eval_nll:.4f}") + + self.predictor.train() diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/utils/logging.py b/src/utils/logging.py new file mode 100644 index 0000000..a9b3b10 --- /dev/null +++ b/src/utils/logging.py @@ -0,0 +1,63 @@ +"""Wandb integration for DAGFormer training. + +Logs all metrics from CLAUDE.md §7. +""" + +from __future__ import annotations + +from typing import Any, Optional + +import wandb + + +def init_wandb( + project: str, + run_name: str, + config: dict[str, Any], +) -> Optional[wandb.sdk.wandb_run.Run]: + """Initialize wandb run. + + Args: + project: wandb project name + run_name: run display name + config: full config dict to log + + Returns: + wandb run object (or None if wandb init fails) + """ + try: + run = wandb.init( + project=project, + name=run_name, + config=config, + ) + return run + except Exception as e: + print(f"WARNING: wandb init failed: {e}. Continuing without wandb.") + return None + + +def log_metrics( + metrics: dict[str, float], + step: int, + run: Optional[wandb.sdk.wandb_run.Run] = None, +) -> None: + """Log metrics to wandb. + + Args: + metrics: dict of metric_name → value + step: global training step + run: wandb run (if None, prints to stdout instead) + """ + if run is not None: + wandb.log(metrics, step=step) + else: + # Fallback: print metrics + parts = [f"{k}={v:.4f}" if isinstance(v, float) else f"{k}={v}" for k, v in metrics.items()] + print(f"[step {step}] {', '.join(parts)}") + + +def finish_wandb(run: Optional[wandb.sdk.wandb_run.Run] = None) -> None: + """Finish wandb run.""" + if run is not None: + wandb.finish() diff --git a/src/utils/topology.py b/src/utils/topology.py new file mode 100644 index 0000000..3a2e0a8 --- /dev/null +++ b/src/utils/topology.py @@ -0,0 +1,87 @@ +"""A matrix analysis utilities for logging and monitoring. + +Computes topology metrics per CLAUDE.md §7. +""" + +from __future__ import annotations + +import torch + +from src.model.olmo_graph import create_block_upper_triangular_mask + + +def compute_topology_metrics(A: torch.Tensor, num_heads: int = 16) -> dict[str, float]: + """Compute topology metrics from adjacency matrix A. + + Args: + A: [batch, 256, 256] — gate matrix + num_heads: heads per layer (16 for OLMo2-1B) + + Returns: + dict with metrics: mean_A, seq_gate_frac, hyp_gate_frac, jaccard_var + """ + batch = A.shape[0] + num_nodes = A.shape[1] + mask = create_block_upper_triangular_mask(num_nodes, num_heads).to(A.device) + + # Mean A over valid entries + valid_entries = A[:, mask.bool()] # [batch, 30720] + mean_A = valid_entries.mean().item() + + # Classify connections: adjacent (layer diff == 1) vs skip (layer diff > 1) + layer_idx = torch.arange(num_nodes, device=A.device) // num_heads + layer_diff = layer_idx.unsqueeze(0) - layer_idx.unsqueeze(1) # [256, 256] + # layer_diff[i,j] = layer(j) - layer(i) + + adj_mask = (layer_diff == 1) & mask.bool() # adjacent-layer connections + skip_mask = (layer_diff > 1) & mask.bool() # skip connections + + # Fraction of gates > 0.5 + adj_vals = A[:, adj_mask] # [batch, 3840] + skip_vals = A[:, skip_mask] # [batch, 26880] + + seq_gate_frac = (adj_vals > 0.5).float().mean().item() if adj_vals.numel() > 0 else 0.0 + hyp_gate_frac = (skip_vals > 0.5).float().mean().item() if skip_vals.numel() > 0 else 0.0 + + # Jaccard variance across batch + jaccard_var = _jaccard_variance(A, mask).item() if batch > 1 else 0.0 + + return { + "topology/mean_A": mean_A, + "topology/seq_gate_frac": seq_gate_frac, + "topology/hyp_gate_frac": hyp_gate_frac, + "topology/jaccard_var": jaccard_var, + } + + +def _jaccard_variance(A: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: + """Compute variance of pairwise Jaccard similarity across batch. + + Measures how context-dependent the topologies are. + Higher variance = more context-dependent routing. + """ + batch = A.shape[0] + if batch < 2: + return torch.tensor(0.0) + + # Binarize at 0.5 threshold for Jaccard + binary = (A > 0.5).float() + valid = mask.bool() + + # Extract valid entries: [batch, num_valid] + entries = binary[:, valid] + + # Pairwise Jaccard + jaccards = [] + for i in range(batch): + for j in range(i + 1, batch): + intersection = (entries[i] * entries[j]).sum() + union = ((entries[i] + entries[j]) > 0).float().sum() + jaccard = intersection / union.clamp(min=1.0) + jaccards.append(jaccard) + + if not jaccards: + return torch.tensor(0.0) + + jaccards = torch.stack(jaccards) + return jaccards.var() -- cgit v1.2.3