From f1c2cc22d46a6976df3555391e667c7e61592fad Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 4 Feb 2026 18:59:35 -0600 Subject: Initial commit: RL floating-point noise project --- utils_kl.py | 419 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 419 insertions(+) create mode 100644 utils_kl.py (limited to 'utils_kl.py') diff --git a/utils_kl.py b/utils_kl.py new file mode 100644 index 0000000..2be50a0 --- /dev/null +++ b/utils_kl.py @@ -0,0 +1,419 @@ +# utils_kl.py +""" +KL Divergence Utilities for RLVR Experiments. + +This module provides utilities for computing KL divergence between +policy distributions, including: +- Token-level KL computation +- Sequence-level KL aggregation +- Dataset-level KL estimation +""" + +import torch +import torch.nn.functional as F +from typing import Dict, Any, List, Tuple, Optional +import numpy as np +import logging +from tqdm import tqdm + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Token-Level KL Computation +# ============================================================================ + +def compute_token_log_probs( + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + labels: Optional[torch.Tensor] = None +) -> torch.Tensor: + """ + Compute token-level log probabilities. + + Args: + model: Language model + input_ids: Input token IDs [batch, seq_len] + attention_mask: Attention mask [batch, seq_len] + labels: Token labels for which to compute log probs (default: input_ids) + + Returns: + Token log probabilities [batch, seq_len - 1] + """ + if labels is None: + labels = input_ids + + with torch.no_grad(): + outputs = model( + input_ids=input_ids, + attention_mask=attention_mask, + use_cache=False + ) + + logits = outputs.logits # [batch, seq_len, vocab] + + # Shift for autoregressive: predict token t from tokens 0..t-1 + shift_logits = logits[:, :-1, :] # [batch, seq_len-1, vocab] + shift_labels = labels[:, 1:] # [batch, seq_len-1] + + # Compute log probabilities + log_probs = F.log_softmax(shift_logits, dim=-1) + + # Gather log probs for actual tokens + token_log_probs = torch.gather( + log_probs, + dim=-1, + index=shift_labels.unsqueeze(-1) + ).squeeze(-1) # [batch, seq_len-1] + + return token_log_probs + + +def compute_kl_per_token( + policy_log_probs: torch.Tensor, + ref_log_probs: torch.Tensor +) -> torch.Tensor: + """ + Compute per-token KL divergence. + + KL(π || π_ref) at token t = log π(y_t) - log π_ref(y_t) + + Note: This is the forward KL from policy to reference. + """ + return policy_log_probs - ref_log_probs + + +def compute_reverse_kl_per_token( + policy_logits: torch.Tensor, + ref_logits: torch.Tensor, + temperature: float = 1.0 +) -> torch.Tensor: + """ + Compute per-token reverse KL divergence using full distributions. + + KL(π || π_ref) = Σ_y π(y) [log π(y) - log π_ref(y)] + + This is more expensive but gives the true KL. + """ + policy_probs = F.softmax(policy_logits / temperature, dim=-1) + policy_log_probs = F.log_softmax(policy_logits / temperature, dim=-1) + ref_log_probs = F.log_softmax(ref_logits / temperature, dim=-1) + + # KL = Σ p(x) log(p(x)/q(x)) = Σ p(x) [log p(x) - log q(x)] + kl = (policy_probs * (policy_log_probs - ref_log_probs)).sum(dim=-1) + + return kl + + +# ============================================================================ +# Sequence-Level KL +# ============================================================================ + +def compute_sequence_kl( + policy_model: torch.nn.Module, + ref_model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + response_start_idx: int = 0, + normalize_by_length: bool = False +) -> Dict[str, float]: + """ + Compute KL divergence for a single sequence. + + Args: + policy_model: Finetuned policy model + ref_model: Reference model + input_ids: Full sequence (prompt + response) [1, seq_len] + attention_mask: Attention mask [1, seq_len] + response_start_idx: Index where response starts + normalize_by_length: If True, return average KL per token + + Returns: + Dictionary with KL metrics + """ + # Get log probs from both models + policy_log_probs = compute_token_log_probs( + policy_model, input_ids, attention_mask + ) + ref_log_probs = compute_token_log_probs( + ref_model, input_ids, attention_mask + ) + + # Compute per-token KL + kl_per_token = compute_kl_per_token(policy_log_probs, ref_log_probs) + + # Create mask for response tokens only + seq_len = kl_per_token.shape[1] + response_mask = torch.zeros(1, seq_len, device=input_ids.device) + if response_start_idx > 0: + response_mask[:, response_start_idx-1:] = 1.0 + else: + response_mask[:, :] = 1.0 + + # Apply attention mask + valid_mask = attention_mask[:, 1:].float() * response_mask + + # Compute statistics + masked_kl = kl_per_token * valid_mask + num_tokens = valid_mask.sum().item() + total_kl = masked_kl.sum().item() + + result = { + "total_kl": total_kl, + "num_tokens": int(num_tokens), + "mean_kl": total_kl / num_tokens if num_tokens > 0 else 0.0, + "max_kl": (kl_per_token * valid_mask).max().item() if num_tokens > 0 else 0.0, + "min_kl": (kl_per_token * valid_mask).min().item() if num_tokens > 0 else 0.0, + } + + if normalize_by_length: + result["kl"] = result["mean_kl"] + else: + result["kl"] = result["total_kl"] + + return result + + +# ============================================================================ +# Dataset-Level KL Estimation +# ============================================================================ + +def estimate_dataset_kl( + policy_model: torch.nn.Module, + ref_model: torch.nn.Module, + tokenizer, + prompts: List[str], + responses: List[str], + device: torch.device, + max_seq_len: int = 4096, + normalize_by_length: bool = False, + show_progress: bool = True +) -> Dict[str, Any]: + """ + Estimate KL divergence over a dataset. + + Args: + policy_model: Finetuned policy model + ref_model: Reference model + tokenizer: Tokenizer for both models + prompts: List of prompts + responses: List of corresponding responses + device: Device to use + max_seq_len: Maximum sequence length + normalize_by_length: If True, use mean KL per token + show_progress: Show progress bar + + Returns: + Dictionary with dataset-level KL statistics + """ + assert len(prompts) == len(responses), \ + "Number of prompts must match responses" + + policy_model.eval() + ref_model.eval() + + all_kl_values: List[float] = [] + all_num_tokens: List[int] = [] + + iterator = zip(prompts, responses) + if show_progress: + iterator = tqdm( + list(iterator), + desc="Computing KL" + ) + + for prompt, response in iterator: + # Tokenize prompt + prompt_tokens = tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=max_seq_len // 2 + ) + prompt_len = prompt_tokens["input_ids"].shape[1] + + # Tokenize full sequence + full_text = prompt + response + full_tokens = tokenizer( + full_text, + return_tensors="pt", + truncation=True, + max_length=max_seq_len + ) + + input_ids = full_tokens["input_ids"].to(device) + attention_mask = full_tokens["attention_mask"].to(device) + + # Compute sequence KL + with torch.no_grad(): + kl_result = compute_sequence_kl( + policy_model=policy_model, + ref_model=ref_model, + input_ids=input_ids, + attention_mask=attention_mask, + response_start_idx=prompt_len, + normalize_by_length=normalize_by_length + ) + + all_kl_values.append(kl_result["kl"]) + all_num_tokens.append(kl_result["num_tokens"]) + + # Aggregate statistics + kl_array = np.array(all_kl_values) + + result = { + "mean_kl": float(np.mean(kl_array)), + "std_kl": float(np.std(kl_array)), + "median_kl": float(np.median(kl_array)), + "min_kl": float(np.min(kl_array)), + "max_kl": float(np.max(kl_array)), + "total_samples": len(prompts), + "total_tokens": sum(all_num_tokens), + "kl_values": all_kl_values, + } + + return result + + +# ============================================================================ +# On-Task vs Off-Task KL Analysis +# ============================================================================ + +def analyze_kl_by_task( + kl_results: Dict[str, Dict[str, Any]], + on_task_names: List[str], + off_task_names: List[str] +) -> Dict[str, Any]: + """ + Analyze KL divergence patterns for on-task vs off-task. + + Args: + kl_results: Dictionary mapping task names to KL results + on_task_names: List of on-task (training distribution) names + off_task_names: List of off-task names + + Returns: + Analysis of KL patterns + """ + on_task_kl = [] + off_task_kl = [] + + for name in on_task_names: + if name in kl_results: + on_task_kl.append(kl_results[name]["mean_kl"]) + + for name in off_task_names: + if name in kl_results: + off_task_kl.append(kl_results[name]["mean_kl"]) + + analysis = { + "on_task": { + "mean": float(np.mean(on_task_kl)) if on_task_kl else 0.0, + "std": float(np.std(on_task_kl)) if on_task_kl else 0.0, + "values": on_task_kl, + }, + "off_task": { + "mean": float(np.mean(off_task_kl)) if off_task_kl else 0.0, + "std": float(np.std(off_task_kl)) if off_task_kl else 0.0, + "values": off_task_kl, + }, + } + + # Compute ratio + if analysis["on_task"]["mean"] > 0: + analysis["off_to_on_ratio"] = ( + analysis["off_task"]["mean"] / analysis["on_task"]["mean"] + ) + else: + analysis["off_to_on_ratio"] = float("inf") + + return analysis + + +# ============================================================================ +# KL Contribution Analysis +# ============================================================================ + +def analyze_kl_contribution_by_layer( + model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor +) -> Dict[str, float]: + """ + Analyze which layers contribute most to the final prediction. + + This is a simplified analysis - for full KL attribution, + you would need layer-wise probing. + """ + # This is a placeholder for more sophisticated analysis + # Full implementation would require modifying the model + # to output intermediate representations + + return { + "note": "Layer-wise KL attribution not implemented", + } + + +def compute_kl_trajectory( + checkpoints: List[str], + ref_model_path: str, + tokenizer_path: str, + prompts: List[str], + responses: List[str], + device: torch.device +) -> List[Dict[str, Any]]: + """ + Compute KL divergence trajectory over training checkpoints. + + Useful for understanding how KL evolves during training. + """ + from transformers import AutoModelForCausalLM, AutoTokenizer + + # Load tokenizer + tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, use_fast=True) + if tokenizer.pad_token is None: + tokenizer.pad_token = tokenizer.eos_token + + # Load reference model + ref_model = AutoModelForCausalLM.from_pretrained( + ref_model_path, + torch_dtype=torch.bfloat16, + device_map=None + ).to(device) + ref_model.eval() + + trajectory = [] + + for ckpt_path in tqdm(checkpoints, desc="Computing KL trajectory"): + # Load checkpoint + policy_model = AutoModelForCausalLM.from_pretrained( + ckpt_path, + torch_dtype=torch.bfloat16, + device_map=None + ).to(device) + policy_model.eval() + + # Estimate KL + kl_result = estimate_dataset_kl( + policy_model=policy_model, + ref_model=ref_model, + tokenizer=tokenizer, + prompts=prompts, + responses=responses, + device=device, + show_progress=False + ) + + trajectory.append({ + "checkpoint": ckpt_path, + "mean_kl": kl_result["mean_kl"], + "std_kl": kl_result["std_kl"], + }) + + # Free memory + del policy_model + torch.cuda.empty_cache() + + return trajectory + -- cgit v1.2.3