summaryrefslogtreecommitdiff
path: root/resulets/adapt/fit_theta_lm_head_update.py
blob: 7a38e62f5206c95e927164cb04df945b0ac04081 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
"""Fit theta_u for a low-rank LM-head weight update.

The update is W'_u = W + gamma * alpha * C diag(theta_u) A, so the
per-token logit correction depends on the current hidden state.
"""

import torch
import torch.nn.functional as F


CHUNK_SIZE = 16


def _backward_chunked_ce_kl(
    h_cpu, lm_w, lm_bias, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens
):
    total_ce_value = 0.0
    total_kl_value = 0.0

    for start in range(0, h_cpu.shape[0], CHUNK_SIZE):
        end = min(start + CHUNK_SIZE, h_cpu.shape[0])
        h_chunk = h_cpu[start:end].to(device).float()
        y_chunk = y_cpu[start:end].to(device)

        base_logits = F.linear(h_chunk, lm_w, lm_bias)
        delta_logits = head_update.logit_delta(h_chunk, theta)
        logits = base_logits + blend_gamma * delta_logits

        ce = F.cross_entropy(logits, y_chunk, reduction='sum')

        if beta > 0:
            log_p = F.log_softmax(logits, dim=-1)
            p0 = F.softmax(base_logits.detach(), dim=-1)
            kl = F.kl_div(log_p, p0, reduction='sum')
        else:
            kl = torch.zeros((), device=device)

        ((ce + beta * kl) / max(total_tokens, 1)).backward()
        total_ce_value += float(ce.detach().cpu())
        total_kl_value += float(kl.detach().cpu())

        if beta > 0:
            del log_p, p0

        del h_chunk, y_chunk, base_logits, delta_logits, logits, ce, kl

    return total_ce_value, total_kl_value


def fit_theta_lm_head_update(
    cached_h: list,
    lm_head_weight: torch.Tensor,
    lm_head_bias: torch.Tensor | None,
    head_update,
    d: int = 64,
    lr: float = 0.05,
    steps: int = 30,
    beta: float = 0.05,
    lam: float = 1e-4,
    blend_gamma: float = 0.5,
    max_grad_norm: float = 5.0,
    device: str = "cuda:0",
    verbose: bool = False,
) -> torch.Tensor:
    """Fit the user vector theta_u for an LM-head update."""
    theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32)
    optimizer = torch.optim.Adam([theta], lr=lr)

    lm_w = lm_head_weight.float()
    lm_b = lm_head_bias.float() if lm_head_bias is not None else None

    total_tokens = sum(y_cpu.shape[0] for _, y_cpu in cached_h)

    for step in range(steps):
        total_ce_value = 0.0
        total_kl_value = 0.0
        optimizer.zero_grad()

        for h_cpu, y_cpu in cached_h:
            ce_value, kl_value = _backward_chunked_ce_kl(
                h_cpu, lm_w, lm_b, y_cpu, head_update, theta, beta, blend_gamma, device, total_tokens
            )
            total_ce_value += ce_value
            total_kl_value += kl_value

        reg = lam * theta.square().sum()
        reg.backward()
        torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm)
        optimizer.step()

        with torch.no_grad():
            norm = theta.norm()
            if norm > max_grad_norm:
                theta.mul_(max_grad_norm / norm)

        if verbose and (step % 10 == 0 or step == steps - 1):
            loss_value = (total_ce_value + beta * total_kl_value) / max(total_tokens, 1) + float(reg.detach().cpu())
            print(f"  Step {step:3d}: loss={loss_value:.4f}, |theta|={theta.norm().item():.4f}")

        del reg

    return theta.detach()