summaryrefslogtreecommitdiff
path: root/readme.md
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-09 11:00:39 -0600
commit13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (patch)
tree073534138604c1c49021ca7e334322262129f6ac /readme.md
Initial implementation: DAGFormer Phase 1
- olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'readme.md')
-rw-r--r--readme.md371
1 files changed, 371 insertions, 0 deletions
diff --git a/readme.md b/readme.md
new file mode 100644
index 0000000..d1649a2
--- /dev/null
+++ b/readme.md
@@ -0,0 +1,371 @@
+# DAGFormer
+
+> **One-liner**: Train a lightweight neural network to predict per-token dynamic
+> DAG topologies for OLMo2-1B, replacing expensive oracle search with a
+> single forward pass.
+
+## Project Context
+
+The oracle search (separate codebase) proved that **context-dependent
+topologies exist** which dramatically reduce NLL (2.58 → 0.12 median,
+100% improvement rate across 50 windows). However, oracle search takes
+~500 optimization steps per window and cannot scale. This codebase trains
+a **structure predictor** end-to-end: it reads the current context,
+predicts a soft adjacency matrix, and the language model executes under
+that topology — all differentiable, no oracle data needed.
+
+---
+
+## Architecture Overview
+
+```
+Context tokens
+ │
+ ├──────────────────────────┐
+ ▼ ▼
+ Qwen-3-Embedding-0.6B OLMo2-1B
+ (frozen encoder) (16L × 16H = 256 nodes)
+ │ │
+ ▼ │
+ Structure Predictor │
+ (trainable MLP) │
+ │ │
+ ▼ │
+ Gumbel-Sigmoid │
+ + Cascading Gate │
+ │ │
+ ▼ │
+ Soft A ∈ [0,1]^{256×256} ───▶ │ (applied as gate mask)
+ upper-triangular │
+ per-token ▼
+ NLL Loss
+ │
+ ◀── ∇ backprop to predictor
+```
+
+### Gate Structure (inherited from oracle search)
+
+| Gate type | Count | Semantics |
+|------------------|---------|----------------------------------------|
+| Sequential | 256 | 16 layers × 16 heads (residual → head) |
+| Hyperconnection | 30,720 | head_i → head_j for all j > layer(i) |
+| **Total** | **~31K**| Binary decisions per context window |
+
+The adjacency matrix A is 256×256, upper-triangular (DAG constraint).
+Row i, col j means "head i sends output to head j". Sequential gates
+are the diagonal-block entries; hyperconnection gates are off-diagonal.
+
+### Oracle Topology Statistics (reference targets)
+
+- Sequential gates: ~91% ON
+- Hyperconnection gates: ~70% ON
+- Jaccard similarity between windows: < 0.8 (topologies are context-dependent)
+
+---
+
+## Structure Predictor Architecture
+
+```
+Input: context embedding e ∈ R^d (from Qwen, pooled or [CLS])
+ │
+ ▼
+ MLP: Linear(d, h) → GELU → Linear(h, h) → GELU
+ │
+ ▼
+ Low-rank head: Linear(h, 256·r), Linear(h, 256·r)
+ │ │
+ ▼ ▼
+ U ∈ R^{256×r} V ∈ R^{256×r}
+ │
+ ▼
+ Z = U V^T ∈ R^{256×256} (logits)
+ │
+ ▼
+ UpperTriMask ⊙ σ((Z + G) / τ) (Gumbel-Sigmoid)
+ │
+ ▼
+ Cascading Gate:
+ g_j = σ(k · Σ_i A[i][j])
+ A[j, :] *= g_j
+ │
+ ▼
+ Soft A ∈ [0,1]^{256×256}
+```
+
+**Key design choices:**
+
+1. **Low-rank factorization** — Instead of predicting 65K entries, predict
+ U,V ∈ R^{256×r} where r ∈ {8, 16, 32, 64}. Inductive bias toward
+ structured topology. Also enables future diffusion head replacement.
+
+2. **Gumbel-Sigmoid** — Continuous relaxation of binary gates.
+ Temperature τ anneals from τ_init to τ_final via cosine schedule.
+ - τ large → soft (exploration)
+ - τ small → sharp (near-binary)
+ - Training: soft; Inference: hard threshold at 0.5
+
+3. **Cascading gate** — If a node has no incoming edges, it should have
+ no outgoing edges (no information to propagate). Enforced as a
+ differentiable soft constraint:
+ ```
+ incoming_j = Σ_i A[i][j]
+ g_j = σ(k · incoming_j) # k=5, fixed or learnable
+ A[j, :] = A[j, :] * g_j # kill outgoing if no incoming
+ ```
+
+---
+
+## OLMo2-1B Modification
+
+The base OLMo2-1B model needs a **modified forward pass** that:
+
+1. Accepts a soft adjacency matrix A ∈ [0,1]^{256×256} per token
+2. Gates the residual stream connections accordingly:
+ - Each attention head output is multiplied by its gate value before
+ being added to the residual stream
+ - Hyperconnections: head_i's output is routed to head_j's input,
+ weighted by A[i][j]
+3. When A is all-ones (or not provided), behavior is identical to vanilla
+ OLMo2-1B (this is the **sanity check** — must reproduce baseline NLL)
+
+**Implementation strategy**: Monkey-patch or subclass OLMo2's attention
+and residual logic. Do NOT fork the entire model — maintain compatibility
+with HuggingFace `olmo2` model loading.
+
+---
+
+## Training Phases
+
+### Phase 1: Learn Topology (Frozen OLMo)
+
+| Component | Status |
+|---------------------|----------|
+| OLMo2-1B | ❄ frozen |
+| Qwen-3-Embedding | ❄ frozen |
+| Structure Predictor | 🔥 trainable |
+
+| Hyperparameter | Value |
+|----------------------|----------------------|
+| Data | Dolma (streamed) |
+| Tokens | 5–10B |
+| Sequence length | 1024 |
+| Batch size | TBD (start 32) |
+| LR | 3e-4 |
+| Optimizer | AdamW (β1=0.9, β2=0.999, wd=0.01) |
+| LR schedule | Cosine decay to 0 |
+| τ annealing | 5 → 0.2 (cosine) |
+| Sparsity λ | 0 → 0.01 (linear ramp over first 20% steps) |
+| Hardware | A40 × 4 |
+| **Gate criterion** | NLL ≤ dense baseline on held-out eval |
+
+**Sparsity loss**: λ · mean(A) — encourages the predictor to turn off
+unnecessary connections rather than leaving everything ON.
+
+### Phase 2: Joint Continued Pre-Training (future)
+
+| Component | Status |
+|---------------------|----------|
+| OLMo2-1B | 🔥 unfrozen |
+| Qwen-3-Embedding | ❄ frozen |
+| Structure Predictor | 🔥 trainable |
+
+| Hyperparameter | Value |
+|----------------------|----------------------|
+| Tokens | 20–50B |
+| LR (OLMo) | 3e-5 |
+| LR (Predictor) | 1e-4 |
+| τ continues | → 0.1 |
+| Hardware | A100 × 4 |
+
+Phase 2 is out of scope for initial implementation. Build the training
+loop to support it (differential LR groups) but don't implement the
+unfreezing logic yet.
+
+---
+
+## Directory Structure
+
+```
+dagformer/
+├── CLAUDE.md # This file — the spec
+├── README.md # Public-facing project description
+├── pyproject.toml # Dependencies and project metadata
+│
+├── configs/
+│ ├── sanity_check.yaml # Tiny run: 1K steps, verify NLL matches baseline
+│ ├── ablation_rank.yaml # Sweep r ∈ {8, 16, 32, 64}
+│ ├── ablation_tau.yaml # Sweep τ_init, τ_final
+│ ├── ablation_lambda.yaml # Sweep sparsity coefficient
+│ └── phase1_full.yaml # Full Phase 1 training config
+│
+├── src/
+│ ├── __init__.py
+│ │
+│ ├── model/
+│ │ ├── __init__.py
+│ │ ├── olmo_graph.py # Modified OLMo2 forward with adjacency injection
+│ │ ├── predictor.py # Structure predictor (Qwen + MLP + Gumbel)
+│ │ └── pipeline.py # Combines predictor + OLMo into one forward call
+│ │
+│ ├── data/
+│ │ ├── __init__.py
+│ │ └── dolma.py # Streaming dataloader for Dolma
+│ │
+│ ├── training/
+│ │ ├── __init__.py
+│ │ ├── trainer.py # Pure PyTorch training loop
+│ │ ├── schedulers.py # τ annealing, sparsity ramp, LR schedule
+│ │ └── checkpointing.py # Save/load with topology statistics
+│ │
+│ └── utils/
+│ ├── __init__.py
+│ ├── logging.py # Wandb + console logging
+│ └── topology.py # A matrix analysis: sparsity, Jaccard, per-layer stats
+│
+├── scripts/
+│ ├── train.py # Entry point: python scripts/train.py --config configs/...
+│ ├── eval.py # Evaluate: compare NLL with/without predictor
+│ ├── sanity_check.py # Verify: A=all-ones reproduces baseline NLL
+│ └── visualize_topology.py # Plot adjacency matrices, gate distributions
+│
+└── tests/
+ ├── test_olmo_graph.py # Forward pass matches baseline when A=1
+ ├── test_predictor.py # Output shapes, gradient flow
+ └── test_gumbel.py # Gumbel-Sigmoid properties at various τ
+```
+
+---
+
+## Implementation Order
+
+Build and **test each module in isolation** before combining.
+
+### Step 0: Environment Setup
+- [ ] `pyproject.toml` with deps: torch, transformers, datasets, wandb, pyyaml, einops
+- [ ] Verify OLMo2-1B loads: `AutoModelForCausalLM.from_pretrained("allenai/OLMo-2-0425-1B")`
+- [ ] Verify Qwen loads: `AutoModel.from_pretrained("Qwen/Qwen3-Embedding-0.6B")`
+- [ ] Verify Dolma streaming: `load_dataset("allenai/dolma", ..., streaming=True)`
+
+### Step 1: `model/olmo_graph.py` — Modified OLMo Forward
+- [ ] Load OLMo2-1B, identify where head outputs merge into residual
+- [ ] Implement gate injection: multiply head outputs by A values
+- [ ] **Sanity check**: A = all-ones → NLL matches vanilla forward within 0.01
+- [ ] **Sanity check**: A = all-zeros → NLL is very high (model is broken)
+- [ ] Verify gradients flow through A (A.requires_grad=True, check A.grad is not None)
+
+### Step 2: `model/predictor.py` — Structure Predictor
+- [ ] Qwen encoder wrapper (frozen, returns pooled embedding)
+- [ ] MLP decoder: e → h → (U, V) → Z = UV^T
+- [ ] Gumbel-Sigmoid: σ((Z + G) / τ) with configurable τ
+- [ ] Upper-triangular mask
+- [ ] Cascading gate
+- [ ] **Test**: output shape is (batch, 256, 256), values in [0,1], upper-tri
+- [ ] **Test**: gradient flows from output back to MLP parameters
+
+### Step 3: `model/pipeline.py` — End-to-End Forward
+- [ ] Combine predictor + OLMo: tokens → A → modified_forward → NLL
+- [ ] Verify full gradient chain: NLL.backward() updates predictor params
+- [ ] Profile memory: should fit on single A40 (48GB) for seq_len=1024, batch=1
+
+### Step 4: `training/` — Training Infrastructure
+- [ ] Config loading (yaml → dataclass)
+- [ ] τ annealing schedule (cosine: τ_init → τ_final)
+- [ ] Sparsity λ ramp (linear: 0 → λ_max over warmup fraction)
+- [ ] LR schedule (cosine decay)
+- [ ] Training loop: forward → loss → backward → step → log
+- [ ] Wandb logging: NLL, sparsity(A), τ, gate statistics per layer
+- [ ] Checkpointing: save predictor weights + optimizer + step + τ
+
+### Step 5: Sanity Check Run
+- [ ] `configs/sanity_check.yaml`: 1K steps, small batch, high τ
+- [ ] Verify: loss decreases, A is not collapsing to all-ones or all-zeros
+- [ ] Verify: gradient norms are reasonable (not exploding/vanishing)
+- [ ] **Decision gate**: if loss doesn't decrease in 1K steps, debug before proceeding
+
+### Step 6: Ablations
+- [ ] Rank r ∈ {8, 16, 32, 64}: which gives best NLL-sparsity tradeoff?
+- [ ] τ schedule: (5→0.2) vs (2→0.1) vs (10→0.5)
+- [ ] Sparsity λ: 0 vs 0.001 vs 0.01 vs 0.1
+- [ ] Cascading gate: with vs without
+
+---
+
+## Key Invariants (must always hold)
+
+1. **Baseline reproduction**: When the predictor outputs A = all-ones,
+ the NLL must match vanilla OLMo2-1B within 1%.
+
+2. **DAG constraint**: A is always upper-triangular (no cycles).
+ Enforced by mask, not by loss.
+
+3. **Gradient flow**: `loss.backward()` must produce non-None gradients
+ for all predictor parameters. Check after every architectural change.
+
+4. **Memory budget**: Phase 1 must fit on 4× A40 (48GB each) with
+ seq_len=1024. If it doesn't, reduce batch size before changing
+ architecture.
+
+5. **Deterministic eval**: At eval time, use hard threshold (A > 0.5)
+ with no Gumbel noise. Eval NLL must be reported with hard gates.
+
+---
+
+## Logging & Monitoring
+
+Log to **Wandb** at every step:
+
+| Metric | What it tells you |
+|------------------------|------------------------------------------|
+| `train/nll` | Primary objective |
+| `train/sparsity_loss` | λ · mean(A) |
+| `train/total_loss` | NLL + sparsity |
+| `eval/nll_soft` | NLL with soft gates (Gumbel, current τ) |
+| `eval/nll_hard` | NLL with hard gates (threshold 0.5) |
+| `eval/nll_baseline` | NLL with A=1 (should be constant) |
+| `topology/sparsity` | 1 - mean(A) |
+| `topology/seq_gate_on` | Fraction of sequential gates > 0.5 |
+| `topology/hyp_gate_on` | Fraction of hyperconnection gates > 0.5 |
+| `topology/jaccard_var` | Variance of Jaccard similarity across batch |
+| `schedule/tau` | Current temperature |
+| `schedule/lambda` | Current sparsity coefficient |
+| `gradients/predictor_norm` | Total gradient norm |
+
+**Collapse detection**: If `topology/sparsity` < 0.01 or > 0.99 for
+100 consecutive steps, something is wrong. Log a warning.
+
+---
+
+## Dependencies
+
+```
+torch >= 2.2
+transformers >= 4.40
+datasets
+wandb
+pyyaml
+einops
+```
+
+No Accelerate, no Lightning. Pure PyTorch with `torch.nn.parallel.DistributedDataParallel`
+for multi-GPU. Keep it simple and transparent.
+
+---
+
+## What NOT to Build (out of scope)
+
+- Phase 2 (joint CPT) — design for it, don't implement yet
+- Diffusion-based topology predictor — future work
+- Custom CUDA kernels for sparse attention — use dense ops with masking
+- Support for models other than OLMo2-1B — hardcode for now
+- Fancy hyperparameter search — manual ablations are fine
+
+---
+
+## Code Style
+
+- Type hints everywhere
+- Docstrings on all public functions
+- Config dataclasses, not raw dicts
+- `assert` liberally for shape checks in forward passes
+- No silent failures: if something is wrong, crash loudly
+- Prefer explicit over clever: `for layer in layers` over `map(lambda ...)`