summaryrefslogtreecommitdiff
path: root/adapt
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
committerYurenHao0426 <Blackhao0426@gmail.com>2026-04-03 15:12:34 -0500
commit8fe28101366dd32562b8c5534d7fe359b252bdf3 (patch)
treec92a92184fb2f46f265ab84c1f754c3d5d6597bc /adapt
Initial commit: UPH project codebase and experiment results
Includes model code, evaluation scripts, configs, analysis outputs, and experiment results for the User Prior Head personalization method. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Diffstat (limited to 'adapt')
-rw-r--r--adapt/__init__.py0
-rw-r--r--adapt/cache_hidden.py33
-rw-r--r--adapt/fit_theta.py107
-rw-r--r--adapt/fit_theta_weighted.py153
4 files changed, 293 insertions, 0 deletions
diff --git a/adapt/__init__.py b/adapt/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/adapt/__init__.py
diff --git a/adapt/cache_hidden.py b/adapt/cache_hidden.py
new file mode 100644
index 0000000..421c3b7
--- /dev/null
+++ b/adapt/cache_hidden.py
@@ -0,0 +1,33 @@
+"""Cache hidden states from support set using frozen model."""
+
+import torch
+from data.templates import build_support_prompt
+
+
+def cache_support_hidden_states(
+ wrapper,
+ support_items: list,
+ task: str,
+) -> list:
+ """Cache hidden states from support set items.
+
+ Args:
+ wrapper: QwenWrapper instance
+ support_items: List of dicts with 'support_input' and 'support_output'
+ task: 'review' or 'topic'
+
+ Returns:
+ List of (h_states, label_ids) tuples
+ """
+ cached = []
+
+ for item in support_items:
+ input_text = build_support_prompt(item['support_input'], task)
+ target_text = " " + item['support_output'] # Space prefix for clean tokenization
+
+ h_states, label_ids = wrapper.get_hidden_states_teacher_forced(input_text, target_text)
+
+ if h_states is not None and h_states.shape[0] > 0:
+ cached.append((h_states.detach().cpu(), label_ids.detach().cpu()))
+
+ return cached
diff --git a/adapt/fit_theta.py b/adapt/fit_theta.py
new file mode 100644
index 0000000..f5b047b
--- /dev/null
+++ b/adapt/fit_theta.py
@@ -0,0 +1,107 @@
+"""Fit theta_u on cached hidden states.
+
+Loss = CE(lm_head(h + alpha * B(theta ⊙ A@h)), y) + beta * KL(p_theta || p_0) + lambda * ||theta||^2
+"""
+
+import torch
+import torch.nn.functional as F
+
+# Maximum chunk size for logit computation to avoid OOM
+CHUNK_SIZE = 128
+
+
+def _chunked_ce_kl(h_prime, h_base, lm_w, lm_bias, y, beta):
+ """Compute CE + KL in chunks to avoid OOM from huge vocab logits."""
+ seq_len = h_prime.shape[0]
+ total_ce = 0.0
+ total_kl = 0.0
+
+ for start in range(0, seq_len, CHUNK_SIZE):
+ end = min(start + CHUNK_SIZE, seq_len)
+ hp_chunk = h_prime[start:end]
+ hb_chunk = h_base[start:end]
+ y_chunk = y[start:end]
+
+ logits = F.linear(hp_chunk, lm_w, lm_bias)
+ base_logits = F.linear(hb_chunk, lm_w, lm_bias)
+
+ total_ce = total_ce + F.cross_entropy(logits, y_chunk, reduction='sum')
+
+ if beta > 0:
+ log_p = F.log_softmax(logits, dim=-1)
+ p0 = F.softmax(base_logits.detach(), dim=-1)
+ total_kl = total_kl + F.kl_div(log_p, p0, reduction='sum')
+
+ # Free intermediates
+ del logits, base_logits
+ if beta > 0:
+ del log_p, p0
+
+ return total_ce, total_kl
+
+
+def fit_theta(
+ cached_h: list, # list of (h_states: (T_i, H), label_ids: (T_i,))
+ lm_head_weight: torch.Tensor, # (vocab_size, H)
+ lm_head_bias: torch.Tensor | None,
+ head_module, # CVHHead or UnconditionalHead
+ d: int = 64,
+ lr: float = 0.05,
+ steps: int = 30,
+ beta: float = 0.05,
+ lam: float = 1e-4,
+ max_grad_norm: float = 5.0,
+ device: str = "cuda:1",
+ verbose: bool = False,
+) -> torch.Tensor:
+ """Fit a user vector theta_u on cached hidden states.
+
+ Memory-efficient: computes logits in chunks, no pre-computation of base logits.
+ """
+ theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32)
+ optimizer = torch.optim.Adam([theta], lr=lr)
+
+ lm_w = lm_head_weight.float()
+ lm_b = lm_head_bias.float() if lm_head_bias is not None else None
+
+ for step in range(steps):
+ total_loss = 0.0
+ total_tokens = 0
+
+ for h_cpu, y_cpu in cached_h:
+ h = h_cpu.to(device).float()
+ y = y_cpu.to(device)
+
+ # Apply head to get personalized hidden states
+ h_prime = head_module(h, theta)
+
+ # Compute CE + KL in chunks
+ ce, kl = _chunked_ce_kl(h_prime, h.detach(), lm_w, lm_b, y, beta)
+
+ total_loss = total_loss + ce + beta * kl
+ total_tokens += y.shape[0]
+
+ # Free GPU memory
+ del h, y, h_prime
+
+ # Average over tokens + L2 reg
+ loss = total_loss / max(total_tokens, 1) + lam * theta.square().sum()
+
+ optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm)
+ optimizer.step()
+
+ # Clip theta L2 norm
+ with torch.no_grad():
+ norm = theta.norm()
+ if norm > max_grad_norm:
+ theta.mul_(max_grad_norm / norm)
+
+ if verbose and (step % 10 == 0 or step == steps - 1):
+ print(f" Step {step:3d}: loss={loss.item():.4f}, |theta|={theta.norm().item():.4f}")
+
+ # Free graph
+ del total_loss, loss
+
+ return theta.detach()
diff --git a/adapt/fit_theta_weighted.py b/adapt/fit_theta_weighted.py
new file mode 100644
index 0000000..f19d5ad
--- /dev/null
+++ b/adapt/fit_theta_weighted.py
@@ -0,0 +1,153 @@
+"""Style-weighted theta fitting.
+
+Weights style-indicative tokens (punctuation, newlines, first-person pronouns,
+discourse markers, function words) more heavily in the CE loss.
+"""
+
+import torch
+import torch.nn.functional as F
+
+# Maximum chunk size for logit computation
+CHUNK_SIZE = 128
+
+# Style-indicator token categories and their weights
+# These are identified by checking if the decoded token matches certain patterns
+STYLE_INDICATORS = {
+ # Punctuation and formatting
+ '!': 3.0, '?': 3.0, '.': 2.0, ',': 2.0, ';': 2.0, ':': 2.0,
+ '...': 3.0, '-': 2.0, '"': 2.0, "'": 2.0,
+ '\n': 4.0, '\n\n': 4.0,
+ # First-person pronouns
+ 'i': 3.0, 'me': 3.0, 'my': 3.0, 'mine': 3.0, 'myself': 3.0,
+ 'we': 3.0, 'us': 3.0, 'our': 3.0, 'ours': 3.0,
+ # Discourse markers and hedges
+ 'however': 2.0, 'although': 2.0, 'moreover': 2.0, 'furthermore': 2.0,
+ 'basically': 2.0, 'actually': 2.0, 'honestly': 2.0, 'frankly': 2.0,
+ 'definitely': 2.0, 'absolutely': 2.0, 'obviously': 2.0,
+ # Sentiment/intensity
+ 'very': 2.0, 'really': 2.0, 'quite': 2.0, 'extremely': 2.0,
+ 'amazing': 2.0, 'terrible': 2.0, 'awesome': 2.0, 'horrible': 2.0,
+ 'love': 2.0, 'hate': 2.0, 'loved': 2.0, 'hated': 2.0,
+}
+
+
+def build_token_weights(label_ids: torch.Tensor, tokenizer, device: str) -> torch.Tensor:
+ """Build per-token weights for style-weighted loss.
+
+ Args:
+ label_ids: (T,) tensor of token ids
+ tokenizer: The tokenizer to decode token ids
+
+ Returns:
+ weights: (T,) tensor of per-token weights (≥ 1.0)
+ """
+ weights = torch.ones(label_ids.shape[0], device=device, dtype=torch.float32)
+
+ for i, tok_id in enumerate(label_ids):
+ tok_str = tokenizer.decode([tok_id.item()]).strip().lower()
+ if tok_str in STYLE_INDICATORS:
+ weights[i] = STYLE_INDICATORS[tok_str]
+
+ return weights
+
+
+def _chunked_weighted_ce_kl(h_prime, h_base, lm_w, lm_b, y, weights, beta):
+ """Compute weighted CE + KL in chunks."""
+ seq_len = h_prime.shape[0]
+ total_ce = 0.0
+ total_kl = 0.0
+ total_weight = 0.0
+
+ for start in range(0, seq_len, CHUNK_SIZE):
+ end = min(start + CHUNK_SIZE, seq_len)
+ hp_chunk = h_prime[start:end]
+ hb_chunk = h_base[start:end]
+ y_chunk = y[start:end]
+ w_chunk = weights[start:end]
+
+ logits = F.linear(hp_chunk, lm_w, lm_b)
+ base_logits = F.linear(hb_chunk, lm_w, lm_b)
+
+ # Weighted CE: per-token CE * weight
+ ce_per_tok = F.cross_entropy(logits, y_chunk, reduction='none') # (chunk,)
+ total_ce = total_ce + (ce_per_tok * w_chunk).sum()
+ total_weight += w_chunk.sum().item()
+
+ if beta > 0:
+ log_p = F.log_softmax(logits, dim=-1)
+ p0 = F.softmax(base_logits.detach(), dim=-1)
+ total_kl = total_kl + F.kl_div(log_p, p0, reduction='sum')
+
+ del logits, base_logits
+ if beta > 0:
+ del log_p, p0
+
+ return total_ce, total_kl, total_weight
+
+
+def fit_theta_weighted(
+ cached_h: list, # list of (h_states, label_ids)
+ lm_head_weight: torch.Tensor,
+ lm_head_bias: torch.Tensor | None,
+ head_module,
+ tokenizer,
+ d: int = 64,
+ lr: float = 0.05,
+ steps: int = 30,
+ beta: float = 0.05,
+ lam: float = 1e-4,
+ max_grad_norm: float = 5.0,
+ device: str = "cuda:1",
+ max_tokens_per_item: int = 128,
+ verbose: bool = False,
+) -> torch.Tensor:
+ """Fit theta with style-weighted CE loss and per-item token budget."""
+ theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32)
+ optimizer = torch.optim.Adam([theta], lr=lr)
+
+ lm_w = lm_head_weight.float()
+ lm_b = lm_head_bias.float() if lm_head_bias is not None else None
+
+ # Pre-build token weights and truncate to max_tokens_per_item
+ processed_items = []
+ for h_cpu, y_cpu in cached_h:
+ T = min(h_cpu.shape[0], max_tokens_per_item)
+ h_trunc = h_cpu[:T]
+ y_trunc = y_cpu[:T]
+ w = build_token_weights(y_trunc, tokenizer, device)
+ processed_items.append((h_trunc, y_trunc, w))
+
+ for step in range(steps):
+ total_loss = 0.0
+ total_weight = 0.0
+
+ for h_cpu, y_cpu, w in processed_items:
+ h = h_cpu.to(device).float()
+ y = y_cpu.to(device)
+
+ h_prime = head_module(h, theta)
+ ce, kl, w_sum = _chunked_weighted_ce_kl(h_prime, h.detach(), lm_w, lm_b, y, w, beta)
+
+ total_loss = total_loss + ce + beta * kl
+ total_weight += w_sum
+
+ del h, y, h_prime
+
+ loss = total_loss / max(total_weight, 1.0) + lam * theta.square().sum()
+
+ optimizer.zero_grad()
+ loss.backward()
+ torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm)
+ optimizer.step()
+
+ with torch.no_grad():
+ norm = theta.norm()
+ if norm > max_grad_norm:
+ theta.mul_(max_grad_norm / norm)
+
+ if verbose and (step % 10 == 0 or step == steps - 1):
+ print(f" Step {step:3d}: loss={loss.item():.4f}, |theta|={theta.norm().item():.4f}")
+
+ del total_loss, loss
+
+ return theta.detach()