summaryrefslogtreecommitdiff
path: root/adapt/fit_theta.py
blob: f5b047b6c7733d678966437917ff0df82def728c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""Fit theta_u on cached hidden states.

Loss = CE(lm_head(h + alpha * B(theta ⊙ A@h)), y) + beta * KL(p_theta || p_0) + lambda * ||theta||^2
"""

import torch
import torch.nn.functional as F

# Maximum chunk size for logit computation to avoid OOM
CHUNK_SIZE = 128


def _chunked_ce_kl(h_prime, h_base, lm_w, lm_bias, y, beta):
    """Compute CE + KL in chunks to avoid OOM from huge vocab logits."""
    seq_len = h_prime.shape[0]
    total_ce = 0.0
    total_kl = 0.0

    for start in range(0, seq_len, CHUNK_SIZE):
        end = min(start + CHUNK_SIZE, seq_len)
        hp_chunk = h_prime[start:end]
        hb_chunk = h_base[start:end]
        y_chunk = y[start:end]

        logits = F.linear(hp_chunk, lm_w, lm_bias)
        base_logits = F.linear(hb_chunk, lm_w, lm_bias)

        total_ce = total_ce + F.cross_entropy(logits, y_chunk, reduction='sum')

        if beta > 0:
            log_p = F.log_softmax(logits, dim=-1)
            p0 = F.softmax(base_logits.detach(), dim=-1)
            total_kl = total_kl + F.kl_div(log_p, p0, reduction='sum')

        # Free intermediates
        del logits, base_logits
        if beta > 0:
            del log_p, p0

    return total_ce, total_kl


def fit_theta(
    cached_h: list,       # list of (h_states: (T_i, H), label_ids: (T_i,))
    lm_head_weight: torch.Tensor,  # (vocab_size, H)
    lm_head_bias: torch.Tensor | None,
    head_module,           # CVHHead or UnconditionalHead
    d: int = 64,
    lr: float = 0.05,
    steps: int = 30,
    beta: float = 0.05,
    lam: float = 1e-4,
    max_grad_norm: float = 5.0,
    device: str = "cuda:1",
    verbose: bool = False,
) -> torch.Tensor:
    """Fit a user vector theta_u on cached hidden states.

    Memory-efficient: computes logits in chunks, no pre-computation of base logits.
    """
    theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32)
    optimizer = torch.optim.Adam([theta], lr=lr)

    lm_w = lm_head_weight.float()
    lm_b = lm_head_bias.float() if lm_head_bias is not None else None

    for step in range(steps):
        total_loss = 0.0
        total_tokens = 0

        for h_cpu, y_cpu in cached_h:
            h = h_cpu.to(device).float()
            y = y_cpu.to(device)

            # Apply head to get personalized hidden states
            h_prime = head_module(h, theta)

            # Compute CE + KL in chunks
            ce, kl = _chunked_ce_kl(h_prime, h.detach(), lm_w, lm_b, y, beta)

            total_loss = total_loss + ce + beta * kl
            total_tokens += y.shape[0]

            # Free GPU memory
            del h, y, h_prime

        # Average over tokens + L2 reg
        loss = total_loss / max(total_tokens, 1) + lam * theta.square().sum()

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm)
        optimizer.step()

        # Clip theta L2 norm
        with torch.no_grad():
            norm = theta.norm()
            if norm > max_grad_norm:
                theta.mul_(max_grad_norm / norm)

        if verbose and (step % 10 == 0 or step == steps - 1):
            print(f"  Step {step:3d}: loss={loss.item():.4f}, |theta|={theta.norm().item():.4f}")

        # Free graph
        del total_loss, loss

    return theta.detach()