diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-09 11:00:39 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-09 11:00:39 -0600 |
| commit | 13ddc8dc583d8b1355909970cb8c27f85b7d3c8b (patch) | |
| tree | 073534138604c1c49021ca7e334322262129f6ac /structure_predictor.html | |
Initial implementation: DAGFormer Phase 1
- olmo_graph.py: Modified OLMo2-1B forward with per-head routing via 256x256 adjacency matrix A
- Proportional attribution for post-norm decomposition
- All 6 GPU sanity checks pass (baseline diff = 0.000001)
- predictor.py: Qwen3-Embedding encoder + MLP decoder + Gumbel-Sigmoid + cascading gate
- pipeline.py: End-to-end glue (predictor -> A -> OLMo -> NLL)
- trainer.py: Full training loop with DDP, gradient accumulation, eval, checkpointing
- dolma.py: Streaming Dolma v1.7 with sequence packing
- 43/43 unit tests pass
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'structure_predictor.html')
| -rw-r--r-- | structure_predictor.html | 289 |
1 files changed, 289 insertions, 0 deletions
diff --git a/structure_predictor.html b/structure_predictor.html new file mode 100644 index 0000000..7460649 --- /dev/null +++ b/structure_predictor.html @@ -0,0 +1,289 @@ +<!DOCTYPE html> +<html lang="en"> +<head> +<meta charset="UTF-8"> +<meta name="viewport" content="width=device-width, initial-scale=1.0"> +<title>Structure Predictor v4</title> +<style> + *{margin:0;padding:0;box-sizing:border-box} + body{background:#08080d;color:#ddd;font-family:'Segoe UI',system-ui,sans-serif;display:flex;justify-content:center;padding:24px 16px} + .c{width:1040px} + h1{text-align:center;font-size:18px;font-weight:600;color:#c8c8d0;margin-bottom:5px} + .sub{text-align:center;font-size:12px;color:#6a6a7a;margin-bottom:6px} + .sub2{text-align:center;font-size:11px;color:#8a7a4a;margin-bottom:28px} + .foot{text-align:center;font-size:10px;color:#555;margin-top:14px;line-height:1.6} +</style> +</head> +<body> +<div class="c"> +<h1>Lookahead Structure Predictor for OLMo2-1B</h1> +<p class="sub">Per-token context-conditioned DAG · Head-level 256×256 upper-triangular adjacency · Cascading activation gate</p> +<p class="sub2">⚡ Topology predicted before each forward pass · Fully differentiable via continuous relaxation</p> + +<svg viewBox="0 0 1040 780" xmlns="http://www.w3.org/2000/svg"> +<defs> + <marker id="ab" markerWidth="7" markerHeight="5" refX="6" refY="2.5" orient="auto"><polygon points="0 0,7 2.5,0 5" fill="#5a80bb"/></marker> + <marker id="ap" markerWidth="7" markerHeight="5" refX="6" refY="2.5" orient="auto"><polygon points="0 0,7 2.5,0 5" fill="#bb5a78"/></marker> + <marker id="ag" markerWidth="7" markerHeight="5" refX="6" refY="2.5" orient="auto"><polygon points="0 0,7 2.5,0 5" fill="#5abb78"/></marker> + <marker id="ar" markerWidth="7" markerHeight="5" refX="6" refY="2.5" orient="auto"><polygon points="0 0,7 2.5,0 5" fill="#cc4444"/></marker> +</defs> + +<!-- ============ LEFT: TOPOLOGY PREDICTION ============ --> +<rect x="15" y="12" width="500" height="555" rx="10" fill="none" stroke="#252530" stroke-dasharray="4 3"/> +<text x="265" y="32" text-anchor="middle" font-size="10" fill="#5a5a6a" letter-spacing="2" text-transform="uppercase">TOPOLOGY PREDICTION (per token)</text> +<text x="265" y="47" text-anchor="middle" font-size="9" fill="#e8a04c">⚡ runs before OLMo forward pass</text> + +<!-- Input --> +<rect x="115" y="62" width="300" height="38" rx="7" fill="#10141e" stroke="#3a5578" stroke-width="1.4"/> +<text x="265" y="80" text-anchor="middle" font-size="12.5" fill="#ddd">Current Context</text> +<text x="265" y="93" text-anchor="middle" font-size="9" fill="#6a6a7a">changes every generation step</text> + +<!-- Arrow --> +<line x1="265" y1="100" x2="265" y2="124" stroke="#5a80bb" stroke-width="1.3" marker-end="url(#ab)"/> + +<!-- Qwen --> +<rect x="105" y="128" width="320" height="46" rx="7" fill="#101420" stroke="#3a5590" stroke-width="1.4"/> +<text x="265" y="148" text-anchor="middle" font-size="12.5" fill="#ddd">Qwen-3-Embedding-0.6B</text> +<text x="265" y="164" text-anchor="middle" font-size="9" fill="#6a6a7a">frozen · context → d-dim vector e</text> + +<!-- Arrow --> +<line x1="265" y1="174" x2="265" y2="200" stroke="#5a80bb" stroke-width="1.3" marker-end="url(#ab)"/> +<text x="285" y="190" text-anchor="start" font-size="9" fill="#5a80bb">e</text> + +<!-- Decoder --> +<rect x="70" y="204" width="390" height="95" rx="7" fill="#18102a" stroke="#7040a0" stroke-width="1.4"/> +<text x="265" y="224" text-anchor="middle" font-size="12.5" fill="#ddd" font-weight="600">Structure Predictor (Trainable)</text> +<rect x="90" y="238" width="350" height="50" rx="5" fill="#140e22" stroke="#5a3580"/> +<text x="265" y="256" text-anchor="middle" font-size="10.5" fill="#b090d0">Low-Rank Parameterization</text> +<text x="265" y="272" text-anchor="middle" font-family="Consolas,monospace" font-size="9" fill="#9080b0">e → MLP → U, V ∈ ℝ^{256×r} → Z = UV^T</text> + +<!-- Arrow --> +<line x1="265" y1="299" x2="265" y2="322" stroke="#5a80bb" stroke-width="1.3" marker-end="url(#ab)"/> + +<!-- Gumbel-Sigmoid --> +<rect x="85" y="326" width="360" height="50" rx="7" fill="#1e1418" stroke="#a05070" stroke-width="1.4"/> +<text x="265" y="345" text-anchor="middle" font-size="12.5" fill="#ddd">Gumbel-Sigmoid + Upper-Tri Mask</text> +<text x="265" y="363" text-anchor="middle" font-family="Consolas,monospace" font-size="9" fill="#c08090">A_raw = UpperTriMask ⊙ σ((Z + G) / τ)</text> + +<!-- Arrow --> +<line x1="265" y1="376" x2="265" y2="400" stroke="#5a80bb" stroke-width="1.3" marker-end="url(#ab)"/> + +<!-- ★ CASCADING ACTIVATION GATE — NEW --> +<rect x="60" y="404" width="410" height="72" rx="7" fill="#1a1410" stroke="#c09040" stroke-width="1.4"/> +<text x="265" y="424" text-anchor="middle" font-size="12.5" fill="#e8c06a" font-weight="600">Cascading Activation Gate</text> +<text x="265" y="442" text-anchor="middle" font-family="Consolas,monospace" font-size="9" fill="#c0a060">gⱼ = σ( k · Σᵢ A_raw[i][j] ) // incoming sum</text> +<text x="265" y="458" text-anchor="middle" font-family="Consolas,monospace" font-size="9" fill="#c0a060">A[j, :] = gⱼ · A_raw[j, :] // gate outgoing</text> +<text x="265" y="472" text-anchor="middle" font-size="8.5" fill="#8a7a4a">no input → gⱼ≈0 → no output · fully differentiable · k = learnable or fixed</text> + +<!-- Arrow --> +<line x1="265" y1="476" x2="265" y2="500" stroke="#5a80bb" stroke-width="1.3" marker-end="url(#ab)"/> + +<!-- Soft Adjacency Output --> +<rect x="60" y="504" width="410" height="52" rx="7" fill="#1e0e14" stroke="#bb5a78" stroke-width="1.4"/> +<text x="265" y="524" text-anchor="middle" font-size="12.5" fill="#ddd" font-weight="600">Soft Adjacency A ∈ [0,1]^{256×256}</text> +<text x="265" y="542" text-anchor="middle" font-size="9" fill="#bb5a78">upper-tri · per-token dynamic · cascading-gated</text> + +<!-- ============ RIGHT: OLMo INFERENCE ============ --> +<rect x="545" y="12" width="480" height="445" rx="10" fill="none" stroke="#1a3020" stroke-dasharray="4 3"/> +<text x="785" y="32" text-anchor="middle" font-size="10" fill="#5a5a6a" letter-spacing="2">OLMO2-1B INFERENCE</text> + +<!-- Context → OLMo --> +<path d="M 415 81 L 490 81 Q 520 81 520 100 L 520 81 Q 520 66 540 66 L 564 66" stroke="#5abb78" stroke-width="1.3" fill="none" marker-end="url(#ag)"/> +<text x="490" y="58" text-anchor="middle" font-size="9" fill="#5abb78">tokens</text> + +<!-- OLMo body --> +<rect x="566" y="48" width="444" height="290" rx="7" fill="#0c160e" stroke="#3a7a4a" stroke-width="1.4"/> +<text x="788" y="70" text-anchor="middle" font-size="12.5" fill="#ddd" font-weight="600">OLMo2-1B</text> +<text x="788" y="86" text-anchor="middle" font-size="9" fill="#6a6a7a">16 layers × 16 heads = 256 nodes</text> + +<!-- Layer rows — each 16 heads --> +<!-- L0 --> +<text x="590" y="110" text-anchor="start" font-size="10" fill="#4a7a5a">L0</text> +<g transform="translate(618,102)"> + <rect x="0" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="14" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="28" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="42" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="56" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="70" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="84" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="98" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="112" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="126" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="140" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="154" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="168" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="182" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="196" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="210" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> +</g> + +<!-- L1 (some pruned) --> +<text x="590" y="138" text-anchor="start" font-size="10" fill="#4a7a5a">L1</text> +<g transform="translate(618,130)"> + <rect x="0" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="14" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="28" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="42" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="56" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="70" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="84" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="98" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="112" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="126" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="140" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="154" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="168" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="182" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="196" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="210" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> +</g> + +<!-- L2 --> +<text x="590" y="166" text-anchor="start" font-size="10" fill="#4a7a5a">L2</text> +<g transform="translate(618,158)"> + <rect x="0" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="14" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="28" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="42" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="56" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="70" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="84" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="98" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="112" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="126" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="140" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="154" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="168" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="182" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="196" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="210" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> +</g> + +<text x="788" y="190" text-anchor="middle" font-size="13" fill="#2a5a3a">⋮</text> + +<!-- L8 heavily pruned --> +<text x="590" y="212" text-anchor="start" font-size="10" fill="#5a4a4a">L8</text> +<g transform="translate(618,204)"> + <rect x="0" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="14" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="28" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="42" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="56" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="70" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="84" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="98" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="112" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="126" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="140" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="154" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="168" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="182" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="196" y="0" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <rect x="210" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> +</g> +<text x="870" y="212" text-anchor="start" font-size="9" fill="#7a4a4a">heavily pruned</text> + +<text x="788" y="238" text-anchor="middle" font-size="13" fill="#2a5a3a">⋮</text> + +<!-- L15 --> +<text x="590" y="260" text-anchor="start" font-size="10" fill="#4a7a5a">L15</text> +<g transform="translate(618,252)"> + <rect x="0" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="14" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="28" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="42" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="56" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="70" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="84" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="98" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="112" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="126" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="140" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="154" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="168" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="182" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="196" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <rect x="210" y="0" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> +</g> + +<!-- Hyperconnections (skip) --> +<path d="M 624 114 Q 604 152 618 158" stroke="#bb5a78" stroke-width="1" fill="none" stroke-dasharray="3 2" marker-end="url(#ap)"/> +<path d="M 716 114 Q 598 198 618 258" stroke="#bb5a78" stroke-width="1" fill="none" stroke-dasharray="3 2" marker-end="url(#ap)"/> +<path d="M 828 160 Q 855 210 843 252" stroke="#bb5a78" stroke-width="1" fill="none" stroke-dasharray="3 2" marker-end="url(#ap)"/> + +<!-- Weighted connectivity note --> +<rect x="580" y="280" width="416" height="48" rx="5" fill="#0a120c" stroke="#1a3020"/> +<text x="788" y="298" text-anchor="middle" font-size="10" fill="#7ab88a">Weighted head connectivity</text> +<text x="788" y="314" text-anchor="middle" font-family="Consolas,monospace" font-size="9" fill="#8aaa90">input_j = Σᵢ A[i][j] · output_i (soft → differentiable)</text> + +<!-- Topology application arrow --> +<path d="M 470 530 Q 540 500 566 310" stroke="#bb5a78" stroke-width="1.5" fill="none" stroke-dasharray="5 3" marker-end="url(#ap)"/> +<text x="530" y="485" text-anchor="middle" font-size="10" fill="#bb5a78">apply A</text> + +<!-- LM Head --> +<rect x="660" y="350" width="256" height="34" rx="7" fill="#0c160e" stroke="#4a9a6a" stroke-width="1.4"/> +<text x="788" y="370" text-anchor="middle" font-size="12.5" fill="#ddd">LM Head → logits</text> + +<!-- Arrow to NLL --> +<line x1="788" y1="384" x2="788" y2="412" stroke="#5abb78" stroke-width="1.3" marker-end="url(#ag)"/> + +<!-- NLL Loss --> +<rect x="718" y="416" width="140" height="34" rx="7" fill="#0e1a10" stroke="#5abb78" stroke-width="1.4"/> +<text x="788" y="437" text-anchor="middle" font-size="12.5" fill="#8aee9a" font-weight="600">NLL Loss</text> + +<!-- ============ GRADIENT FLOW ============ --> +<!-- Arrow: NLL → left, up (right of yellow box), into Predictor --> +<path d="M 718 435 L 500 435 Q 490 435 490 422 L 490 260 Q 490 250 480 250 L 462 250" stroke="#cc4444" stroke-width="1.5" fill="none" stroke-dasharray="5 3" marker-end="url(#ar)"/> +<text x="510" y="350" text-anchor="start" font-size="10" fill="#cc6666">∇ NLL</text> + +<!-- ============ TRAINING PHASES ============ --> +<rect x="15" y="600" width="1010" height="170" rx="10" fill="none" stroke="#252530" stroke-dasharray="4 3"/> +<text x="520" y="620" text-anchor="middle" font-size="10" fill="#5a5a6a" letter-spacing="2">TRAINING — FULLY DIFFERENTIABLE PIPELINE</text> + +<!-- Phase 1 --> +<rect x="30" y="635" width="470" height="120" rx="7" fill="#14140e" stroke="#8a8a3a"/> +<text x="265" y="656" text-anchor="middle" font-size="12" fill="#cccc6a" font-weight="600">Phase 1: Train Predictor Only</text> +<text x="55" y="678" text-anchor="start" font-family="Consolas,monospace" font-size="9" fill="#aaa86a">OLMo2-1B frozen 🔒</text> +<text x="55" y="694" text-anchor="start" font-family="Consolas,monospace" font-size="9" fill="#aaa86a">Qwen-3-Emb frozen 🔒</text> +<text x="55" y="710" text-anchor="start" font-family="Consolas,monospace" font-size="9" fill="#aaa86a">Predictor trainable 🔓 ← only these params update</text> +<text x="55" y="732" text-anchor="start" font-size="9" fill="#7a7a5a">Goal: learn topologies that lower NLL vs dense baseline</text> +<text x="55" y="746" text-anchor="start" font-size="9" fill="#7a7a5a">Also generates (context, topology) pairs for future diffusion head</text> + +<!-- Arrow between --> +<line x1="500" y1="695" x2="528" y2="695" stroke="#6a6a6a" stroke-width="1.3" marker-end="url(#ab)"/> +<text x="514" y="686" text-anchor="middle" font-size="9" fill="#6a6a6a">then</text> + +<!-- Phase 2 --> +<rect x="535" y="635" width="480" height="120" rx="7" fill="#0e1414" stroke="#3a8a8a"/> +<text x="775" y="656" text-anchor="middle" font-size="12" fill="#6acccc" font-weight="600">Phase 2: Joint Training (CPT)</text> +<text x="560" y="678" text-anchor="start" font-family="Consolas,monospace" font-size="9" fill="#6aaaa8">OLMo2-1B unfrozen 🔓 ← adapts to predicted topologies</text> +<text x="560" y="694" text-anchor="start" font-family="Consolas,monospace" font-size="9" fill="#6aaaa8">Qwen-3-Emb frozen 🔒</text> +<text x="560" y="710" text-anchor="start" font-family="Consolas,monospace" font-size="9" fill="#6aaaa8">Predictor trainable 🔓 ← co-evolves with OLMo</text> +<text x="560" y="732" text-anchor="start" font-size="9" fill="#4a7a7a">Goal: OLMo + Predictor co-alignment</text> +<text x="560" y="746" text-anchor="start" font-size="9" fill="#4a7a7a">Optional: swap MLP decoder → diffusion head (multi-modal topologies)</text> + +<!-- ============ LEGEND ============ --> +<g transform="translate(15,770)"> + <rect x="0" y="-8" width="11" height="11" rx="2" fill="#142a18" stroke="#2a6a3a" stroke-width=".7"/> + <text x="16" y="0" text-anchor="start" font-size="9" fill="#6a6a7a">active</text> + <rect x="60" y="-8" width="11" height="11" rx="2" fill="#18100e" stroke="#5a2a2a" stroke-width=".7" opacity=".35"/> + <text x="76" y="0" text-anchor="start" font-size="9" fill="#6a6a7a">pruned</text> + <line x1="125" y1="0" x2="158" y2="0" stroke="#bb5a78" stroke-width="1" stroke-dasharray="3 2"/> + <text x="164" y="0" text-anchor="start" font-size="9" fill="#6a6a7a">hyperconn</text> + <line x1="225" y1="0" x2="258" y2="0" stroke="#cc4444" stroke-width="1.3" stroke-dasharray="5 3"/> + <text x="264" y="0" text-anchor="start" font-size="9" fill="#6a6a7a">gradient</text> + <line x1="315" y1="0" x2="348" y2="0" stroke="#5abb78" stroke-width="1.3"/> + <text x="354" y="0" text-anchor="start" font-size="9" fill="#6a6a7a">forward</text> + <rect x="400" y="-8" width="11" height="11" rx="2" fill="#1a1410" stroke="#c09040" stroke-width=".7"/> + <text x="416" y="0" text-anchor="start" font-size="9" fill="#6a6a7a">cascading gate (new)</text> +</g> +</svg> + +<p class="foot"> + Cascading Gate enforces: no incoming edges → no outgoing edges · differentiable via soft sigmoid gate<br/> + Future: Phase 1 data → train diffusion decoder to capture multi-modal optimal topologies +</p> +</div> +</body> +</html> |
