From 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:00:39 -0600 Subject: Initial implementation: DAGFormer Phase 1 - olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 --- CLAUDE.md | 1381 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 1381 insertions(+) create mode 100644 CLAUDE.md (limited to 'CLAUDE.md') diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..1d83dec --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,1381 @@ +# CLAUDE.md — DAGFormer Project Specification + +**Read this file in full before writing any code.** This document is the +single source of truth for all design decisions. If something contradicts +the README or any other file, this document wins. + +--- + +## 1. What Is This Project? + +**DAGFormer** trains a small neural network (the "structure predictor") to +predict, for each token, the optimal wiring diagram (a DAG) of a frozen +1B-parameter language model (OLMo2-1B). The predicted DAG controls which +attention heads talk to which other heads. The entire system is trained +end-to-end with language modeling loss — no labeled topology data needed. + +### Why? + +Standard transformers have a fixed sequential computation graph: layer 0 → +layer 1 → ... → layer 15. Every input sees the same wiring. We showed via +expensive oracle search that **context-dependent topologies** can reduce +next-token prediction loss (NLL) from 2.58 to 0.12 (median across 50 +evaluation windows, 100% of windows improved). The oracle search is too +slow to use at scale (500 gradient steps per window), so we need a learned +predictor that produces the topology in a single forward pass. + +### What is NOT this project? + +- This is NOT the oracle search codebase (that exists separately) +- This is NOT a Mixture-of-Experts project (despite the repo history) +- This does NOT modify the OLMo2-1B weights in Phase 1 +- This does NOT implement Phase 2 (joint training) yet — only the + infrastructure to support it later + +--- + +## 2. Architecture — Exact Specification + +### 2.1 The Computation Graph of OLMo2-1B + +OLMo2-1B (HuggingFace ID: `allenai/OLMo-2-0425-1B`) has: + +- **16 transformer layers**, each with **16 attention heads** +- This gives **256 "nodes"** total: node `(l, h)` = layer `l`, head `h` +- We flatten to a single index: `node_id = l * 16 + h` (0-indexed) + +**Standard forward pass** — each layer does: +```python +# Input: residual (shared across all heads in this layer) +normed = RMSNorm(residual) +attn_out = self_attn(normed) # all 16 heads compute in parallel, + # outputs concatenated and projected by o_proj +residual = residual + attn_out # attention residual connection +normed2 = RMSNorm(residual) +mlp_out = MLP(normed2) +residual = residual + mlp_out # MLP residual connection +``` + +The **residual stream** at the start of layer `l` is therefore: +``` +residual_l = embedding + Σ_{l' layer(i) # i.e. j//16 > i//16 +mask[i, j] = 0 if layer(j) <= layer(i) # same layer or backward + +# WRONG: do NOT use torch.triu() — it would allow same-layer connections +# e.g. triu would set mask[0,15]=1, but both are in layer 0 +``` + +Heads in the same layer execute in parallel and cannot see each other's +outputs. Only cross-layer forward connections are meaningful. + +#### Connection Count + +| Type | Definition | Count | Role | +|------|-----------|-------|------| +| **Adjacent-layer** | `layer(j) = layer(i) + 1`, all head pairs | 15 × 16 × 16 = **3,840** | These exist in standard transformer (via shared residual). When gated to 1, behavior matches baseline. | +| **Skip** | `layer(j) > layer(i) + 1`, all head pairs | 105 × 16 × 16 = **26,880** | These do NOT exist in standard transformer. They are additional direct routes that bypass intermediate layers. | +| **Total** | All entries where `layer(j) > layer(i)` | **30,720** | | + +For logging and analysis, label connections as "adjacent" or "skip", but +the forward pass treats all 30,720 entries identically. + +> **Note on "31K" count**: The oracle search reported "256 sequential + +> 30,720 hyperconnection ≈ 31K" using a different parameterization +> (separate per-head activity gates + routing gates). In our unified +> 256×256 matrix, there are exactly 30,720 free entries. Both represent +> the same underlying structure. + +#### What "head output" means (resolves shape ambiguity) + +In HuggingFace OLMo2, the `o_proj` layer concatenates all 16 heads and +projects back to model_dim: + +```python +# Inside self_attn: +# Each head computes: attn_weights @ V → [batch, seq, head_dim] (head_dim = 128) +# All heads concatenated: [batch, seq, 16 * 128] = [batch, seq, 2048] +# Then: o_proj([batch, seq, 2048]) → [batch, seq, 2048] +``` + +The `o_proj` weight matrix `W_o ∈ R^{2048 × 2048}` can be viewed as 16 +column blocks: `W_o = [W_o^0 | W_o^1 | ... | W_o^15]`, each `W_o^h ∈ +R^{2048 × 128}`. The full attention output is: + +``` +attn_output = Σ_h W_o^h @ head_value_h = Σ_h head_output[l, h] +``` + +So **`head_output[l, h]` is `W_o^h @ head_value_h`, shape `[batch, seq, +model_dim]` (= 2048)**. This is each head's contribution to the attention +output in MODEL DIM space. Heads can be summed directly because they're +already in the same 2048-dim space. + +#### Per-Head Input Assembly (the core modification) + +In DAGFormer, each head j = (l_j, h_j) has its **own input**, assembled +from three sources: + +```python +input_j = embedding # (1) always present + + Σ_{l' < l_j} mlp_output[l'] # (2) all prior MLP outputs, NOT gated + + Σ_{i: layer(i) < l_j} A[i,j] * head_output[i] # (3) gated head outputs +``` + +**Source (1) — Token embedding**: The output of the token embedding layer +(before any transformer layer). Always included for every head. This is +NOT part of A. Think of it as the "seed" of the computation. + +**Source (2) — MLP outputs**: Each layer's MLP computes on a shared +aggregated state (see below) and its output is added to ALL downstream +head inputs equally. MLP outputs are **never gated by A**. This keeps +the MLP's contribution identical to the standard forward pass. + +**Source (3) — Gated head outputs**: The only part controlled by A. +Each entry A[i,j] scales how much of head i's output reaches head j's +input. Different heads h within the same layer l_j can receive different +weighted combinations of prior head outputs — this is what makes the +256×256 per-head routing meaningful. + +#### Baseline Reproduction Proof (resolves the "double-counting" concern) + +When A[i,j] = 1 for ALL 30,720 valid entries: + +``` +input_j (where j is in layer l) + = embedding + Σ_{l' i//16 │ + │ (layer(j) > layer(i)) │ + │ │ + │ ⚠ Do NOT use torch.triu() — │ + │ it allows same-layer connections│ + │ │ + │ Z_masked = Z * mask + (-1e9) * (1 - mask) │ + │ (force invalid positions to │ + │ -inf so sigmoid → 0) │ + └─────────────────────────────────┘ + │ + ▼ + ┌─────────────────────────────────┐ + │ Gumbel-Sigmoid (3 modes) │ + │ │ + │ MODE 1 — TRAINING: │ + │ G ~ Logistic(0, 1) │ + │ (= log(U) - log(1-U), │ + │ U ~ Uniform(0,1)) │ + │ A = σ((Z_masked + G) / τ) │ + │ Gumbel noise G adds stochastic│ + │ exploration. Gradients flow │ + │ through σ naturally (this is │ + │ continuous relaxation, NOT │ + │ STE — no straight-through │ + │ estimator needed here). │ + │ │ + │ MODE 2 — EVAL SOFT: │ + │ A = σ(Z_masked / τ) │ + │ NO Gumbel noise. Deterministic│ + │ soft gates. Values in (0,1). │ + │ Used for eval/nll_soft metric.│ + │ │ + │ MODE 3 — EVAL HARD: │ + │ A = (Z_masked > 0).float() │ + │ NO Gumbel noise. Binary 0/1. │ + │ Threshold at logit=0 (= prob │ + │ 0.5). Used for eval/nll_hard │ + │ and for final inference. │ + │ │ + │ τ = temperature, annealed │ + │ during training (see §3) │ + └─────────────────────────────────┘ + +**WHY Gumbel-Sigmoid — the gradient problem and its solution:** + +At inference we want binary A ∈ {0, 1} (discrete routing decisions). +But discrete decisions have zero gradient — you can't backprop through +`(Z > 0).float()`. The predictor MLP would receive no learning signal. + +Gumbel-Sigmoid is a **continuous relaxation** (also called the "concrete +distribution" / "surrogate gradient" technique): + +``` +Training: A = σ((Z + G) / τ) ← continuous, differentiable + ∂A/∂Z = A(1-A)/τ ← well-defined gradient + backprop: ∂Loss/∂Z = ∂Loss/∂A · ∂A/∂Z + +Inference: A = (Z > 0).float() ← discrete, but no gradient needed +``` + +The complete gradient chain during training: +``` +Loss (NLL from OLMo) + → ∂Loss/∂logits (OLMo's output layer) + → ∂logits/∂head_inputs (through OLMo's frozen layers — computed + but NOT used to update OLMo weights) + → ∂head_inputs/∂A (the gate multiplication: input_j = Σ A[i,j] * head_out[i]) + → ∂A/∂Z (sigmoid derivative: A(1-A)/τ — THIS is the surrogate) + → ∂Z/∂(U,V) (low-rank matmul: Z = UV^T) + → ∂(U,V)/∂MLP_params (the trainable predictor MLP) + → optimizer updates MLP_params +``` + +Key points: +- OLMo is frozen but its forward computation is ON the gradient tape + (it's a differentiable function of A). Gradients flow THROUGH OLMo + back to A, even though OLMo's own parameters don't get updated. +- The sigmoid σ acts as a differentiable surrogate for the discrete + threshold. As τ → 0, σ((Z+G)/τ) approaches a step function, but + during training τ is always > 0 so gradients are always nonzero. +- Gumbel noise G provides stochastic exploration: even if Z is small, + the noise can push A above or below 0.5, letting the model discover + which connections matter. +- NO straight-through estimator (STE) is needed. The sigmoid itself + provides the gradient. STE would be needed if we hard-thresholded + during training, but we don't. + │ + ▼ + ┌─────────────────────────────────┐ + │ Cascading Gate │ + │ │ + │ Purpose: if node j has no │ + │ incoming edges, it has no info │ + │ to propagate, so kill outgoing. │ + │ │ + │ ONE-PASS computation (not │ + │ sequential by layer): │ + │ │ + │ 1. Compute ALL incoming sums: │ + │ inc_j = Σ_i A[i, j] ∀j │ + │ 2. Compute ALL gates: │ + │ g_j = σ(k * inc_j) ∀j │ + │ 3. Apply ALL gates at once: │ + │ A[j, :] *= g_j ∀j │ + │ │ + │ This is a single vectorized op, │ + │ NOT a layer-by-layer cascade. │ + │ All incoming sums use the │ + │ ORIGINAL A values (before any │ + │ gates are applied). │ + │ │ + │ k = 5.0 (fixed scalar) │ + │ k can be made learnable later. │ + │ This is fully differentiable. │ + │ │ + │ EVAL HARD MODE: After hard- │ + │ thresholding A to binary 0/1, │ + │ the cascading gate is also │ + │ hard: g_j = 1 if inc_j > 0, │ + │ else g_j = 0. (Because σ(5*0) │ + │ = 0.5 would be wrong for │ + │ binary gates — a disconnected │ + │ node should be fully killed, │ + │ not half-alive.) │ + │ │ + │ SOFT MODE NOTE: In training and │ + │ eval-soft mode, the cascading │ + │ gate uses σ(k * inc_j) with the│ + │ ORIGINAL (pre-gate) A values. │ + │ If all incoming gates are small │ + │ (e.g. 0.01 each), inc_j can │ + │ still be > 0, giving g_j > 0.5.│ + │ This is INTENTIONAL: in soft │ + │ mode, "weakly connected" is a │ + │ valid state (the gradient can │ + │ still flow). The cascading gate │ + │ is a soft penalty, not a hard │ + │ kill, during training. │ + └─────────────────────────────────┘ + │ + ▼ + Output: A ∈ [0,1]^{batch × 256 × 256}, block-upper-triangular + (30,720 free entries, rest forced to 0) +``` + +### 2.4 End-to-End Pipeline + +``` +raw text ──→ [Qwen tokenizer] ──→ qwen_ids ──→ [Qwen encoder] ──→ e ──→ [Predictor MLP] ──→ A + │ │ + │ ▼ + └──→ [OLMo tokenizer] ──→ olmo_ids ──→ [OLMo2-1B modified forward with A] ──→ logits ──→ NLL + │ + ∇ backprop to predictor MLP +``` + +**Tokenization**: Qwen and OLMo have DIFFERENT tokenizers and vocabularies. +The dataloader produces raw text strings. Each model tokenizes independently: +```python +# In the dataloader or pipeline: +raw_text = batch["text"] # list of strings +qwen_ids = qwen_tokenizer(raw_text, ...)["input_ids"] # Qwen's token IDs +olmo_ids = olmo_tokenizer(raw_text, ...)["input_ids"] # OLMo's token IDs +# These are DIFFERENT tensors with DIFFERENT lengths and vocabularies. +# Qwen produces a pooled embedding (one vector per sequence). +# OLMo produces per-token logits for NLL computation. +``` + +- Qwen and OLMo are FROZEN (Phase 1). Only the MLP decoder trains. +- The same **raw text** goes to both Qwen and OLMo, but they use + **separate tokenizers**. Qwen tokenizes independently to produce an + embedding; OLMo tokenizes independently for language modeling. + Token IDs are NOT shared between the two models. +- Qwen's output (a single pooled vector per sequence) goes to the + predictor MLP → A matrix. OLMo uses A in its modified forward pass. +- Loss = NLL (from OLMo) + λ · mean(A) (sparsity regularization) + +--- + +## 3. Training — Exact Specification + +### 3.1 Phase 1: Learn Topology (IMPLEMENT THIS) + +| What | Frozen/Trainable | +|------|-----------------| +| OLMo2-1B (`allenai/OLMo-2-0425-1B`) | ❄ FROZEN, no grad | +| Qwen-3-Embedding (`Qwen/Qwen3-Embedding-0.6B`) | ❄ FROZEN, no grad | +| Structure Predictor (MLP decoder) | 🔥 TRAINABLE | + +| Hyperparameter | Value | Notes | +|----------------|-------|-------| +| Dataset | Dolma v1.7 (`allenai/dolma`, `name="v1_7"`, streamed) | Specify version explicitly | +| Token budget | 5–10B | Configurable | +| Sequence length | 1024 | OLMo token count, matches oracle search | +| Batch size | 32 (start) | Reduce if OOM | +| Learning rate | 3e-4 | For predictor MLP only | +| Optimizer | AdamW | β1=0.9, β2=0.999, weight_decay=0.01 | +| LR schedule | Cosine decay to 0 | Standard | +| τ (temperature) init | 5.0 | | +| τ final | 0.2 | | +| τ schedule | Cosine annealing | τ(t) = τ_f + 0.5(τ_i - τ_f)(1 + cos(πt/T)) | +| Sparsity λ max | 0.01 | | +| Sparsity ramp | Linear 0 → λ_max over first 20% of steps | | +| Hardware | 4× A40 (48GB each) | Use DDP | + +**Loss function:** +```python +total_loss = nll_loss + lambda_t * A.mean() +``` +where `lambda_t` ramps linearly from 0 to `lambda_max` over the first 20% +of training steps. + +**τ schedule (exact formula):** +```python +tau_t = tau_final + 0.5 * (tau_init - tau_final) * (1 + math.cos(math.pi * step / total_steps)) +``` + +### 3.1.1 Data Processing Pipeline + +**Dolma version**: Use `allenai/dolma` with `name="v1_7"`. If v1_7 is not +available on HuggingFace, fall back to whatever default version loads. +Verify by printing the dataset info at startup and logging to wandb. + +**Sequence packing** (critical for training efficiency): + +Dolma contains documents of varying lengths. Do NOT pad short documents +or discard them. Instead, **pack multiple documents into fixed-length +sequences**: + +```python +# Pseudocode for sequence packing: +buffer = [] +for doc in dolma_stream: + olmo_tokens = olmo_tokenizer(doc["text"], add_special_tokens=False)["input_ids"] + buffer.extend(olmo_tokens) + buffer.append(olmo_tokenizer.eos_token_id) # document separator + + while len(buffer) >= seq_len + 1: # +1 for labels (next-token prediction) + chunk = buffer[:seq_len + 1] + buffer = buffer[seq_len + 1:] + yield { + "olmo_ids": chunk[:seq_len], # input + "olmo_labels": chunk[1:seq_len+1], # shifted target + "raw_text": olmo_tokenizer.decode(chunk[:seq_len]) # for Qwen + } +``` + +This means: +- No padding ever. Every token contributes to NLL. +- A single training sequence may contain parts of multiple documents, + separated by EOS tokens. The causal mask handles this correctly (each + token can only attend to prior tokens, so cross-document leakage is + minimal and standard practice). +- `raw_text` is decoded from OLMo tokens to feed to Qwen. This is a + slight approximation (Qwen re-tokenizes the decoded text), but Qwen + only needs to understand the gist of the context to produce a good + embedding — exact token alignment doesn't matter. +- **Qwen sees the entire packed sequence** (which may contain multiple + document fragments separated by EOS). This is intentional: the + predictor should condition on exactly what OLMo will process. Qwen's + mean-pooling produces a single summary vector of the full 1024-token + window, which is the right granularity — one A matrix per window. +- Qwen's `seq_len` will differ from OLMo's 1024 (different tokenizer + granularity), but this is fine because Qwen's output is mean-pooled + to a single vector regardless of input length. + +**Qwen input format**: Qwen3-Embedding-0.6B is an embedding model. +**Decision**: Use raw text directly (no prefix). Rationale: our input is +a packed sequence of document fragments, not a search query or passage +in the retrieval sense. Prefixes like "query:" or "passage:" are designed +for retrieval tasks and would be semantically misleading here. The Qwen +encoder just needs a general-purpose text representation. + +Log `qwen_input_prefix: ""` to wandb config for reproducibility. If +future experiments show that a prefix helps, it can be changed via the +`qwen_input_prefix` config field. + +### 3.2 Evaluation Data + +Use a **fixed held-out subset of Dolma** for eval. Implementation: + +```python +# At startup, skip the first N examples to get to a held-out region +eval_dataset = load_dataset("allenai/dolma", name="v1_7", split="train", streaming=True) +eval_dataset = eval_dataset.skip(1_000_000) # skip 1M examples +eval_dataset = eval_dataset.take(1_000) # use next 1K as eval set + +# Pack these into fixed-length sequences using the same packing logic +# as training (§3.1.1), then cache in memory at startup +eval_batches = list(eval_dataloader) # ~1K sequences, fits in RAM +``` + +**Eval caching**: The `skip(1_000_000)` on a streaming dataset is slow +(~minutes). To avoid this cost on every restart, **cache the packed eval +sequences to disk** after the first run: +```python +eval_cache_path = os.path.join(config.save_dir, "eval_cache.pt") +if os.path.exists(eval_cache_path): + eval_batches = torch.load(eval_cache_path) +else: + eval_batches = list(build_eval_dataloader(...)) + torch.save(eval_batches, eval_cache_path) +``` + +**Multi-GPU eval**: Run eval on **rank 0 only**. Other ranks skip eval +and wait at a barrier. This avoids redundant computation and ensures +consistent eval metrics (no need to reduce across GPUs). + +Eval runs in **two modes** (both deterministic, no Gumbel noise): +- **Soft**: A = σ(Z / τ) at current temperature. Reports `eval/nll_soft`. +- **Hard**: A = (Z > 0).float(), binary 0/1. Reports `eval/nll_hard`. +Also report `eval/nll_baseline` with A = all-ones (should be constant). + +### 3.3 Multi-GPU Strategy + +**Standard DDP — each GPU holds a full replica of all three models.** + +Memory budget per GPU (fp16/bf16): +``` +OLMo2-1B parameters: ~2.0 GB +Qwen-0.6B parameters: ~1.2 GB +Predictor MLP: ~0.05 GB +Optimizer states: ~0.1 GB (only predictor params) +───────────────────────────── +Model total: ~3.4 GB + +Activations (seq_len=1024, batch=8): + OLMo per-head forward: ~12-16 GB (per-head inputs means 16 separate + attention computations per layer instead of 1 batched MHA — this + increases intermediate activation storage significantly vs standard) + Qwen forward: ~3 GB + Per-head A storage: ~0.5 GB (256×256 × batch × float32) +───────────────────────────── +Activations total: ~16-20 GB + +Grand total: ~20-24 GB per GPU (with batch=8) +Headroom on 48GB A40: ~24-28 GB +``` + +This fits but is tighter than a standard OLMo forward pass. If OOM, +reduce batch_size to 4 first (halves activation memory). If still OOM, +use **gradient accumulation** to maintain effective batch size: +```python +# Example: effective_batch=8, micro_batch=2, accumulation_steps=4 +accumulation_steps = config.batch_size // config.micro_batch_size +for micro_step in range(accumulation_steps): + loss = forward(micro_batch) / accumulation_steps + loss.backward() +optimizer.step() +optimizer.zero_grad() +``` +Add `micro_batch_size` to config (default: same as `batch_size`, i.e. no +accumulation). The per-head computation path creates ~2-3x more +intermediates than standard batched MHA because we cannot fuse across +heads when each has a different input. + +**Gradient checkpointing**: Since OLMo is frozen (no gradients computed +for its parameters), we do NOT need to store OLMo's intermediate +activations for backprop through OLMo's weights. However, we DO need +gradients to flow through OLMo's forward pass back to A (and then to +the predictor). This means OLMo's forward activations are needed for +the chain rule through A's gate multiplications, but NOT for OLMo's +own parameter gradients. + +If memory is still tight, apply `torch.utils.checkpoint` to the OLMo +layer loop — recompute each layer's forward pass during backward instead +of storing all intermediates. This trades compute for memory. +Gradient checkpointing is optional; try without it first. + +**Storing head_outputs**: The forward pass accumulates `head_outputs` for +all 256 nodes (each `[batch, seq, 2048]`). At fp16 with batch=8, +seq=1024: 256 × 8 × 1024 × 2048 × 2 bytes ≈ 8.6 GB. This is the +dominant memory cost. Gradient checkpointing can reduce this by +recomputing head outputs layer-by-layer during backward. + +Use `DistributedDataParallel` wrapping ONLY on the predictor MLP (the +only trainable component). The frozen Qwen and OLMo models do not need +DDP wrapping — just load them on each GPU. + +```python +# DDP setup +predictor = StructurePredictor(config).to(device) +predictor = DDP(predictor, device_ids=[local_rank]) + +# Frozen models — no DDP needed +olmo = AutoModelForCausalLM.from_pretrained(...).to(device).eval() +qwen = AutoModel.from_pretrained(...).to(device).eval() + +# Gradient disabled for frozen models +for p in olmo.parameters(): p.requires_grad_(False) +for p in qwen.parameters(): p.requires_grad_(False) +``` + +Data parallelism: Dolma is a streaming `IterableDataset` — do NOT use +`DistributedSampler` (which requires map-style datasets). Instead, shard +the stream manually: +```python +dataset = load_dataset("allenai/dolma", name="v1_7", split="train", streaming=True) +dataset = dataset.shard(num_shards=world_size, index=rank) +``` +Each GPU processes a disjoint shard. Gradients are synchronized by DDP. + +### 3.4 o_proj Bias and Weight Layout + +OLMo2-1B uses **no bias** in its linear layers (`bias=False` for q_proj, +k_proj, v_proj, o_proj, and MLP layers). This is standard for modern LLMs. + +**Verify this at runtime:** +```python +assert not model.layers[0].self_attn.o_proj.bias, \ + "Expected no bias in o_proj — update per-head splitting if bias exists" +``` + +If a future model version adds bias, the per-head split must also split +the bias: `bias_h = bias[h * head_dim : (h+1) * head_dim]` for input +projections, or `bias / num_heads` for o_proj (since it's additive). +But for OLMo2-1B, this is not needed. + +**o_proj weight layout** for per-head splitting: +```python +# o_proj.weight: [model_dim, model_dim] = [2048, 2048] +# This maps concatenated head outputs → model_dim +# +# Split by INPUT dimension (head_dim chunks): +# o_proj_h.weight = o_proj.weight[:, h*head_dim : (h+1)*head_dim] +# Shape: [2048, 128] +# +# head_output[l, h] = attn_values_h @ o_proj_h.weight.T +# Shape: [batch, seq, 128] @ [128, 2048] → [batch, seq, 2048] +``` + +### 3.5 Phase 2: Joint CPT (DO NOT IMPLEMENT — FUTURE WORK) + +In Phase 2, OLMo would be unfrozen and co-trained with the predictor using +differential learning rates (OLMo: 3e-5, predictor: 1e-4). The training +loop should be DESIGNED to support this (parameter groups with different +LRs) but the actual unfreezing logic should NOT be implemented yet. + +--- + +## 4. OLMo2-1B Modification — Implementation Guide + +This is the hardest part. The goal: intercept OLMo2-1B's forward pass to +apply the adjacency matrix A without forking the model code. + +### 4.1 Strategy: Hook-Based or Subclass + +**Option A (preferred): Monkey-patch the attention forward.** +- Load the model normally via HuggingFace +- Replace each layer's attention module's forward method with a wrapped + version that accepts and applies A +- Pro: no code duplication, stays in sync with HF updates +- Con: fragile if HF changes internal API + +**Option B: Subclass the model.** +- Subclass `OlmoForCausalLM` and override the relevant methods +- Pro: cleaner, more explicit +- Con: more code to maintain + +Choose whichever is cleaner after inspecting the actual OLMo2 source. +The key requirement is: **when A is not provided or is all-ones, the +output must be IDENTICAL to vanilla OLMo2-1B.** + +### 4.2 What Exactly to Modify + +**Reference pseudocode — this is the exact semantics to implement:** + +```python +def dagformer_forward(model, olmo_ids, A, input_norm="none"): + """ + Args: + model: OLMo2-1B loaded from HuggingFace + olmo_ids: [batch, seq_len] — tokenized by OLMo's tokenizer + A: [batch, 256, 256] — block-upper-triangular gate matrix + (produced by the predictor from Qwen embeddings — but this + function doesn't know about Qwen, it just takes A as input) + input_norm: normalization method for assembled head inputs + Returns: + logits: [batch, seq_len, vocab_size] + """ + batch, seq = olmo_ids.shape + model_dim = 2048 + num_layers = 16 + num_heads = 16 + + # Token embedding (always included, not gated) + embedding = model.embed_tokens(olmo_ids) # [batch, seq, 2048] + + # Storage for per-head outputs (in model_dim space) + head_outputs = {} # (layer, head) -> [batch, seq, model_dim] + mlp_outputs = {} # layer -> [batch, seq, model_dim] + + for l in range(num_layers): + layer_module = model.layers[l] + + # === ASSEMBLE PER-HEAD INPUTS === + # Base: embedding + all prior MLP outputs (shared, ungated) + base_l = embedding.clone() + for prev_l in range(l): + base_l = base_l + mlp_outputs[prev_l] + + # Per-head: add gated head outputs from earlier layers + # This is the KEY difference from standard transformer: + # each head h in layer l gets its OWN input. + per_head_inputs = [] + for h in range(num_heads): + j = l * num_heads + h + assembled = base_l.clone() # [batch, seq, model_dim] + + # Gated sum of all prior head outputs + gated_sum = torch.zeros_like(base_l) + for src_l in range(l): + for src_h in range(num_heads): + i = src_l * num_heads + src_h + gate = A[:, i, j] # [batch] + gated_sum = gated_sum + gate[:, None, None] * head_outputs[(src_l, src_h)] + + # Apply normalization ONLY to gated_sum, then add to base (see §6.1) + assembled = base_l + apply_norm(gated_sum, method=input_norm) + per_head_inputs.append(assembled) + + # === RUN ATTENTION WITH PER-HEAD INPUTS === + # This requires splitting the attention computation: + # 1. Each head h gets its own Q from per_head_inputs[h] + # 2. Each head h gets its own K, V from per_head_inputs[h] + # 3. Each head computes attention independently + # 4. Each head's output is projected by its slice of o_proj + # + # In practice: need to split q_proj, k_proj, v_proj into per-head + # projections, run attention per-head, then apply per-head o_proj. + # + # Efficient batch implementation: + # Stack per_head_inputs → [batch, num_heads, seq, model_dim] + # Apply qkv projection in batch across heads + # Run attention, apply o_proj per-head + stacked = torch.stack(per_head_inputs, dim=1) # [batch, 16, seq, 2048] + normed = layer_module.input_layernorm(stacked) # RMSNorm per-head + # NOTE: RMSNorm(2048) operates on the last dimension. When applied to + # [batch, 16, seq, 2048], it normalizes each head's input independently. + # When A=1, all 16 heads have identical inputs, so RMSNorm produces + # identical outputs — matching the standard single-RMSNorm behavior. + # No numerical divergence occurs in the baseline case. + + for h in range(num_heads): + q = q_proj_head(normed[:, h], head=h) # [batch, seq, head_dim] + k = k_proj_head(normed[:, h], head=h) + v = v_proj_head(normed[:, h], head=h) + attn_out_h = attention(q, k, v) # [batch, seq, head_dim] + head_outputs[(l, h)] = o_proj_head(attn_out_h, head=h) # [batch, seq, model_dim] + + # === RUN MLP (ungated, standard) === + # MLP input = base_l + ungated sum of THIS layer's head outputs + attn_output_l = sum(head_outputs[(l, h)] for h in range(num_heads)) + mlp_input = base_l + attn_output_l # standard residual + mlp_outputs[l] = layer_module.mlp(layer_module.post_attention_layernorm(mlp_input)) + + # Final: assemble output state = embedding + all heads + all MLPs + # IMPORTANT: final output uses UNGATED sum of ALL head outputs. + # A only controls intermediate routing (which heads feed into which). + # The final output is NOT gated — every head's output contributes + # equally to the final state. This is a deliberate design choice: + # (1) it matches the standard residual stream when A=1, + # (2) it means A controls *how* heads compute (via their inputs), + # not *whether* their outputs are used in the final prediction. + final_state = embedding + for l in range(num_layers): + final_state = final_state + mlp_outputs[l] + for h in range(num_heads): + final_state = final_state + head_outputs[(l, h)] + + logits = model.lm_head(model.norm(final_state)) + return logits +``` + +⚠ **This pseudocode is for SEMANTIC CLARITY. The actual implementation +MUST use batched operations for performance:** + +```python +# Efficient version for the gated sum: +# Stack all prior head outputs into a tensor +prior_outputs = torch.stack([head_outputs[(l', h')] + for l' in range(l) for h' in range(16)], dim=1) +# prior_outputs: [batch, l*16, seq, model_dim] + +# Slice A for connections into this layer +A_slice = A[:, :l*16, l*16:(l+1)*16] # [batch, l*16, 16] + +# Batched gated sum: einsum or matmul +# per_head_gated[h] = sum_i A_slice[:, i, h] * prior_outputs[:, i] +per_head_gated = torch.einsum('bih,bisd->bhsd', A_slice, prior_outputs) +# per_head_gated: [batch, 16, seq, model_dim] + +# Apply normalization ONLY to per_head_gated, then add base (consistent with §6.1) +per_head_gated = apply_norm(per_head_gated, method=input_norm) +assembled = base_l.unsqueeze(1) + per_head_gated # [batch, 16, seq, model_dim] +``` + +**Splitting Q/K/V projections per-head**: OLMo2's `q_proj`, `k_proj`, +`v_proj` are single linear layers projecting from `model_dim` to +`num_heads * head_dim`. To apply different inputs per head: + +```python +# Option A: reshape the weight matrix and apply per-head +W_q = layer.self_attn.q_proj.weight # [2048, 2048] +W_q_heads = W_q.view(num_heads, head_dim, model_dim) # [16, 128, 2048] +# For each head h: q_h = assembled[:, h] @ W_q_heads[h].T + +# Option B: run the full projection on each head's input separately +# Less efficient but simpler + +# Option C (RECOMMENDED): run full projection on stacked inputs +# assembled: [batch, 16, seq, 2048] +# Reshape to [batch*16, seq, 2048], run q_proj, reshape back +``` + +**RoPE (Rotary Position Embeddings)**: OLMo2 applies RoPE to Q and K +AFTER projection, BEFORE the attention dot product. This is critical for +per-head computation: + +```python +# Standard OLMo2 attention (simplified): +q = q_proj(hidden_states) # [batch, seq, num_heads * head_dim] +k = k_proj(hidden_states) # [batch, seq, num_kv_heads * head_dim] +q, k = apply_rotary_emb(q, k, cos, sin, position_ids) +# Then attention(q, k, v) + +# DAGFormer per-head version: +# For each head h with its own input: +q_h = q_proj_head(assembled_h, head=h) # [batch, seq, head_dim] +k_h = k_proj_head(assembled_h, head=h) # [batch, seq, head_dim] +q_h, k_h = apply_rotary_emb(q_h, k_h, cos, sin, position_ids) # SAME cos/sin/position_ids for all heads +v_h = v_proj_head(assembled_h, head=h) # no RoPE on V +attn_out_h = attention(q_h, k_h, v_h) +``` + +The cos/sin cache and position_ids are **shared across all heads** (they +depend on sequence position, not on head identity). Extract them once from +the model's rotary embedding module at the start of each layer, then reuse +for all 16 heads. If the implementation fails to apply RoPE, the baseline +reproduction sanity check WILL fail because position information is lost. + +**OLMo2-1B uses standard MHA (NOT GQA)**: OLMo2-1B has +`num_attention_heads = 16` and `num_key_value_heads = 16` (every Q head +has its own K and V head). This means per-head splitting is straightforward +— no KV sharing complications. **Verify at runtime:** +```python +config = model.config +assert config.num_attention_heads == 16 +assert config.num_key_value_heads == 16, \ + f"Expected MHA (16 KV heads), got {config.num_key_value_heads} — GQA detected, update splitting logic" +``` +If a future model uses GQA, the per-head splitting must account for shared +KV heads. But OLMo2-1B does not require this. + +**Causal attention mask**: The standard causal mask within each head's +self-attention (preventing token t from attending to tokens t+1, t+2, ...) +is UNCHANGED. DAGFormer's adjacency matrix A controls **cross-layer** +routing (which heads' outputs feed into which heads' inputs). The causal +mask controls **within-sequence** attention (which token positions can +attend to which). These are orthogonal — A operates at the head/layer +level, causal mask operates at the token position level. Use OLMo's +existing causal mask implementation as-is. + +### 4.3 Sanity Checks (MUST PASS before proceeding) + +1. **Baseline reproduction**: Set A[i,j]=1 for all 30,720 valid cross-layer + entries, `input_norm: "none"`. NLL must match vanilla OLMo2-1B within + 0.01 nats. This works because A=1 makes every head's input equal to the + standard residual stream (see §2.2 proof). If this fails, the gate + injection or per-head splitting logic is wrong. +2. **A = all-zeros** (all 30,720 entries = 0): Every head sees only + embedding + MLP outputs, no cross-layer attention routing. NLL should + be significantly higher than baseline. +3. **A = random in [0,1]**: NLL should be between the all-ones and all-zeros + cases. +4. **Gradient check**: Create A as a leaf tensor with `requires_grad=True`, + compute NLL, call `.backward()`, verify `A.grad` is not None and has + nonzero entries at all 30,720 valid positions. +5. **Normalization smoke test**: For each `input_norm` in {gate_mean, + rms_post, ln_post, rms_pre}, run forward with A=all-ones. NLL will NOT + match baseline (normalization changes the scale), but must be finite + (no NaN/Inf). This confirms the norm implementations don't crash. +6. **Per-head input divergence**: Set A to a matrix where different heads + in the same layer have different gate values. Verify that the per-head + inputs are actually different (not collapsed to the same tensor). + This confirms the per-head routing works. + +--- + +## 5. Directory Structure + +``` +dagformer/ +├── CLAUDE.md # THIS FILE — the complete spec +├── README.md # Brief public description +├── pyproject.toml # Dependencies +│ +├── configs/ +│ ├── sanity_check.yaml # 1K steps, verify baseline NLL reproduction +│ ├── ablation_rank.yaml # r ∈ {8, 16, 32, 64} +│ ├── ablation_tau.yaml # τ_init/τ_final sweep +│ ├── ablation_lambda.yaml # sparsity coefficient sweep +│ └── phase1_full.yaml # Full Phase 1 run +│ +├── src/ +│ ├── __init__.py +│ ├── model/ +│ │ ├── __init__.py +│ │ ├── olmo_graph.py # Modified OLMo2 forward with A injection +│ │ ├── predictor.py # Qwen encoder + MLP decoder + Gumbel + cascade +│ │ └── pipeline.py # Combines predictor + OLMo into single forward +│ ├── data/ +│ │ ├── __init__.py +│ │ └── dolma.py # Streaming dataloader: produces raw text, +│ │ # tokenizes with BOTH Qwen and OLMo tokenizers, +│ │ # returns {qwen_ids, olmo_ids, labels} +│ ├── training/ +│ │ ├── __init__.py +│ │ ├── trainer.py # Training loop (pure PyTorch + DDP) +│ │ ├── schedulers.py # τ annealing, λ ramp, LR schedule +│ │ └── checkpointing.py # Save/load predictor + optimizer + schedule state +│ └── utils/ +│ ├── __init__.py +│ ├── logging.py # Wandb integration +│ └── topology.py # A matrix analysis utilities +│ +├── scripts/ +│ ├── train.py # Entry: python scripts/train.py --config configs/X.yaml +│ ├── eval.py # Evaluate NLL with/without predictor +│ ├── sanity_check.py # Verify A=1 reproduces baseline +│ └── visualize_topology.py # Plot A matrices and gate distributions +│ +└── tests/ + ├── test_olmo_graph.py # Baseline reproduction test + ├── test_predictor.py # Shape and gradient tests + └── test_gumbel.py # Gumbel-Sigmoid correctness +``` + +**File responsibilities (be precise):** + +- `olmo_graph.py`: ONLY handles injecting A into OLMo's forward. Does NOT + know about the predictor, Qwen, or training. Exports a function or class + that takes `(model, olmo_ids, A)` and returns logits. + +- `predictor.py`: ONLY the structure predictor. Takes raw text (or + pre-tokenized Qwen IDs), returns A. Contains Qwen loading, Qwen + tokenizer, MLP, Gumbel-Sigmoid, cascading gate. Does NOT know about + OLMo or training. + +- `pipeline.py`: Glue. Takes raw text (or pre-tokenized batch dict with + both qwen_ids and olmo_ids), calls predictor to get A, calls modified + OLMo forward with A, returns loss. This is what the trainer calls. + +- `trainer.py`: Pure training loop. Loads config, builds pipeline, runs + forward/backward/step. Handles DDP, logging, checkpointing. No model + logic here. + +--- + +## 6. Implementation Order (FOLLOW THIS EXACTLY) + +### Step 0: Project Scaffolding +- Create directory structure above +- Write `pyproject.toml`: + ``` + dependencies: torch>=2.2, transformers>=4.40, datasets, wandb, pyyaml, einops + ``` +- Verify model loading: + ```python + from transformers import AutoModelForCausalLM, AutoModel, AutoTokenizer + olmo = AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B") + qwen = AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B") + ``` +- Verify Dolma streaming: + ```python + from datasets import load_dataset + ds = load_dataset("allenai/dolma", name="v1_7", split="train", streaming=True) + print(next(iter(ds))) # verify a document loads + ``` + +### Step 1: `olmo_graph.py` — Modified OLMo Forward +**This is the foundation. Get it right.** + +1. Load OLMo2-1B, inspect its architecture: + - Find the attention layer class name + - Find where head outputs are computed and merged + - Find the residual connection logic +2. Implement adjacency injection +3. Run sanity checks (§4.3) — ALL FOUR must pass +4. Do not proceed to Step 2 until Step 1 passes all checks + +### Step 2: `predictor.py` — Structure Predictor +1. Implement Qwen encoder wrapper (frozen, mean-pooled) +2. Implement MLP decoder with low-rank heads +3. Implement Gumbel-Sigmoid with 3 modes: (a) training: noise + τ, + (b) eval soft: σ(Z/τ) no noise, (c) eval hard: (Z>0).float() no noise +4. Implement block-upper-triangular mask (based on layer index, NOT torch.triu) +5. Implement cascading gate +6. Test: output shape [batch, 256, 256], values in [0,1], block-upper-tri +7. Test: `loss = A.sum(); loss.backward()` produces gradients in MLP params + +### Step 3: `pipeline.py` — End-to-End +1. Wire predictor output into OLMo modified forward +2. Verify full gradient chain: NLL.backward() updates predictor MLP +3. Profile memory on single GPU (must fit seq_len=1024, batch=1 on 48GB) + +### Step 4: Training Infrastructure +1. YAML config → Python dataclass (not raw dicts) +2. `schedulers.py`: τ annealing, λ ramp, LR cosine decay +3. `trainer.py`: training loop +4. `logging.py`: Wandb metrics (see §7) +5. `checkpointing.py`: save/load + +### Step 5: Sanity Check Training Run +- Config: `sanity_check.yaml` — 1K steps, batch=4, high τ=5.0 +- Verify: loss decreases over 1K steps +- Verify: A is not collapsing (not all-ones, not all-zeros) +- Verify: gradient norms are reasonable +- **STOP HERE if loss doesn't decrease.** Debug before proceeding. + +### Step 6: Ablations (later) +- Rank r, τ schedule, sparsity λ, cascading gate on/off +- **Input normalization** (see §6.1 below) + +### 6.1 Input Normalization Ablation (IMPORTANT) + +When head j receives gated inputs from multiple source heads across +different layers, the magnitudes of these representations can differ +significantly. Layer 0 outputs and layer 14 outputs live at different +scales. The choice of how to normalize the aggregated input to each +head is a critical design decision. + +**The problem** — referring to the per-head assembly from §2.2: +```python +gated_sum = Σ_{i: layer(i) < l} A[i,j] * head_output[i] +# If 50 sources have A[i,j] > 0, gated_sum has much larger magnitude +# than any single head_output. Scale depends on sparsity pattern of A. + +assembled = base_l + gated_sum # base_l = embedding + prior MLPs +# The gated_sum component can dwarf or be dwarfed by base_l. +``` + +**Normalization is applied ONLY to the gated_sum, before adding to base_l:** +```python +assembled = base_l + normalize(gated_sum, method=config.input_norm) +``` +This way, base_l (which is the standard, ungated component) is preserved +as-is, and only the novel gated routing is normalized. + +**Ablation candidates (implement all, sweep in configs):** + +| ID | Method | Formula | Learnable params | Rationale | +|----|--------|---------|-----------------|-----------| +| `none` | Raw weighted sum | `gated_sum` as-is | 0 | Baseline. When A=1 this reproduces vanilla OLMo. | +| `gate_mean` | Divide by gate sum | `gated_sum / (Σ_i A[i,j] + ε)` | 0 | Normalizes for fan-in. ε=1e-8. | +| `rms_post` | RMSNorm after sum | `RMSNorm(gated_sum)` | 2048 (one gain vector) | One shared `nn.RMSNorm(model_dim)` instance, applied to each head's gated_sum. | +| `ln_post` | LayerNorm after sum | `LayerNorm(gated_sum)` | 2×2048 (gain + bias) | One shared `nn.LayerNorm(model_dim)` instance. Affine params are trainable and counted as predictor params. | +| `rms_pre` | RMSNorm each source before sum | `Σ_i A[i,j] * RMSNorm_i(head_output[i])` | 256 × 2048 = 524,288 | One `nn.RMSNorm(model_dim)` **per source node** (256 total). Each head gets its own learnable gain, allowing fine-grained per-head scale correction before mixing. | + +All learnable norm params (if any) are part of the predictor's parameter +group and trained with the same LR and optimizer as the MLP decoder. + +**Default:** Start with `none` (to verify baseline reproduction), then +switch to `gate_mean` for actual training. If that doesn't work, try +`rms_post`. + +**Implementation:** The normalization method is a config string +(`input_norm: "gate_mean"`). The `olmo_graph.py` code dispatches on this +config. All five methods must be implemented. + +**Config example for ablation:** +```yaml +# In configs/ablation_norm.yaml +sweep: + input_norm: ["none", "gate_mean", "rms_post", "ln_post", "rms_pre"] +``` + +--- + +## 7. Logging & Monitoring + +Log ALL of these to Wandb at every training step: + +| Metric | Formula / Source | Why | +|--------|-----------------|-----| +| `train/nll` | Cross-entropy loss | Primary objective | +| `train/sparsity_loss` | λ_t · mean(A) | Regularization term | +| `train/total_loss` | nll + sparsity_loss | What optimizer sees | +| `eval/nll_soft` | NLL with **deterministic** soft gates: A = σ(Z / τ), NO Gumbel noise | Smooth relaxation perf | +| `eval/nll_hard` | NLL with hard gates: A = (Z > 0).float(), NO Gumbel noise | Inference-mode perf | +| `eval/nll_baseline` | NLL with A = all-ones | Should be constant | +| `topology/mean_A` | mean(A) | Overall gate activation | +| `topology/seq_gate_frac` | Fraction of sequential gates > 0.5 | | +| `topology/hyp_gate_frac` | Fraction of hyperconnection gates > 0.5 | | +| `topology/jaccard_var` | Variance of pairwise Jaccard across batch | Context-dependency | +| `schedule/tau` | Current temperature | | +| `schedule/lambda` | Current sparsity coefficient | | +| `grad/predictor_norm` | Total L2 norm of predictor gradients | | + +**Collapse alarm:** If `topology/mean_A` < 0.01 or > 0.99 for 100 +consecutive steps, log a WARNING. The predictor has degenerated. + +--- + +## 8. Config Format + +Use YAML. Example (`configs/sanity_check.yaml`): + +```yaml +# Model +olmo_model_id: "allenai/OLMo-2-0425-1B" +qwen_model_id: "Qwen/Qwen3-Embedding-0.6B" + +# Predictor +predictor_hidden_dim: 1024 +predictor_rank: 32 +cascading_gate_k: 5.0 +input_norm: "none" # one of: none, gate_mean, rms_post, ln_post, rms_pre + # use "none" for sanity check, "gate_mean" for training + +# Data +dataset: "allenai/dolma" +dataset_name: "v1_7" # Dolma version / subset +seq_len: 1024 # OLMo token count per packed sequence +batch_size: 4 +micro_batch_size: 4 # per-GPU micro batch; if < batch_size, use gradient accumulation +qwen_input_prefix: "" # use raw text directly (see §3.1.1) + +# Eval +eval_skip: 1000000 # skip this many examples to reach held-out region +eval_size: 1000 # number of eval sequences (cached in memory) + +# Training +total_steps: 1000 +lr: 3e-4 +weight_decay: 0.01 +optimizer: "adamw" # only adamw supported + +# Schedules +tau_init: 5.0 +tau_final: 0.2 +tau_schedule: "cosine" +lambda_max: 0.0 # no sparsity for sanity check +lambda_warmup_frac: 0.2 + +# Logging +wandb_project: "dagformer" +wandb_run_name: "sanity-check" +log_every: 10 +eval_every: 100 + +# Checkpointing +save_every: 500 +save_dir: "checkpoints/" + +# Hardware +num_gpus: 1 +``` + +Parse into a `@dataclass` with validation. Crash on unknown keys. + +--- + +## 9. Key Invariants (ALWAYS enforce) + +1. **Baseline reproduction**: A=1 (all 30,720 entries) with `input_norm: + "none"` → NLL matches vanilla OLMo within 0.01. This validates the + entire gate injection and per-head splitting logic. Test BEFORE and + AFTER any architectural change. + +2. **DAG constraint**: A is block-upper-triangular based on LAYER indices. + A[i,j] = 0 whenever `j//16 <= i//16`. Enforced by multiplicative mask, + never by loss or gradient clipping. Do NOT use `torch.triu()`. + +3. **Gradient flow**: After every forward-backward, assert that all + predictor parameters have non-None, non-zero gradients. + +4. **Memory budget**: Must fit on 4×A40 for seq_len=1024. If OOM, reduce + batch size. Do NOT change the architecture to fix memory. + +5. **Frozen models stay frozen**: OLMo and Qwen must have + `requires_grad=False` on ALL parameters. Verify this at startup. + In Phase 1, the only trainable parameters are the MLP decoder. + +6. **Deterministic eval**: Eval uses NO Gumbel noise, ever. Two eval modes: + (a) Soft: A = σ(Z/τ), continuous [0,1]. (b) Hard: A = (Z>0).float(), + binary {0,1}, with hard cascading gate (g_j = 1 if inc_j>0, else 0). + Always report both `eval/nll_soft` and `eval/nll_hard`. + +--- + +## 10. Oracle Search Reference Numbers + +These are the results from the completed oracle search. Use them to +validate that the learned predictor is heading in the right direction. + +| Metric | Value | +|--------|-------| +| Windows evaluated | 50 | +| Window size | 1024 tokens | +| Improvement rate | 100% (all 50 improved) | +| Baseline NLL (median) | 2.58 | +| Oracle NLL (median) | 0.12 | +| NLL delta (median) | +2.38 | +| Oracle sequential gates ON | ~91% | +| Oracle hyperconnection gates ON | ~70% | +| Oracle search steps per window | 500 | +| Oracle search method | Surrogate gradient (STE) | + +The learned predictor does NOT need to match oracle performance. The +**decision gate** for Phase 1 success is: predictor NLL ≤ dense baseline +NLL (i.e., the predictor must not make things WORSE). + +--- + +## 11. Dependencies (exact) + +```toml +[project] +name = "dagformer" +version = "0.1.0" +requires-python = ">=3.10" +dependencies = [ + "torch>=2.2", + "transformers>=4.40", + "datasets", + "wandb", + "pyyaml", + "einops", +] +``` + +**NOT used (by design):** +- No HuggingFace Accelerate +- No PyTorch Lightning +- No DeepSpeed +- No custom CUDA kernels + +Multi-GPU via `torch.nn.parallel.DistributedDataParallel` only. + +--- + +## 12. What NOT to Do + +- **Do NOT implement Phase 2** (joint OLMo training). Design the code to + support it (param groups, differential LR) but do not implement unfreezing. +- **Do NOT implement a diffusion-based predictor.** The MLP decoder is + the current design. Diffusion is future work. +- **Do NOT write custom CUDA kernels.** Use dense matmuls with masking. +- **Do NOT support other base models.** Hardcode OLMo2-1B for now. +- **Do NOT use Accelerate or Lightning.** Pure PyTorch. +- **Do NOT run hyperparameter search.** Manual ablations only. +- **Do NOT fork OLMo2's source code.** Load from HF and modify via + hooks, monkey-patching, or subclassing. +- **Do NOT use `nn.DataParallel`.** Use `DistributedDataParallel` only. + +--- + +## 13. Code Style + +- Type hints on all function signatures +- Docstrings on all public functions and classes +- Config as `@dataclass`, not raw dicts +- `assert` for shape checks in every forward pass (e.g., + `assert A.shape == (batch, 256, 256)`) +- No silent failures — crash loudly with informative messages +- Prefer explicit loops over clever one-liners when clarity matters +- One class per file is fine; don't over-split +- Use `einops.rearrange` for complex tensor reshaping (clearer than + chains of `.view().permute().contiguous()`) + +--- + +## 14. Quick Reference — Model IDs and Shapes + +| Thing | Value | +|-------|-------| +| OLMo model ID | `allenai/OLMo-2-0425-1B` | +| Qwen model ID | `Qwen/Qwen3-Embedding-0.6B` | +| OLMo layers | 16 | +| OLMo heads per layer | 16 | +| Total nodes | 256 | +| A matrix shape | `[batch, 256, 256]` | +| A constraint | Block-upper-triangular: `A[i,j]=0` when `j//16 <= i//16` | +| A free entries | 30,720 (cross-layer only) | +| Predictor rank r | 32 (default) | +| Predictor hidden dim | 1024 (default) | +| Sequence length | 1024 | +| Gumbel-Sigmoid τ range | 5.0 → 0.2 | +| Cascading gate k | 5.0 | +| Input normalization | `none` (sanity check), `gate_mean` (training default), ablate all 5 | +| Sparsity λ range | 0 → 0.01 | -- cgit v1.2.3