From 8fe28101366dd32562b8c5534d7fe359b252bdf3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Fri, 3 Apr 2026 15:12:34 -0500 Subject: Initial commit: UPH project codebase and experiment results Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) --- adapt/__init__.py | 0 adapt/cache_hidden.py | 33 ++++++++++ adapt/fit_theta.py | 107 +++++++++++++++++++++++++++++++ adapt/fit_theta_weighted.py | 153 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 293 insertions(+) create mode 100644 adapt/__init__.py create mode 100644 adapt/cache_hidden.py create mode 100644 adapt/fit_theta.py create mode 100644 adapt/fit_theta_weighted.py (limited to 'adapt') diff --git a/adapt/__init__.py b/adapt/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/adapt/cache_hidden.py b/adapt/cache_hidden.py new file mode 100644 index 0000000..421c3b7 --- /dev/null +++ b/adapt/cache_hidden.py @@ -0,0 +1,33 @@ +"""Cache hidden states from support set using frozen model.""" + +import torch +from data.templates import build_support_prompt + + +def cache_support_hidden_states( + wrapper, + support_items: list, + task: str, +) -> list: + """Cache hidden states from support set items. + + Args: + wrapper: QwenWrapper instance + support_items: List of dicts with 'support_input' and 'support_output' + task: 'review' or 'topic' + + Returns: + List of (h_states, label_ids) tuples + """ + cached = [] + + for item in support_items: + input_text = build_support_prompt(item['support_input'], task) + target_text = " " + item['support_output'] # Space prefix for clean tokenization + + h_states, label_ids = wrapper.get_hidden_states_teacher_forced(input_text, target_text) + + if h_states is not None and h_states.shape[0] > 0: + cached.append((h_states.detach().cpu(), label_ids.detach().cpu())) + + return cached diff --git a/adapt/fit_theta.py b/adapt/fit_theta.py new file mode 100644 index 0000000..f5b047b --- /dev/null +++ b/adapt/fit_theta.py @@ -0,0 +1,107 @@ +"""Fit theta_u on cached hidden states. + +Loss = CE(lm_head(h + alpha * B(theta ⊙ A@h)), y) + beta * KL(p_theta || p_0) + lambda * ||theta||^2 +""" + +import torch +import torch.nn.functional as F + +# Maximum chunk size for logit computation to avoid OOM +CHUNK_SIZE = 128 + + +def _chunked_ce_kl(h_prime, h_base, lm_w, lm_bias, y, beta): + """Compute CE + KL in chunks to avoid OOM from huge vocab logits.""" + seq_len = h_prime.shape[0] + total_ce = 0.0 + total_kl = 0.0 + + for start in range(0, seq_len, CHUNK_SIZE): + end = min(start + CHUNK_SIZE, seq_len) + hp_chunk = h_prime[start:end] + hb_chunk = h_base[start:end] + y_chunk = y[start:end] + + logits = F.linear(hp_chunk, lm_w, lm_bias) + base_logits = F.linear(hb_chunk, lm_w, lm_bias) + + total_ce = total_ce + F.cross_entropy(logits, y_chunk, reduction='sum') + + if beta > 0: + log_p = F.log_softmax(logits, dim=-1) + p0 = F.softmax(base_logits.detach(), dim=-1) + total_kl = total_kl + F.kl_div(log_p, p0, reduction='sum') + + # Free intermediates + del logits, base_logits + if beta > 0: + del log_p, p0 + + return total_ce, total_kl + + +def fit_theta( + cached_h: list, # list of (h_states: (T_i, H), label_ids: (T_i,)) + lm_head_weight: torch.Tensor, # (vocab_size, H) + lm_head_bias: torch.Tensor | None, + head_module, # CVHHead or UnconditionalHead + d: int = 64, + lr: float = 0.05, + steps: int = 30, + beta: float = 0.05, + lam: float = 1e-4, + max_grad_norm: float = 5.0, + device: str = "cuda:1", + verbose: bool = False, +) -> torch.Tensor: + """Fit a user vector theta_u on cached hidden states. + + Memory-efficient: computes logits in chunks, no pre-computation of base logits. + """ + theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32) + optimizer = torch.optim.Adam([theta], lr=lr) + + lm_w = lm_head_weight.float() + lm_b = lm_head_bias.float() if lm_head_bias is not None else None + + for step in range(steps): + total_loss = 0.0 + total_tokens = 0 + + for h_cpu, y_cpu in cached_h: + h = h_cpu.to(device).float() + y = y_cpu.to(device) + + # Apply head to get personalized hidden states + h_prime = head_module(h, theta) + + # Compute CE + KL in chunks + ce, kl = _chunked_ce_kl(h_prime, h.detach(), lm_w, lm_b, y, beta) + + total_loss = total_loss + ce + beta * kl + total_tokens += y.shape[0] + + # Free GPU memory + del h, y, h_prime + + # Average over tokens + L2 reg + loss = total_loss / max(total_tokens, 1) + lam * theta.square().sum() + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm) + optimizer.step() + + # Clip theta L2 norm + with torch.no_grad(): + norm = theta.norm() + if norm > max_grad_norm: + theta.mul_(max_grad_norm / norm) + + if verbose and (step % 10 == 0 or step == steps - 1): + print(f" Step {step:3d}: loss={loss.item():.4f}, |theta|={theta.norm().item():.4f}") + + # Free graph + del total_loss, loss + + return theta.detach() diff --git a/adapt/fit_theta_weighted.py b/adapt/fit_theta_weighted.py new file mode 100644 index 0000000..f19d5ad --- /dev/null +++ b/adapt/fit_theta_weighted.py @@ -0,0 +1,153 @@ +"""Style-weighted theta fitting. + +Weights style-indicative tokens (punctuation, newlines, first-person pronouns, +discourse markers, function words) more heavily in the CE loss. +""" + +import torch +import torch.nn.functional as F + +# Maximum chunk size for logit computation +CHUNK_SIZE = 128 + +# Style-indicator token categories and their weights +# These are identified by checking if the decoded token matches certain patterns +STYLE_INDICATORS = { + # Punctuation and formatting + '!': 3.0, '?': 3.0, '.': 2.0, ',': 2.0, ';': 2.0, ':': 2.0, + '...': 3.0, '-': 2.0, '"': 2.0, "'": 2.0, + '\n': 4.0, '\n\n': 4.0, + # First-person pronouns + 'i': 3.0, 'me': 3.0, 'my': 3.0, 'mine': 3.0, 'myself': 3.0, + 'we': 3.0, 'us': 3.0, 'our': 3.0, 'ours': 3.0, + # Discourse markers and hedges + 'however': 2.0, 'although': 2.0, 'moreover': 2.0, 'furthermore': 2.0, + 'basically': 2.0, 'actually': 2.0, 'honestly': 2.0, 'frankly': 2.0, + 'definitely': 2.0, 'absolutely': 2.0, 'obviously': 2.0, + # Sentiment/intensity + 'very': 2.0, 'really': 2.0, 'quite': 2.0, 'extremely': 2.0, + 'amazing': 2.0, 'terrible': 2.0, 'awesome': 2.0, 'horrible': 2.0, + 'love': 2.0, 'hate': 2.0, 'loved': 2.0, 'hated': 2.0, +} + + +def build_token_weights(label_ids: torch.Tensor, tokenizer, device: str) -> torch.Tensor: + """Build per-token weights for style-weighted loss. + + Args: + label_ids: (T,) tensor of token ids + tokenizer: The tokenizer to decode token ids + + Returns: + weights: (T,) tensor of per-token weights (≥ 1.0) + """ + weights = torch.ones(label_ids.shape[0], device=device, dtype=torch.float32) + + for i, tok_id in enumerate(label_ids): + tok_str = tokenizer.decode([tok_id.item()]).strip().lower() + if tok_str in STYLE_INDICATORS: + weights[i] = STYLE_INDICATORS[tok_str] + + return weights + + +def _chunked_weighted_ce_kl(h_prime, h_base, lm_w, lm_b, y, weights, beta): + """Compute weighted CE + KL in chunks.""" + seq_len = h_prime.shape[0] + total_ce = 0.0 + total_kl = 0.0 + total_weight = 0.0 + + for start in range(0, seq_len, CHUNK_SIZE): + end = min(start + CHUNK_SIZE, seq_len) + hp_chunk = h_prime[start:end] + hb_chunk = h_base[start:end] + y_chunk = y[start:end] + w_chunk = weights[start:end] + + logits = F.linear(hp_chunk, lm_w, lm_b) + base_logits = F.linear(hb_chunk, lm_w, lm_b) + + # Weighted CE: per-token CE * weight + ce_per_tok = F.cross_entropy(logits, y_chunk, reduction='none') # (chunk,) + total_ce = total_ce + (ce_per_tok * w_chunk).sum() + total_weight += w_chunk.sum().item() + + if beta > 0: + log_p = F.log_softmax(logits, dim=-1) + p0 = F.softmax(base_logits.detach(), dim=-1) + total_kl = total_kl + F.kl_div(log_p, p0, reduction='sum') + + del logits, base_logits + if beta > 0: + del log_p, p0 + + return total_ce, total_kl, total_weight + + +def fit_theta_weighted( + cached_h: list, # list of (h_states, label_ids) + lm_head_weight: torch.Tensor, + lm_head_bias: torch.Tensor | None, + head_module, + tokenizer, + d: int = 64, + lr: float = 0.05, + steps: int = 30, + beta: float = 0.05, + lam: float = 1e-4, + max_grad_norm: float = 5.0, + device: str = "cuda:1", + max_tokens_per_item: int = 128, + verbose: bool = False, +) -> torch.Tensor: + """Fit theta with style-weighted CE loss and per-item token budget.""" + theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32) + optimizer = torch.optim.Adam([theta], lr=lr) + + lm_w = lm_head_weight.float() + lm_b = lm_head_bias.float() if lm_head_bias is not None else None + + # Pre-build token weights and truncate to max_tokens_per_item + processed_items = [] + for h_cpu, y_cpu in cached_h: + T = min(h_cpu.shape[0], max_tokens_per_item) + h_trunc = h_cpu[:T] + y_trunc = y_cpu[:T] + w = build_token_weights(y_trunc, tokenizer, device) + processed_items.append((h_trunc, y_trunc, w)) + + for step in range(steps): + total_loss = 0.0 + total_weight = 0.0 + + for h_cpu, y_cpu, w in processed_items: + h = h_cpu.to(device).float() + y = y_cpu.to(device) + + h_prime = head_module(h, theta) + ce, kl, w_sum = _chunked_weighted_ce_kl(h_prime, h.detach(), lm_w, lm_b, y, w, beta) + + total_loss = total_loss + ce + beta * kl + total_weight += w_sum + + del h, y, h_prime + + loss = total_loss / max(total_weight, 1.0) + lam * theta.square().sum() + + optimizer.zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm) + optimizer.step() + + with torch.no_grad(): + norm = theta.norm() + if norm > max_grad_norm: + theta.mul_(max_grad_norm / norm) + + if verbose and (step % 10 == 0 or step == steps - 1): + print(f" Step {step:3d}: loss={loss.item():.4f}, |theta|={theta.norm().item():.4f}") + + del total_loss, loss + + return theta.detach() -- cgit v1.2.3