From 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:00:39 -0600 Subject: Initial implementation: DAGFormer Phase 1 - olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 --- readme.md | 371 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 371 insertions(+) create mode 100644 readme.md (limited to 'readme.md') diff --git a/readme.md b/readme.md new file mode 100644 index 0000000..d1649a2 --- /dev/null +++ b/readme.md @@ -0,0 +1,371 @@ +# DAGFormer + +> **One-liner**: Train a lightweight neural network to predict per-token dynamic +> DAG topologies for OLMo2-1B, replacing expensive oracle search with a +> single forward pass. + +## Project Context + +The oracle search (separate codebase) proved that **context-dependent +topologies exist** which dramatically reduce NLL (2.58 → 0.12 median, +100% improvement rate across 50 windows). However, oracle search takes +~500 optimization steps per window and cannot scale. This codebase trains +a **structure predictor** end-to-end: it reads the current context, +predicts a soft adjacency matrix, and the language model executes under +that topology — all differentiable, no oracle data needed. + +--- + +## Architecture Overview + +``` +Context tokens + │ + ├──────────────────────────┐ + ▼ ▼ + Qwen-3-Embedding-0.6B OLMo2-1B + (frozen encoder) (16L × 16H = 256 nodes) + │ │ + ▼ │ + Structure Predictor │ + (trainable MLP) │ + │ │ + ▼ │ + Gumbel-Sigmoid │ + + Cascading Gate │ + │ │ + ▼ │ + Soft A ∈ [0,1]^{256×256} ───▶ │ (applied as gate mask) + upper-triangular │ + per-token ▼ + NLL Loss + │ + ◀── ∇ backprop to predictor +``` + +### Gate Structure (inherited from oracle search) + +| Gate type | Count | Semantics | +|------------------|---------|----------------------------------------| +| Sequential | 256 | 16 layers × 16 heads (residual → head) | +| Hyperconnection | 30,720 | head_i → head_j for all j > layer(i) | +| **Total** | **~31K**| Binary decisions per context window | + +The adjacency matrix A is 256×256, upper-triangular (DAG constraint). +Row i, col j means "head i sends output to head j". Sequential gates +are the diagonal-block entries; hyperconnection gates are off-diagonal. + +### Oracle Topology Statistics (reference targets) + +- Sequential gates: ~91% ON +- Hyperconnection gates: ~70% ON +- Jaccard similarity between windows: < 0.8 (topologies are context-dependent) + +--- + +## Structure Predictor Architecture + +``` +Input: context embedding e ∈ R^d (from Qwen, pooled or [CLS]) + │ + ▼ + MLP: Linear(d, h) → GELU → Linear(h, h) → GELU + │ + ▼ + Low-rank head: Linear(h, 256·r), Linear(h, 256·r) + │ │ + ▼ ▼ + U ∈ R^{256×r} V ∈ R^{256×r} + │ + ▼ + Z = U V^T ∈ R^{256×256} (logits) + │ + ▼ + UpperTriMask ⊙ σ((Z + G) / τ) (Gumbel-Sigmoid) + │ + ▼ + Cascading Gate: + g_j = σ(k · Σ_i A[i][j]) + A[j, :] *= g_j + │ + ▼ + Soft A ∈ [0,1]^{256×256} +``` + +**Key design choices:** + +1. **Low-rank factorization** — Instead of predicting 65K entries, predict + U,V ∈ R^{256×r} where r ∈ {8, 16, 32, 64}. Inductive bias toward + structured topology. Also enables future diffusion head replacement. + +2. **Gumbel-Sigmoid** — Continuous relaxation of binary gates. + Temperature τ anneals from τ_init to τ_final via cosine schedule. + - τ large → soft (exploration) + - τ small → sharp (near-binary) + - Training: soft; Inference: hard threshold at 0.5 + +3. **Cascading gate** — If a node has no incoming edges, it should have + no outgoing edges (no information to propagate). Enforced as a + differentiable soft constraint: + ``` + incoming_j = Σ_i A[i][j] + g_j = σ(k · incoming_j) # k=5, fixed or learnable + A[j, :] = A[j, :] * g_j # kill outgoing if no incoming + ``` + +--- + +## OLMo2-1B Modification + +The base OLMo2-1B model needs a **modified forward pass** that: + +1. Accepts a soft adjacency matrix A ∈ [0,1]^{256×256} per token +2. Gates the residual stream connections accordingly: + - Each attention head output is multiplied by its gate value before + being added to the residual stream + - Hyperconnections: head_i's output is routed to head_j's input, + weighted by A[i][j] +3. When A is all-ones (or not provided), behavior is identical to vanilla + OLMo2-1B (this is the **sanity check** — must reproduce baseline NLL) + +**Implementation strategy**: Monkey-patch or subclass OLMo2's attention +and residual logic. Do NOT fork the entire model — maintain compatibility +with HuggingFace `olmo2` model loading. + +--- + +## Training Phases + +### Phase 1: Learn Topology (Frozen OLMo) + +| Component | Status | +|---------------------|----------| +| OLMo2-1B | ❄ frozen | +| Qwen-3-Embedding | ❄ frozen | +| Structure Predictor | 🔥 trainable | + +| Hyperparameter | Value | +|----------------------|----------------------| +| Data | Dolma (streamed) | +| Tokens | 5–10B | +| Sequence length | 1024 | +| Batch size | TBD (start 32) | +| LR | 3e-4 | +| Optimizer | AdamW (β1=0.9, β2=0.999, wd=0.01) | +| LR schedule | Cosine decay to 0 | +| τ annealing | 5 → 0.2 (cosine) | +| Sparsity λ | 0 → 0.01 (linear ramp over first 20% steps) | +| Hardware | A40 × 4 | +| **Gate criterion** | NLL ≤ dense baseline on held-out eval | + +**Sparsity loss**: λ · mean(A) — encourages the predictor to turn off +unnecessary connections rather than leaving everything ON. + +### Phase 2: Joint Continued Pre-Training (future) + +| Component | Status | +|---------------------|----------| +| OLMo2-1B | 🔥 unfrozen | +| Qwen-3-Embedding | ❄ frozen | +| Structure Predictor | 🔥 trainable | + +| Hyperparameter | Value | +|----------------------|----------------------| +| Tokens | 20–50B | +| LR (OLMo) | 3e-5 | +| LR (Predictor) | 1e-4 | +| τ continues | → 0.1 | +| Hardware | A100 × 4 | + +Phase 2 is out of scope for initial implementation. Build the training +loop to support it (differential LR groups) but don't implement the +unfreezing logic yet. + +--- + +## Directory Structure + +``` +dagformer/ +├── CLAUDE.md # This file — the spec +├── README.md # Public-facing project description +├── pyproject.toml # Dependencies and project metadata +│ +├── configs/ +│ ├── sanity_check.yaml # Tiny run: 1K steps, verify NLL matches baseline +│ ├── ablation_rank.yaml # Sweep r ∈ {8, 16, 32, 64} +│ ├── ablation_tau.yaml # Sweep τ_init, τ_final +│ ├── ablation_lambda.yaml # Sweep sparsity coefficient +│ └── phase1_full.yaml # Full Phase 1 training config +│ +├── src/ +│ ├── __init__.py +│ │ +│ ├── model/ +│ │ ├── __init__.py +│ │ ├── olmo_graph.py # Modified OLMo2 forward with adjacency injection +│ │ ├── predictor.py # Structure predictor (Qwen + MLP + Gumbel) +│ │ └── pipeline.py # Combines predictor + OLMo into one forward call +│ │ +│ ├── data/ +│ │ ├── __init__.py +│ │ └── dolma.py # Streaming dataloader for Dolma +│ │ +│ ├── training/ +│ │ ├── __init__.py +│ │ ├── trainer.py # Pure PyTorch training loop +│ │ ├── schedulers.py # τ annealing, sparsity ramp, LR schedule +│ │ └── checkpointing.py # Save/load with topology statistics +│ │ +│ └── utils/ +│ ├── __init__.py +│ ├── logging.py # Wandb + console logging +│ └── topology.py # A matrix analysis: sparsity, Jaccard, per-layer stats +│ +├── scripts/ +│ ├── train.py # Entry point: python scripts/train.py --config configs/... +│ ├── eval.py # Evaluate: compare NLL with/without predictor +│ ├── sanity_check.py # Verify: A=all-ones reproduces baseline NLL +│ └── visualize_topology.py # Plot adjacency matrices, gate distributions +│ +└── tests/ + ├── test_olmo_graph.py # Forward pass matches baseline when A=1 + ├── test_predictor.py # Output shapes, gradient flow + └── test_gumbel.py # Gumbel-Sigmoid properties at various τ +``` + +--- + +## Implementation Order + +Build and **test each module in isolation** before combining. + +### Step 0: Environment Setup +- [ ] `pyproject.toml` with deps: torch, transformers, datasets, wandb, pyyaml, einops +- [ ] Verify OLMo2-1B loads: `AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B")` +- [ ] Verify Qwen loads: `AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")` +- [ ] Verify Dolma streaming: `load_dataset("allenai/dolma", ..., streaming=True)` + +### Step 1: `model/olmo_graph.py` — Modified OLMo Forward +- [ ] Load OLMo2-1B, identify where head outputs merge into residual +- [ ] Implement gate injection: multiply head outputs by A values +- [ ] **Sanity check**: A = all-ones → NLL matches vanilla forward within 0.01 +- [ ] **Sanity check**: A = all-zeros → NLL is very high (model is broken) +- [ ] Verify gradients flow through A (A.requires_grad=True, check A.grad is not None) + +### Step 2: `model/predictor.py` — Structure Predictor +- [ ] Qwen encoder wrapper (frozen, returns pooled embedding) +- [ ] MLP decoder: e → h → (U, V) → Z = UV^T +- [ ] Gumbel-Sigmoid: σ((Z + G) / τ) with configurable τ +- [ ] Upper-triangular mask +- [ ] Cascading gate +- [ ] **Test**: output shape is (batch, 256, 256), values in [0,1], upper-tri +- [ ] **Test**: gradient flows from output back to MLP parameters + +### Step 3: `model/pipeline.py` — End-to-End Forward +- [ ] Combine predictor + OLMo: tokens → A → modified_forward → NLL +- [ ] Verify full gradient chain: NLL.backward() updates predictor params +- [ ] Profile memory: should fit on single A40 (48GB) for seq_len=1024, batch=1 + +### Step 4: `training/` — Training Infrastructure +- [ ] Config loading (yaml → dataclass) +- [ ] τ annealing schedule (cosine: τ_init → τ_final) +- [ ] Sparsity λ ramp (linear: 0 → λ_max over warmup fraction) +- [ ] LR schedule (cosine decay) +- [ ] Training loop: forward → loss → backward → step → log +- [ ] Wandb logging: NLL, sparsity(A), τ, gate statistics per layer +- [ ] Checkpointing: save predictor weights + optimizer + step + τ + +### Step 5: Sanity Check Run +- [ ] `configs/sanity_check.yaml`: 1K steps, small batch, high τ +- [ ] Verify: loss decreases, A is not collapsing to all-ones or all-zeros +- [ ] Verify: gradient norms are reasonable (not exploding/vanishing) +- [ ] **Decision gate**: if loss doesn't decrease in 1K steps, debug before proceeding + +### Step 6: Ablations +- [ ] Rank r ∈ {8, 16, 32, 64}: which gives best NLL-sparsity tradeoff? +- [ ] τ schedule: (5→0.2) vs (2→0.1) vs (10→0.5) +- [ ] Sparsity λ: 0 vs 0.001 vs 0.01 vs 0.1 +- [ ] Cascading gate: with vs without + +--- + +## Key Invariants (must always hold) + +1. **Baseline reproduction**: When the predictor outputs A = all-ones, + the NLL must match vanilla OLMo2-1B within 1%. + +2. **DAG constraint**: A is always upper-triangular (no cycles). + Enforced by mask, not by loss. + +3. **Gradient flow**: `loss.backward()` must produce non-None gradients + for all predictor parameters. Check after every architectural change. + +4. **Memory budget**: Phase 1 must fit on 4× A40 (48GB each) with + seq_len=1024. If it doesn't, reduce batch size before changing + architecture. + +5. **Deterministic eval**: At eval time, use hard threshold (A > 0.5) + with no Gumbel noise. Eval NLL must be reported with hard gates. + +--- + +## Logging & Monitoring + +Log to **Wandb** at every step: + +| Metric | What it tells you | +|------------------------|------------------------------------------| +| `train/nll` | Primary objective | +| `train/sparsity_loss` | λ · mean(A) | +| `train/total_loss` | NLL + sparsity | +| `eval/nll_soft` | NLL with soft gates (Gumbel, current τ) | +| `eval/nll_hard` | NLL with hard gates (threshold 0.5) | +| `eval/nll_baseline` | NLL with A=1 (should be constant) | +| `topology/sparsity` | 1 - mean(A) | +| `topology/seq_gate_on` | Fraction of sequential gates > 0.5 | +| `topology/hyp_gate_on` | Fraction of hyperconnection gates > 0.5 | +| `topology/jaccard_var` | Variance of Jaccard similarity across batch | +| `schedule/tau` | Current temperature | +| `schedule/lambda` | Current sparsity coefficient | +| `gradients/predictor_norm` | Total gradient norm | + +**Collapse detection**: If `topology/sparsity` < 0.01 or > 0.99 for +100 consecutive steps, something is wrong. Log a warning. + +--- + +## Dependencies + +``` +torch >= 2.2 +transformers >= 4.40 +datasets +wandb +pyyaml +einops +``` + +No Accelerate, no Lightning. Pure PyTorch with `torch.nn.parallel.DistributedDataParallel` +for multi-GPU. Keep it simple and transparent. + +--- + +## What NOT to Build (out of scope) + +- Phase 2 (joint CPT) — design for it, don't implement yet +- Diffusion-based topology predictor — future work +- Custom CUDA kernels for sparse attention — use dense ops with masking +- Support for models other than OLMo2-1B — hardcode for now +- Fancy hyperparameter search — manual ablations are fine + +--- + +## Code Style + +- Type hints everywhere +- Docstrings on all public functions +- Config dataclasses, not raw dicts +- `assert` liberally for shape checks in forward passes +- No silent failures: if something is wrong, crash loudly +- Prefer explicit over clever: `for layer in layers` over `map(lambda ...)` -- cgit v1.2.3