1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
|
"""Fit theta_u for a low-rank LM-head weight update.
The update is W'_u = W + gamma * alpha * C diag(theta_u) A, so the
per-token logit correction depends on the current hidden state.
"""
import torch
import torch.nn.functional as F
CHUNK_SIZE = 16
def _backward_chunked_ce_kl(
h_cpu, lm_w, lm_bias, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens
):
total_ce_value = 0.0
total_kl_value = 0.0
for start in range(0, h_cpu.shape[0], CHUNK_SIZE):
end = min(start + CHUNK_SIZE, h_cpu.shape[0])
h_chunk = h_cpu[start:end].to(device).float()
y_chunk = y_cpu[start:end].to(device)
base_logits = F.linear(h_chunk, lm_w, lm_bias)
delta_logits = head_update.logit_delta(h_chunk, theta)
logits = base_logits + blend_gamma * delta_logits
ce = F.cross_entropy(logits, y_chunk, reduction='sum')
if beta > 0:
log_p = F.log_softmax(logits, dim=-1)
p0 = F.softmax(base_logits.detach(), dim=-1)
kl = F.kl_div(log_p, p0, reduction='sum')
else:
kl = torch.zeros((), device=device)
((ce + beta * kl) / max(total_tokens, 1)).backward()
total_ce_value += float(ce.detach().cpu())
total_kl_value += float(kl.detach().cpu())
if beta > 0:
del log_p, p0
del h_chunk, y_chunk, base_logits, delta_logits, logits, ce, kl
return total_ce_value, total_kl_value
def fit_theta_lm_head_update(
cached_h: list,
lm_head_weight: torch.Tensor,
lm_head_bias: torch.Tensor | None,
head_update,
d: int = 64,
lr: float = 0.05,
steps: int = 30,
beta: float = 0.05,
lam: float = 1e-4,
blend_gamma: float = 0.5,
max_grad_norm: float = 5.0,
device: str = "cuda:0",
verbose: bool = False,
) -> torch.Tensor:
"""Fit the user vector theta_u for an LM-head update."""
theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32)
optimizer = torch.optim.Adam([theta], lr=lr)
lm_w = lm_head_weight.float()
lm_b = lm_head_bias.float() if lm_head_bias is not None else None
total_tokens = sum(y_cpu.shape[0] for _, y_cpu in cached_h)
for step in range(steps):
total_ce_value = 0.0
total_kl_value = 0.0
optimizer.zero_grad()
for h_cpu, y_cpu in cached_h:
ce_value, kl_value = _backward_chunked_ce_kl(
h_cpu, lm_w, lm_b, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens
)
total_ce_value += ce_value
total_kl_value += kl_value
reg = lam * theta.square().sum()
reg.backward()
torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm)
optimizer.step()
with torch.no_grad():
norm = theta.norm()
if norm > max_grad_norm:
theta.mul_(max_grad_norm / norm)
if verbose and (step % 10 == 0 or step == steps - 1):
loss_value = (total_ce_value + beta * total_kl_value) / max(total_tokens, 1) + float(reg.detach().cpu())
print(f" Step {step:3d}: loss={loss_value:.4f}, |theta|={theta.norm().item():.4f}")
del reg
return theta.detach()
|