Lookahead Structure Predictor for OLMo2-1B

Per-token context-conditioned DAG · Head-level 256×256 upper-triangular adjacency · Cascading activation gate

⚡ Topology predicted before each forward pass · Fully differentiable via continuous relaxation

TOPOLOGY PREDICTION (per token) ⚡ runs before OLMo forward pass Current Context changes every generation step Qwen-3-Embedding-0.6B frozen · context → d-dim vector e e Structure Predictor (Trainable) Low-Rank Parameterization e → MLP → U, V ∈ ℝ^{256×r} → Z = UV^T Gumbel-Sigmoid + Upper-Tri Mask A_raw = UpperTriMask ⊙ σ((Z + G) / τ) Cascading Activation Gate gⱼ = σ( k · Σᵢ A_raw[i][j] ) // incoming sum A[j, :] = gⱼ · A_raw[j, :] // gate outgoing no input → gⱼ≈0 → no output · fully differentiable · k = learnable or fixed Soft Adjacency A ∈ [0,1]^{256×256} upper-tri · per-token dynamic · cascading-gated OLMO2-1B INFERENCE tokens OLMo2-1B 16 layers × 16 heads = 256 nodes L0 L1 L2 L8 heavily pruned L15 Weighted head connectivity input_j = Σᵢ A[i][j] · output_i (soft → differentiable) apply A LM Head → logits NLL Loss ∇ NLL TRAINING — FULLY DIFFERENTIABLE PIPELINE Phase 1: Train Predictor Only OLMo2-1B frozen 🔒 Qwen-3-Emb frozen 🔒 Predictor trainable 🔓 ← only these params update Goal: learn topologies that lower NLL vs dense baseline Also generates (context, topology) pairs for future diffusion head then Phase 2: Joint Training (CPT) OLMo2-1B unfrozen 🔓 ← adapts to predicted topologies Qwen-3-Emb frozen 🔒 Predictor trainable 🔓 ← co-evolves with OLMo Goal: OLMo + Predictor co-alignment Optional: swap MLP decoder → diffusion head (multi-modal topologies) active pruned hyperconn gradient forward cascading gate (new)

Cascading Gate enforces: no incoming edges → no outgoing edges · differentiable via soft sigmoid gate
Future: Phase 1 data → train diffusion decoder to capture multi-modal optimal topologies