1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
|
"""Style-weighted theta fitting.
Weights style-indicative tokens (punctuation, newlines, first-person pronouns,
discourse markers, function words) more heavily in the CE loss.
"""
import torch
import torch.nn.functional as F
# Maximum chunk size for logit computation
CHUNK_SIZE = 128
# Style-indicator token categories and their weights
# These are identified by checking if the decoded token matches certain patterns
STYLE_INDICATORS = {
# Punctuation and formatting
'!': 3.0, '?': 3.0, '.': 2.0, ',': 2.0, ';': 2.0, ':': 2.0,
'...': 3.0, '-': 2.0, '"': 2.0, "'": 2.0,
'\n': 4.0, '\n\n': 4.0,
# First-person pronouns
'i': 3.0, 'me': 3.0, 'my': 3.0, 'mine': 3.0, 'myself': 3.0,
'we': 3.0, 'us': 3.0, 'our': 3.0, 'ours': 3.0,
# Discourse markers and hedges
'however': 2.0, 'although': 2.0, 'moreover': 2.0, 'furthermore': 2.0,
'basically': 2.0, 'actually': 2.0, 'honestly': 2.0, 'frankly': 2.0,
'definitely': 2.0, 'absolutely': 2.0, 'obviously': 2.0,
# Sentiment/intensity
'very': 2.0, 'really': 2.0, 'quite': 2.0, 'extremely': 2.0,
'amazing': 2.0, 'terrible': 2.0, 'awesome': 2.0, 'horrible': 2.0,
'love': 2.0, 'hate': 2.0, 'loved': 2.0, 'hated': 2.0,
}
def build_token_weights(label_ids: torch.Tensor, tokenizer, device: str) -> torch.Tensor:
"""Build per-token weights for style-weighted loss.
Args:
label_ids: (T,) tensor of token ids
tokenizer: The tokenizer to decode token ids
Returns:
weights: (T,) tensor of per-token weights (≥ 1.0)
"""
weights = torch.ones(label_ids.shape[0], device=device, dtype=torch.float32)
for i, tok_id in enumerate(label_ids):
tok_str = tokenizer.decode([tok_id.item()]).strip().lower()
if tok_str in STYLE_INDICATORS:
weights[i] = STYLE_INDICATORS[tok_str]
return weights
def _chunked_weighted_ce_kl(h_prime, h_base, lm_w, lm_b, y, weights, beta):
"""Compute weighted CE + KL in chunks."""
seq_len = h_prime.shape[0]
total_ce = 0.0
total_kl = 0.0
total_weight = 0.0
for start in range(0, seq_len, CHUNK_SIZE):
end = min(start + CHUNK_SIZE, seq_len)
hp_chunk = h_prime[start:end]
hb_chunk = h_base[start:end]
y_chunk = y[start:end]
w_chunk = weights[start:end]
logits = F.linear(hp_chunk, lm_w, lm_b)
base_logits = F.linear(hb_chunk, lm_w, lm_b)
# Weighted CE: per-token CE * weight
ce_per_tok = F.cross_entropy(logits, y_chunk, reduction='none') # (chunk,)
total_ce = total_ce + (ce_per_tok * w_chunk).sum()
total_weight += w_chunk.sum().item()
if beta > 0:
log_p = F.log_softmax(logits, dim=-1)
p0 = F.softmax(base_logits.detach(), dim=-1)
total_kl = total_kl + F.kl_div(log_p, p0, reduction='sum')
del logits, base_logits
if beta > 0:
del log_p, p0
return total_ce, total_kl, total_weight
def fit_theta_weighted(
cached_h: list, # list of (h_states, label_ids)
lm_head_weight: torch.Tensor,
lm_head_bias: torch.Tensor | None,
head_module,
tokenizer,
d: int = 64,
lr: float = 0.05,
steps: int = 30,
beta: float = 0.05,
lam: float = 1e-4,
max_grad_norm: float = 5.0,
device: str = "cuda:1",
max_tokens_per_item: int = 128,
verbose: bool = False,
) -> torch.Tensor:
"""Fit theta with style-weighted CE loss and per-item token budget."""
theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32)
optimizer = torch.optim.Adam([theta], lr=lr)
lm_w = lm_head_weight.float()
lm_b = lm_head_bias.float() if lm_head_bias is not None else None
# Pre-build token weights and truncate to max_tokens_per_item
processed_items = []
for h_cpu, y_cpu in cached_h:
T = min(h_cpu.shape[0], max_tokens_per_item)
h_trunc = h_cpu[:T]
y_trunc = y_cpu[:T]
w = build_token_weights(y_trunc, tokenizer, device)
processed_items.append((h_trunc, y_trunc, w))
for step in range(steps):
total_loss = 0.0
total_weight = 0.0
for h_cpu, y_cpu, w in processed_items:
h = h_cpu.to(device).float()
y = y_cpu.to(device)
h_prime = head_module(h, theta)
ce, kl, w_sum = _chunked_weighted_ce_kl(h_prime, h.detach(), lm_w, lm_b, y, w, beta)
total_loss = total_loss + ce + beta * kl
total_weight += w_sum
del h, y, h_prime
loss = total_loss / max(total_weight, 1.0) + lam * theta.square().sum()
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm)
optimizer.step()
with torch.no_grad():
norm = theta.norm()
if norm > max_grad_norm:
theta.mul_(max_grad_norm / norm)
if verbose and (step % 10 == 0 or step == steps - 1):
print(f" Step {step:3d}: loss={loss.item():.4f}, |theta|={theta.norm().item():.4f}")
del total_loss, loss
return theta.detach()
|