summaryrefslogtreecommitdiff
path: root/src/model/pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/model/pipeline.py')
-rw-r--r--src/model/pipeline.py144
1 files changed, 144 insertions, 0 deletions
diff --git a/src/model/pipeline.py b/src/model/pipeline.py
new file mode 100644
index 0000000..bbfcabf
--- /dev/null
+++ b/src/model/pipeline.py
@@ -0,0 +1,144 @@
+"""End-to-end DAGFormer pipeline: raw text → predictor → A → OLMo → NLL.
+
+Glues the structure predictor (Qwen + MLP) with the modified OLMo forward.
+This is what the trainer calls. See CLAUDE.md §5 for file responsibilities.
+"""
+
+from __future__ import annotations
+
+from typing import Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A
+from src.model.predictor import StructurePredictor
+
+
+class DAGFormerPipeline(nn.Module):
+ """Combines StructurePredictor + DAGFormerOLMo into a single forward pass.
+
+ Forward: raw_text → predictor → A → modified OLMo → logits → NLL
+
+ Only the predictor's MLP params are trainable. OLMo and Qwen are frozen.
+ """
+
+ def __init__(
+ self,
+ olmo_model_id: str = "allenai/OLMo-2-0425-1B",
+ qwen_model_id: str = "Qwen/Qwen3-Embedding-0.6B",
+ predictor_hidden_dim: int = 1024,
+ predictor_rank: int = 32,
+ cascading_gate_k: float = 5.0,
+ input_norm: str = "none",
+ qwen_input_prefix: str = "",
+ device: Optional[torch.device] = None,
+ ):
+ super().__init__()
+
+ # Load frozen OLMo2-1B
+ olmo = AutoModelForCausalLM.from_pretrained(
+ olmo_model_id,
+ torch_dtype=torch.bfloat16,
+ )
+ olmo.eval()
+ for p in olmo.parameters():
+ p.requires_grad_(False)
+
+ # Wrap OLMo with DAGFormer modification
+ self.olmo_wrapper = DAGFormerOLMo(model=olmo, input_norm=input_norm)
+
+ # Structure predictor (Qwen encoder + MLP decoder)
+ self.predictor = StructurePredictor(
+ qwen_model_id=qwen_model_id,
+ hidden_dim=predictor_hidden_dim,
+ rank=predictor_rank,
+ cascading_gate_k=cascading_gate_k,
+ qwen_input_prefix=qwen_input_prefix,
+ device=device,
+ )
+
+ self.vocab_size = olmo.config.vocab_size
+
+ if device is not None:
+ self.to(device)
+
+ def forward(
+ self,
+ raw_texts: list[str],
+ olmo_ids: torch.Tensor,
+ olmo_labels: torch.Tensor,
+ tau: float,
+ lambda_sparsity: float = 0.0,
+ mode: str = "train",
+ ) -> dict[str, torch.Tensor]:
+ """Full forward pass: text → A → logits → loss.
+
+ Args:
+ raw_texts: list of raw text strings (batch)
+ olmo_ids: [batch, seq_len] — OLMo tokenized input
+ olmo_labels: [batch, seq_len] — shifted labels for NLL
+ tau: Gumbel-Sigmoid temperature
+ lambda_sparsity: sparsity coefficient (λ_t)
+ mode: "train", "eval_soft", or "eval_hard"
+
+ Returns:
+ dict with keys:
+ "total_loss": nll + lambda * mean(A) — what the optimizer sees
+ "nll": cross-entropy loss
+ "sparsity_loss": lambda * mean(A)
+ "A": [batch, 256, 256] adjacency matrix
+ """
+ # Step 1: Predict adjacency matrix
+ A = self.predictor(raw_texts, tau=tau, mode=mode)
+ # A: [batch, 256, 256]
+
+ # Step 2: Modified OLMo forward with A
+ logits = self.olmo_wrapper(olmo_ids, A)
+ # logits: [batch, seq_len, vocab_size]
+
+ # Step 3: Compute NLL (next-token prediction)
+ # Shift: logits[:, :-1] predicts labels[:, 1:]
+ nll = F.cross_entropy(
+ logits[:, :-1].contiguous().view(-1, self.vocab_size),
+ olmo_labels[:, 1:].contiguous().view(-1),
+ )
+
+ # Step 4: Sparsity regularization
+ sparsity_loss = lambda_sparsity * A.mean()
+ total_loss = nll + sparsity_loss
+
+ return {
+ "total_loss": total_loss,
+ "nll": nll,
+ "sparsity_loss": sparsity_loss,
+ "A": A,
+ }
+
+ def forward_baseline(
+ self,
+ olmo_ids: torch.Tensor,
+ olmo_labels: torch.Tensor,
+ ) -> torch.Tensor:
+ """Forward with A=all-ones (baseline reproduction).
+
+ Used for eval/nll_baseline metric.
+ """
+ batch = olmo_ids.shape[0]
+ A = create_all_ones_A(batch).to(olmo_ids.device)
+ with torch.no_grad():
+ logits = self.olmo_wrapper(olmo_ids, A)
+ nll = F.cross_entropy(
+ logits[:, :-1].contiguous().view(-1, self.vocab_size),
+ olmo_labels[:, 1:].contiguous().view(-1),
+ )
+ return nll
+
+ def get_trainable_parameters(self) -> list[nn.Parameter]:
+ """Return only the trainable parameters (predictor MLP + any norm params)."""
+ params = list(self.predictor.get_trainable_parameters())
+ # Also include input normalizer params if they exist
+ params.extend(self.olmo_wrapper.input_normalizer.parameters())
+ return params