summaryrefslogtreecommitdiff
path: root/utils_kl.py
diff options
context:
space:
mode:
Diffstat (limited to 'utils_kl.py')
-rw-r--r--utils_kl.py419
1 files changed, 419 insertions, 0 deletions
diff --git a/utils_kl.py b/utils_kl.py
new file mode 100644
index 0000000..2be50a0
--- /dev/null
+++ b/utils_kl.py
@@ -0,0 +1,419 @@
+# utils_kl.py
+"""
+KL Divergence Utilities for RLVR Experiments.
+
+This module provides utilities for computing KL divergence between
+policy distributions, including:
+- Token-level KL computation
+- Sequence-level KL aggregation
+- Dataset-level KL estimation
+"""
+
+import torch
+import torch.nn.functional as F
+from typing import Dict, Any, List, Tuple, Optional
+import numpy as np
+import logging
+from tqdm import tqdm
+
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Token-Level KL Computation
+# ============================================================================
+
+def compute_token_log_probs(
+ model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ labels: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+ """
+ Compute token-level log probabilities.
+
+ Args:
+ model: Language model
+ input_ids: Input token IDs [batch, seq_len]
+ attention_mask: Attention mask [batch, seq_len]
+ labels: Token labels for which to compute log probs (default: input_ids)
+
+ Returns:
+ Token log probabilities [batch, seq_len - 1]
+ """
+ if labels is None:
+ labels = input_ids
+
+ with torch.no_grad():
+ outputs = model(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ use_cache=False
+ )
+
+ logits = outputs.logits # [batch, seq_len, vocab]
+
+ # Shift for autoregressive: predict token t from tokens 0..t-1
+ shift_logits = logits[:, :-1, :] # [batch, seq_len-1, vocab]
+ shift_labels = labels[:, 1:] # [batch, seq_len-1]
+
+ # Compute log probabilities
+ log_probs = F.log_softmax(shift_logits, dim=-1)
+
+ # Gather log probs for actual tokens
+ token_log_probs = torch.gather(
+ log_probs,
+ dim=-1,
+ index=shift_labels.unsqueeze(-1)
+ ).squeeze(-1) # [batch, seq_len-1]
+
+ return token_log_probs
+
+
+def compute_kl_per_token(
+ policy_log_probs: torch.Tensor,
+ ref_log_probs: torch.Tensor
+) -> torch.Tensor:
+ """
+ Compute per-token KL divergence.
+
+ KL(π || π_ref) at token t = log π(y_t) - log π_ref(y_t)
+
+ Note: This is the forward KL from policy to reference.
+ """
+ return policy_log_probs - ref_log_probs
+
+
+def compute_reverse_kl_per_token(
+ policy_logits: torch.Tensor,
+ ref_logits: torch.Tensor,
+ temperature: float = 1.0
+) -> torch.Tensor:
+ """
+ Compute per-token reverse KL divergence using full distributions.
+
+ KL(π || π_ref) = Σ_y π(y) [log π(y) - log π_ref(y)]
+
+ This is more expensive but gives the true KL.
+ """
+ policy_probs = F.softmax(policy_logits / temperature, dim=-1)
+ policy_log_probs = F.log_softmax(policy_logits / temperature, dim=-1)
+ ref_log_probs = F.log_softmax(ref_logits / temperature, dim=-1)
+
+ # KL = Σ p(x) log(p(x)/q(x)) = Σ p(x) [log p(x) - log q(x)]
+ kl = (policy_probs * (policy_log_probs - ref_log_probs)).sum(dim=-1)
+
+ return kl
+
+
+# ============================================================================
+# Sequence-Level KL
+# ============================================================================
+
+def compute_sequence_kl(
+ policy_model: torch.nn.Module,
+ ref_model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ response_start_idx: int = 0,
+ normalize_by_length: bool = False
+) -> Dict[str, float]:
+ """
+ Compute KL divergence for a single sequence.
+
+ Args:
+ policy_model: Finetuned policy model
+ ref_model: Reference model
+ input_ids: Full sequence (prompt + response) [1, seq_len]
+ attention_mask: Attention mask [1, seq_len]
+ response_start_idx: Index where response starts
+ normalize_by_length: If True, return average KL per token
+
+ Returns:
+ Dictionary with KL metrics
+ """
+ # Get log probs from both models
+ policy_log_probs = compute_token_log_probs(
+ policy_model, input_ids, attention_mask
+ )
+ ref_log_probs = compute_token_log_probs(
+ ref_model, input_ids, attention_mask
+ )
+
+ # Compute per-token KL
+ kl_per_token = compute_kl_per_token(policy_log_probs, ref_log_probs)
+
+ # Create mask for response tokens only
+ seq_len = kl_per_token.shape[1]
+ response_mask = torch.zeros(1, seq_len, device=input_ids.device)
+ if response_start_idx > 0:
+ response_mask[:, response_start_idx-1:] = 1.0
+ else:
+ response_mask[:, :] = 1.0
+
+ # Apply attention mask
+ valid_mask = attention_mask[:, 1:].float() * response_mask
+
+ # Compute statistics
+ masked_kl = kl_per_token * valid_mask
+ num_tokens = valid_mask.sum().item()
+ total_kl = masked_kl.sum().item()
+
+ result = {
+ "total_kl": total_kl,
+ "num_tokens": int(num_tokens),
+ "mean_kl": total_kl / num_tokens if num_tokens > 0 else 0.0,
+ "max_kl": (kl_per_token * valid_mask).max().item() if num_tokens > 0 else 0.0,
+ "min_kl": (kl_per_token * valid_mask).min().item() if num_tokens > 0 else 0.0,
+ }
+
+ if normalize_by_length:
+ result["kl"] = result["mean_kl"]
+ else:
+ result["kl"] = result["total_kl"]
+
+ return result
+
+
+# ============================================================================
+# Dataset-Level KL Estimation
+# ============================================================================
+
+def estimate_dataset_kl(
+ policy_model: torch.nn.Module,
+ ref_model: torch.nn.Module,
+ tokenizer,
+ prompts: List[str],
+ responses: List[str],
+ device: torch.device,
+ max_seq_len: int = 4096,
+ normalize_by_length: bool = False,
+ show_progress: bool = True
+) -> Dict[str, Any]:
+ """
+ Estimate KL divergence over a dataset.
+
+ Args:
+ policy_model: Finetuned policy model
+ ref_model: Reference model
+ tokenizer: Tokenizer for both models
+ prompts: List of prompts
+ responses: List of corresponding responses
+ device: Device to use
+ max_seq_len: Maximum sequence length
+ normalize_by_length: If True, use mean KL per token
+ show_progress: Show progress bar
+
+ Returns:
+ Dictionary with dataset-level KL statistics
+ """
+ assert len(prompts) == len(responses), \
+ "Number of prompts must match responses"
+
+ policy_model.eval()
+ ref_model.eval()
+
+ all_kl_values: List[float] = []
+ all_num_tokens: List[int] = []
+
+ iterator = zip(prompts, responses)
+ if show_progress:
+ iterator = tqdm(
+ list(iterator),
+ desc="Computing KL"
+ )
+
+ for prompt, response in iterator:
+ # Tokenize prompt
+ prompt_tokens = tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=max_seq_len // 2
+ )
+ prompt_len = prompt_tokens["input_ids"].shape[1]
+
+ # Tokenize full sequence
+ full_text = prompt + response
+ full_tokens = tokenizer(
+ full_text,
+ return_tensors="pt",
+ truncation=True,
+ max_length=max_seq_len
+ )
+
+ input_ids = full_tokens["input_ids"].to(device)
+ attention_mask = full_tokens["attention_mask"].to(device)
+
+ # Compute sequence KL
+ with torch.no_grad():
+ kl_result = compute_sequence_kl(
+ policy_model=policy_model,
+ ref_model=ref_model,
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ response_start_idx=prompt_len,
+ normalize_by_length=normalize_by_length
+ )
+
+ all_kl_values.append(kl_result["kl"])
+ all_num_tokens.append(kl_result["num_tokens"])
+
+ # Aggregate statistics
+ kl_array = np.array(all_kl_values)
+
+ result = {
+ "mean_kl": float(np.mean(kl_array)),
+ "std_kl": float(np.std(kl_array)),
+ "median_kl": float(np.median(kl_array)),
+ "min_kl": float(np.min(kl_array)),
+ "max_kl": float(np.max(kl_array)),
+ "total_samples": len(prompts),
+ "total_tokens": sum(all_num_tokens),
+ "kl_values": all_kl_values,
+ }
+
+ return result
+
+
+# ============================================================================
+# On-Task vs Off-Task KL Analysis
+# ============================================================================
+
+def analyze_kl_by_task(
+ kl_results: Dict[str, Dict[str, Any]],
+ on_task_names: List[str],
+ off_task_names: List[str]
+) -> Dict[str, Any]:
+ """
+ Analyze KL divergence patterns for on-task vs off-task.
+
+ Args:
+ kl_results: Dictionary mapping task names to KL results
+ on_task_names: List of on-task (training distribution) names
+ off_task_names: List of off-task names
+
+ Returns:
+ Analysis of KL patterns
+ """
+ on_task_kl = []
+ off_task_kl = []
+
+ for name in on_task_names:
+ if name in kl_results:
+ on_task_kl.append(kl_results[name]["mean_kl"])
+
+ for name in off_task_names:
+ if name in kl_results:
+ off_task_kl.append(kl_results[name]["mean_kl"])
+
+ analysis = {
+ "on_task": {
+ "mean": float(np.mean(on_task_kl)) if on_task_kl else 0.0,
+ "std": float(np.std(on_task_kl)) if on_task_kl else 0.0,
+ "values": on_task_kl,
+ },
+ "off_task": {
+ "mean": float(np.mean(off_task_kl)) if off_task_kl else 0.0,
+ "std": float(np.std(off_task_kl)) if off_task_kl else 0.0,
+ "values": off_task_kl,
+ },
+ }
+
+ # Compute ratio
+ if analysis["on_task"]["mean"] > 0:
+ analysis["off_to_on_ratio"] = (
+ analysis["off_task"]["mean"] / analysis["on_task"]["mean"]
+ )
+ else:
+ analysis["off_to_on_ratio"] = float("inf")
+
+ return analysis
+
+
+# ============================================================================
+# KL Contribution Analysis
+# ============================================================================
+
+def analyze_kl_contribution_by_layer(
+ model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor
+) -> Dict[str, float]:
+ """
+ Analyze which layers contribute most to the final prediction.
+
+ This is a simplified analysis - for full KL attribution,
+ you would need layer-wise probing.
+ """
+ # This is a placeholder for more sophisticated analysis
+ # Full implementation would require modifying the model
+ # to output intermediate representations
+
+ return {
+ "note": "Layer-wise KL attribution not implemented",
+ }
+
+
+def compute_kl_trajectory(
+ checkpoints: List[str],
+ ref_model_path: str,
+ tokenizer_path: str,
+ prompts: List[str],
+ responses: List[str],
+ device: torch.device
+) -> List[Dict[str, Any]]:
+ """
+ Compute KL divergence trajectory over training checkpoints.
+
+ Useful for understanding how KL evolves during training.
+ """
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ # Load tokenizer
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True)
+ if tokenizer.pad_token is None:
+ tokenizer.pad_token = tokenizer.eos_token
+
+ # Load reference model
+ ref_model = AutoModelForCausalLM.from_pretrained(
+ ref_model_path,
+ torch_dtype=torch.bfloat16,
+ device_map=None
+ ).to(device)
+ ref_model.eval()
+
+ trajectory = []
+
+ for ckpt_path in tqdm(checkpoints, desc="Computing KL trajectory"):
+ # Load checkpoint
+ policy_model = AutoModelForCausalLM.from_pretrained(
+ ckpt_path,
+ torch_dtype=torch.bfloat16,
+ device_map=None
+ ).to(device)
+ policy_model.eval()
+
+ # Estimate KL
+ kl_result = estimate_dataset_kl(
+ policy_model=policy_model,
+ ref_model=ref_model,
+ tokenizer=tokenizer,
+ prompts=prompts,
+ responses=responses,
+ device=device,
+ show_progress=False
+ )
+
+ trajectory.append({
+ "checkpoint": ckpt_path,
+ "mean_kl": kl_result["mean_kl"],
+ "std_kl": kl_result["std_kl"],
+ })
+
+ # Free memory
+ del policy_model
+ torch.cuda.empty_cache()
+
+ return trajectory
+