"""Style-weighted theta fitting. Weights style-indicative tokens (punctuation, newlines, first-person pronouns, discourse markers, function words) more heavily in the CE loss. """ import torch import torch.nn.functional as F # Maximum chunk size for logit computation CHUNK_SIZE = 128 # Style-indicator token categories and their weights # These are identified by checking if the decoded token matches certain patterns STYLE_INDICATORS = { # Punctuation and formatting '!': 3.0, '?': 3.0, '.': 2.0, ',': 2.0, ';': 2.0, ':': 2.0, '...': 3.0, '-': 2.0, '"': 2.0, "'": 2.0, '\n': 4.0, '\n\n': 4.0, # First-person pronouns 'i': 3.0, 'me': 3.0, 'my': 3.0, 'mine': 3.0, 'myself': 3.0, 'we': 3.0, 'us': 3.0, 'our': 3.0, 'ours': 3.0, # Discourse markers and hedges 'however': 2.0, 'although': 2.0, 'moreover': 2.0, 'furthermore': 2.0, 'basically': 2.0, 'actually': 2.0, 'honestly': 2.0, 'frankly': 2.0, 'definitely': 2.0, 'absolutely': 2.0, 'obviously': 2.0, # Sentiment/intensity 'very': 2.0, 'really': 2.0, 'quite': 2.0, 'extremely': 2.0, 'amazing': 2.0, 'terrible': 2.0, 'awesome': 2.0, 'horrible': 2.0, 'love': 2.0, 'hate': 2.0, 'loved': 2.0, 'hated': 2.0, } def build_token_weights(label_ids: torch.Tensor, tokenizer, device: str) -> torch.Tensor: """Build per-token weights for style-weighted loss. Args: label_ids: (T,) tensor of token ids tokenizer: The tokenizer to decode token ids Returns: weights: (T,) tensor of per-token weights (≥ 1.0) """ weights = torch.ones(label_ids.shape[0], device=device, dtype=torch.float32) for i, tok_id in enumerate(label_ids): tok_str = tokenizer.decode([tok_id.item()]).strip().lower() if tok_str in STYLE_INDICATORS: weights[i] = STYLE_INDICATORS[tok_str] return weights def _chunked_weighted_ce_kl(h_prime, h_base, lm_w, lm_b, y, weights, beta): """Compute weighted CE + KL in chunks.""" seq_len = h_prime.shape[0] total_ce = 0.0 total_kl = 0.0 total_weight = 0.0 for start in range(0, seq_len, CHUNK_SIZE): end = min(start + CHUNK_SIZE, seq_len) hp_chunk = h_prime[start:end] hb_chunk = h_base[start:end] y_chunk = y[start:end] w_chunk = weights[start:end] logits = F.linear(hp_chunk, lm_w, lm_b) base_logits = F.linear(hb_chunk, lm_w, lm_b) # Weighted CE: per-token CE * weight ce_per_tok = F.cross_entropy(logits, y_chunk, reduction='none') # (chunk,) total_ce = total_ce + (ce_per_tok * w_chunk).sum() total_weight += w_chunk.sum().item() if beta > 0: log_p = F.log_softmax(logits, dim=-1) p0 = F.softmax(base_logits.detach(), dim=-1) total_kl = total_kl + F.kl_div(log_p, p0, reduction='sum') del logits, base_logits if beta > 0: del log_p, p0 return total_ce, total_kl, total_weight def fit_theta_weighted( cached_h: list, # list of (h_states, label_ids) lm_head_weight: torch.Tensor, lm_head_bias: torch.Tensor | None, head_module, tokenizer, d: int = 64, lr: float = 0.05, steps: int = 30, beta: float = 0.05, lam: float = 1e-4, max_grad_norm: float = 5.0, device: str = "cuda:1", max_tokens_per_item: int = 128, verbose: bool = False, ) -> torch.Tensor: """Fit theta with style-weighted CE loss and per-item token budget.""" theta = torch.zeros(d, device=device, requires_grad=True, dtype=torch.float32) optimizer = torch.optim.Adam([theta], lr=lr) lm_w = lm_head_weight.float() lm_b = lm_head_bias.float() if lm_head_bias is not None else None # Pre-build token weights and truncate to max_tokens_per_item processed_items = [] for h_cpu, y_cpu in cached_h: T = min(h_cpu.shape[0], max_tokens_per_item) h_trunc = h_cpu[:T] y_trunc = y_cpu[:T] w = build_token_weights(y_trunc, tokenizer, device) processed_items.append((h_trunc, y_trunc, w)) for step in range(steps): total_loss = 0.0 total_weight = 0.0 for h_cpu, y_cpu, w in processed_items: h = h_cpu.to(device).float() y = y_cpu.to(device) h_prime = head_module(h, theta) ce, kl, w_sum = _chunked_weighted_ce_kl(h_prime, h.detach(), lm_w, lm_b, y, w, beta) total_loss = total_loss + ce + beta * kl total_weight += w_sum del h, y, h_prime loss = total_loss / max(total_weight, 1.0) + lam * theta.square().sum() optimizer.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_([theta], max_norm=max_grad_norm) optimizer.step() with torch.no_grad(): norm = theta.norm() if norm > max_grad_norm: theta.mul_(max_grad_norm / norm) if verbose and (step % 10 == 0 or step == steps - 1): print(f" Step {step:3d}: loss={loss.item():.4f}, |theta|={theta.norm().item():.4f}") del total_loss, loss return theta.detach()