# utils_kl.py """ KL Divergence Utilities for RLVR Experiments. This module provides utilities for computing KL divergence between policy distributions, including: - Token-level KL computation - Sequence-level KL aggregation - Dataset-level KL estimation """ import torch import torch.nn.functional as F from typing import Dict, Any, List, Tuple, Optional import numpy as np import logging from tqdm import tqdm logger = logging.getLogger(__name__) # ============================================================================ # Token-Level KL Computation # ============================================================================ def compute_token_log_probs( model: torch.nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None ) -> torch.Tensor: """ Compute token-level log probabilities. Args: model: Language model input_ids: Input token IDs [batch, seq_len] attention_mask: Attention mask [batch, seq_len] labels: Token labels for which to compute log probs (default: input_ids) Returns: Token log probabilities [batch, seq_len - 1] """ if labels is None: labels = input_ids with torch.no_grad(): outputs = model( input_ids=input_ids, attention_mask=attention_mask, use_cache=False ) logits = outputs.logits # [batch, seq_len, vocab] # Shift for autoregressive: predict token t from tokens 0..t-1 shift_logits = logits[:, :-1, :] # [batch, seq_len-1, vocab] shift_labels = labels[:, 1:] # [batch, seq_len-1] # Compute log probabilities log_probs = F.log_softmax(shift_logits, dim=-1) # Gather log probs for actual tokens token_log_probs = torch.gather( log_probs, dim=-1, index=shift_labels.unsqueeze(-1) ).squeeze(-1) # [batch, seq_len-1] return token_log_probs def compute_kl_per_token( policy_log_probs: torch.Tensor, ref_log_probs: torch.Tensor ) -> torch.Tensor: """ Compute per-token KL divergence. KL(π || π_ref) at token t = log π(y_t) - log π_ref(y_t) Note: This is the forward KL from policy to reference. """ return policy_log_probs - ref_log_probs def compute_reverse_kl_per_token( policy_logits: torch.Tensor, ref_logits: torch.Tensor, temperature: float = 1.0 ) -> torch.Tensor: """ Compute per-token reverse KL divergence using full distributions. KL(π || π_ref) = Σ_y π(y) [log π(y) - log π_ref(y)] This is more expensive but gives the true KL. """ policy_probs = F.softmax(policy_logits / temperature, dim=-1) policy_log_probs = F.log_softmax(policy_logits / temperature, dim=-1) ref_log_probs = F.log_softmax(ref_logits / temperature, dim=-1) # KL = Σ p(x) log(p(x)/q(x)) = Σ p(x) [log p(x) - log q(x)] kl = (policy_probs * (policy_log_probs - ref_log_probs)).sum(dim=-1) return kl # ============================================================================ # Sequence-Level KL # ============================================================================ def compute_sequence_kl( policy_model: torch.nn.Module, ref_model: torch.nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, response_start_idx: int = 0, normalize_by_length: bool = False ) -> Dict[str, float]: """ Compute KL divergence for a single sequence. Args: policy_model: Finetuned policy model ref_model: Reference model input_ids: Full sequence (prompt + response) [1, seq_len] attention_mask: Attention mask [1, seq_len] response_start_idx: Index where response starts normalize_by_length: If True, return average KL per token Returns: Dictionary with KL metrics """ # Get log probs from both models policy_log_probs = compute_token_log_probs( policy_model, input_ids, attention_mask ) ref_log_probs = compute_token_log_probs( ref_model, input_ids, attention_mask ) # Compute per-token KL kl_per_token = compute_kl_per_token(policy_log_probs, ref_log_probs) # Create mask for response tokens only seq_len = kl_per_token.shape[1] response_mask = torch.zeros(1, seq_len, device=input_ids.device) if response_start_idx > 0: response_mask[:, response_start_idx-1:] = 1.0 else: response_mask[:, :] = 1.0 # Apply attention mask valid_mask = attention_mask[:, 1:].float() * response_mask # Compute statistics masked_kl = kl_per_token * valid_mask num_tokens = valid_mask.sum().item() total_kl = masked_kl.sum().item() result = { "total_kl": total_kl, "num_tokens": int(num_tokens), "mean_kl": total_kl / num_tokens if num_tokens > 0 else 0.0, "max_kl": (kl_per_token * valid_mask).max().item() if num_tokens > 0 else 0.0, "min_kl": (kl_per_token * valid_mask).min().item() if num_tokens > 0 else 0.0, } if normalize_by_length: result["kl"] = result["mean_kl"] else: result["kl"] = result["total_kl"] return result # ============================================================================ # Dataset-Level KL Estimation # ============================================================================ def estimate_dataset_kl( policy_model: torch.nn.Module, ref_model: torch.nn.Module, tokenizer, prompts: List[str], responses: List[str], device: torch.device, max_seq_len: int = 4096, normalize_by_length: bool = False, show_progress: bool = True ) -> Dict[str, Any]: """ Estimate KL divergence over a dataset. Args: policy_model: Finetuned policy model ref_model: Reference model tokenizer: Tokenizer for both models prompts: List of prompts responses: List of corresponding responses device: Device to use max_seq_len: Maximum sequence length normalize_by_length: If True, use mean KL per token show_progress: Show progress bar Returns: Dictionary with dataset-level KL statistics """ assert len(prompts) == len(responses), \ "Number of prompts must match responses" policy_model.eval() ref_model.eval() all_kl_values: List[float] = [] all_num_tokens: List[int] = [] iterator = zip(prompts, responses) if show_progress: iterator = tqdm( list(iterator), desc="Computing KL" ) for prompt, response in iterator: # Tokenize prompt prompt_tokens = tokenizer( prompt, return_tensors="pt", truncation=True, max_length=max_seq_len // 2 ) prompt_len = prompt_tokens["input_ids"].shape[1] # Tokenize full sequence full_text = prompt + response full_tokens = tokenizer( full_text, return_tensors="pt", truncation=True, max_length=max_seq_len ) input_ids = full_tokens["input_ids"].to(device) attention_mask = full_tokens["attention_mask"].to(device) # Compute sequence KL with torch.no_grad(): kl_result = compute_sequence_kl( policy_model=policy_model, ref_model=ref_model, input_ids=input_ids, attention_mask=attention_mask, response_start_idx=prompt_len, normalize_by_length=normalize_by_length ) all_kl_values.append(kl_result["kl"]) all_num_tokens.append(kl_result["num_tokens"]) # Aggregate statistics kl_array = np.array(all_kl_values) result = { "mean_kl": float(np.mean(kl_array)), "std_kl": float(np.std(kl_array)), "median_kl": float(np.median(kl_array)), "min_kl": float(np.min(kl_array)), "max_kl": float(np.max(kl_array)), "total_samples": len(prompts), "total_tokens": sum(all_num_tokens), "kl_values": all_kl_values, } return result # ============================================================================ # On-Task vs Off-Task KL Analysis # ============================================================================ def analyze_kl_by_task( kl_results: Dict[str, Dict[str, Any]], on_task_names: List[str], off_task_names: List[str] ) -> Dict[str, Any]: """ Analyze KL divergence patterns for on-task vs off-task. Args: kl_results: Dictionary mapping task names to KL results on_task_names: List of on-task (training distribution) names off_task_names: List of off-task names Returns: Analysis of KL patterns """ on_task_kl = [] off_task_kl = [] for name in on_task_names: if name in kl_results: on_task_kl.append(kl_results[name]["mean_kl"]) for name in off_task_names: if name in kl_results: off_task_kl.append(kl_results[name]["mean_kl"]) analysis = { "on_task": { "mean": float(np.mean(on_task_kl)) if on_task_kl else 0.0, "std": float(np.std(on_task_kl)) if on_task_kl else 0.0, "values": on_task_kl, }, "off_task": { "mean": float(np.mean(off_task_kl)) if off_task_kl else 0.0, "std": float(np.std(off_task_kl)) if off_task_kl else 0.0, "values": off_task_kl, }, } # Compute ratio if analysis["on_task"]["mean"] > 0: analysis["off_to_on_ratio"] = ( analysis["off_task"]["mean"] / analysis["on_task"]["mean"] ) else: analysis["off_to_on_ratio"] = float("inf") return analysis # ============================================================================ # KL Contribution Analysis # ============================================================================ def analyze_kl_contribution_by_layer( model: torch.nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor ) -> Dict[str, float]: """ Analyze which layers contribute most to the final prediction. This is a simplified analysis - for full KL attribution, you would need layer-wise probing. """ # This is a placeholder for more sophisticated analysis # Full implementation would require modifying the model # to output intermediate representations return { "note": "Layer-wise KL attribution not implemented", } def compute_kl_trajectory( checkpoints: List[str], ref_model_path: str, tokenizer_path: str, prompts: List[str], responses: List[str], device: torch.device ) -> List[Dict[str, Any]]: """ Compute KL divergence trajectory over training checkpoints. Useful for understanding how KL evolves during training. """ from transformers import AutoModelForCausalLM, AutoTokenizer # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token # Load reference model ref_model = AutoModelForCausalLM.from_pretrained( ref_model_path, torch_dtype=torch.bfloat16, device_map=None ).to(device) ref_model.eval() trajectory = [] for ckpt_path in tqdm(checkpoints, desc="Computing KL trajectory"): # Load checkpoint policy_model = AutoModelForCausalLM.from_pretrained( ckpt_path, torch_dtype=torch.bfloat16, device_map=None ).to(device) policy_model.eval() # Estimate KL kl_result = estimate_dataset_kl( policy_model=policy_model, ref_model=ref_model, tokenizer=tokenizer, prompts=prompts, responses=responses, device=device, show_progress=False ) trajectory.append({ "checkpoint": ckpt_path, "mean_kl": kl_result["mean_kl"], "std_kl": kl_result["std_kl"], }) # Free memory del policy_model torch.cuda.empty_cache() return trajectory