From 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Mon, 9 Feb 2026 11:00:39 -0600 Subject: Initial implementation: DAGFormer Phase 1 - olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A - Proportional attribution for post-norm decomposition - All 6 GPU sanity checks pass (baseline diff = 0.000001) - predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate - pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL) - trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing - dolma.py: Streaming Dolma v1.7 with sequence packing - 43/43 unit tests pass Co-Authored-By: Claude Opus 4.6 --- structure_predictor.html | 289 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 289 insertions(+) create mode 100644 structure_predictor.html (limited to 'structure_predictor.html') diff --git a/structure_predictor.html b/structure_predictor.html new file mode 100644 index 0000000..7460649 --- /dev/null +++ b/structure_predictor.html @@ -0,0 +1,289 @@ + + + + + +Structure Predictor v4 + + + +
+

Lookahead Structure Predictor for OLMo2-1B

+

Per-token context-conditioned DAG · Head-level 256×256 upper-triangular adjacency · Cascading activation gate

+

⚡ Topology predicted before each forward pass · Fully differentiable via continuous relaxation

+ + + + + + + + + + + +TOPOLOGY PREDICTION (per token) +⚡ runs before OLMo forward pass + + + +Current Context +changes every generation step + + + + + + +Qwen-3-Embedding-0.6B +frozen · context → d-dim vector e + + + +e + + + +Structure Predictor (Trainable) + +Low-Rank Parameterization +e → MLP → U, V ∈ ℝ^{256×r} → Z = UV^T + + + + + + +Gumbel-Sigmoid + Upper-Tri Mask +A_raw = UpperTriMask ⊙ σ((Z + G) / τ) + + + + + + +Cascading Activation Gate +gⱼ = σ( k · Σᵢ A_raw[i][j] ) // incoming sum +A[j, :] = gⱼ · A_raw[j, :] // gate outgoing +no input → gⱼ≈0 → no output · fully differentiable · k = learnable or fixed + + + + + + +Soft Adjacency A ∈ [0,1]^{256×256} +upper-tri · per-token dynamic · cascading-gated + + + +OLMO2-1B INFERENCE + + + +tokens + + + +OLMo2-1B +16 layers × 16 heads = 256 nodes + + + +L0 + + + + + + + + + + + + + + + + + + + + +L1 + + + + + + + + + + + + + + + + + + + + +L2 + + + + + + + + + + + + + + + + + + + + + + +L8 + + + + + + + + + + + + + + + + + + +heavily pruned + + + + +L15 + + + + + + + + + + + + + + + + + + + + + + + + + + +Weighted head connectivity +input_j = Σᵢ A[i][j] · output_i (soft → differentiable) + + + +apply A + + + +LM Head → logits + + + + + + +NLL Loss + + + + +∇ NLL + + + +TRAINING — FULLY DIFFERENTIABLE PIPELINE + + + +Phase 1: Train Predictor Only +OLMo2-1B frozen 🔒 +Qwen-3-Emb frozen 🔒 +Predictor trainable 🔓 ← only these params update +Goal: learn topologies that lower NLL vs dense baseline +Also generates (context, topology) pairs for future diffusion head + + + +then + + + +Phase 2: Joint Training (CPT) +OLMo2-1B unfrozen 🔓 ← adapts to predicted topologies +Qwen-3-Emb frozen 🔒 +Predictor trainable 🔓 ← co-evolves with OLMo +Goal: OLMo + Predictor co-alignment +Optional: swap MLP decoder → diffusion head (multi-modal topologies) + + + + + active + + pruned + + hyperconn + + gradient + + forward + + cascading gate (new) + + + +

+ Cascading Gate enforces: no incoming edges → no outgoing edges · differentiable via soft sigmoid gate
+ Future: Phase 1 data → train diffusion decoder to capture multi-modal optimal topologies +

+
+ + -- cgit v1.2.3