1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
|
"""End-to-end DAGFormer pipeline: raw text → predictor → A → OLMo → NLL.
Glues the structure predictor (Qwen + MLP) with the modified OLMo forward.
This is what the trainer calls. See CLAUDE.md §5 for file responsibilities.
"""
from __future__ import annotations
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoModelForCausalLM, AutoTokenizer
from src.model.olmo_graph import DAGFormerOLMo, create_all_ones_A
from src.model.predictor import StructurePredictor
class DAGFormerPipeline(nn.Module):
"""Combines StructurePredictor + DAGFormerOLMo into a single forward pass.
Forward: raw_text → predictor → A → modified OLMo → logits → NLL
Only the predictor's MLP params are trainable. OLMo and Qwen are frozen.
"""
def __init__(
self,
olmo_model_id: str = "allenai/OLMo-2-0425-1B",
qwen_model_id: str = "Qwen/Qwen3-Embedding-0.6B",
predictor_hidden_dim: int = 1024,
predictor_rank: int = 32,
cascading_gate_k: float = 5.0,
input_norm: str = "none",
qwen_input_prefix: str = "",
device: Optional[torch.device] = None,
):
super().__init__()
# Load frozen OLMo2-1B
olmo = AutoModelForCausalLM.from_pretrained(
olmo_model_id,
torch_dtype=torch.bfloat16,
)
olmo.eval()
for p in olmo.parameters():
p.requires_grad_(False)
# Wrap OLMo with DAGFormer modification
self.olmo_wrapper = DAGFormerOLMo(model=olmo, input_norm=input_norm)
# Structure predictor (Qwen encoder + MLP decoder)
self.predictor = StructurePredictor(
qwen_model_id=qwen_model_id,
hidden_dim=predictor_hidden_dim,
rank=predictor_rank,
cascading_gate_k=cascading_gate_k,
qwen_input_prefix=qwen_input_prefix,
device=device,
)
self.vocab_size = olmo.config.vocab_size
if device is not None:
self.to(device)
def forward(
self,
raw_texts: list[str],
olmo_ids: torch.Tensor,
olmo_labels: torch.Tensor,
tau: float,
lambda_sparsity: float = 0.0,
mode: str = "train",
) -> dict[str, torch.Tensor]:
"""Full forward pass: text → A → logits → loss.
Args:
raw_texts: list of raw text strings (batch)
olmo_ids: [batch, seq_len] — OLMo tokenized input
olmo_labels: [batch, seq_len] — shifted labels for NLL
tau: Gumbel-Sigmoid temperature
lambda_sparsity: sparsity coefficient (λ_t)
mode: "train", "eval_soft", or "eval_hard"
Returns:
dict with keys:
"total_loss": nll + lambda * mean(A) — what the optimizer sees
"nll": cross-entropy loss
"sparsity_loss": lambda * mean(A)
"A": [batch, 256, 256] adjacency matrix
"""
# Step 1: Predict adjacency matrix
A = self.predictor(raw_texts, tau=tau, mode=mode)
# A: [batch, 256, 256]
# Step 2: Modified OLMo forward with A
logits = self.olmo_wrapper(olmo_ids, A)
# logits: [batch, seq_len, vocab_size]
# Step 3: Compute NLL (next-token prediction)
# olmo_labels is already shifted (chunk[1:seq_len+1]), no additional shift needed
nll = F.cross_entropy(
logits.contiguous().view(-1, self.vocab_size),
olmo_labels.contiguous().view(-1),
)
# Step 4: Sparsity regularization
sparsity_loss = lambda_sparsity * A.mean()
total_loss = nll + sparsity_loss
return {
"total_loss": total_loss,
"nll": nll,
"sparsity_loss": sparsity_loss,
"A": A,
}
def forward_baseline(
self,
olmo_ids: torch.Tensor,
olmo_labels: torch.Tensor,
) -> torch.Tensor:
"""Forward with A=all-ones (baseline reproduction).
Used for eval/nll_baseline metric.
"""
batch = olmo_ids.shape[0]
A = create_all_ones_A(batch).to(olmo_ids.device)
with torch.no_grad():
logits = self.olmo_wrapper(olmo_ids, A)
# olmo_labels is already shifted, no additional shift needed
nll = F.cross_entropy(
logits.contiguous().view(-1, self.vocab_size),
olmo_labels.contiguous().view(-1),
)
return nll
def get_trainable_parameters(self) -> list[nn.Parameter]:
"""Return only the trainable parameters (predictor MLP + any norm params)."""
params = list(self.predictor.get_trainable_parameters())
# Also include input normalizer params if they exist
params.extend(self.olmo_wrapper.input_normalizer.parameters())
return params
|