diff options
Diffstat (limited to 'eval_policy.py')
| -rw-r--r-- | eval_policy.py | 621 |
1 files changed, 621 insertions, 0 deletions
diff --git a/eval_policy.py b/eval_policy.py new file mode 100644 index 0000000..cc30209 --- /dev/null +++ b/eval_policy.py @@ -0,0 +1,621 @@ +#!/usr/bin/env python3 +# eval_policy.py +""" +Policy Evaluation Script for RLVR Experiments. + +This script evaluates trained models on multiple tasks, computing: +- J_k: Task performance (pass@1 accuracy for verifiable tasks) +- KL_k: KL divergence from base model + +Usage: + python eval_policy.py \ + --base_ckpt Qwen/Qwen2.5-Math-7B \ + --ft_ckpt results/train_logs/fp32_seed1/final_model \ + --eval_tasks_config configs/eval_tasks_config.json \ + --output_path results/eval_metrics/fp32_seed1.json +""" + +import argparse +import json +import os +import logging +from typing import Dict, Any, List, Optional, Tuple +from dataclasses import dataclass, asdict + +import numpy as np +import torch +from torch.cuda.amp import autocast +from transformers import AutoModelForCausalLM, AutoTokenizer +from tqdm import tqdm + +from config import EvalTaskConfig + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" +) +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Data Loading +# ============================================================================ + +def load_eval_tasks(eval_config_path: str) -> List[EvalTaskConfig]: + """Load evaluation task configurations from JSON file.""" + with open(eval_config_path, "r", encoding="utf-8") as f: + data = json.load(f) + + tasks: List[EvalTaskConfig] = [] + for task_item in data: + task = EvalTaskConfig( + name=task_item.get("name", ""), + task_type=task_item.get("task_type", "math"), + dataset_path=task_item.get("dataset_path", ""), + is_verifiable=task_item.get("is_verifiable", True), + metric_type=task_item.get("metric_type", "accuracy"), + num_samples=task_item.get("num_samples", -1), + max_gen_len=task_item.get("max_gen_len", 2048), + temperature=task_item.get("temperature", 0.7), + top_p=task_item.get("top_p", 0.8), + num_samples_per_prompt=task_item.get("num_samples_per_prompt", 1) + ) + tasks.append(task) + + logger.info(f"Loaded {len(tasks)} evaluation tasks from {eval_config_path}") + return tasks + + +def load_dataset(dataset_path: str, num_samples: int = -1) -> List[Dict[str, Any]]: + """Load evaluation dataset from JSON file.""" + with open(dataset_path, "r", encoding="utf-8") as f: + data = json.load(f) + + if num_samples > 0 and num_samples < len(data): + data = data[:num_samples] + + logger.info(f"Loaded {len(data)} examples from {dataset_path}") + return data + + +# ============================================================================ +# Answer Verification +# ============================================================================ + +def extract_boxed_answer(text: str) -> Optional[str]: + """Extract answer from \\boxed{} format.""" + import re + + # Find all \boxed{...} patterns + pattern = r"\\boxed\{([^}]*)\}" + matches = re.findall(pattern, text) + + if matches: + return matches[-1].strip() # Return last match + + return None + + +def extract_final_answer(text: str) -> Optional[str]: + """Extract final answer using various heuristics.""" + # Try boxed format first + boxed = extract_boxed_answer(text) + if boxed: + return boxed + + # Common answer patterns + patterns = [ + r"[Tt]he (?:final )?answer is[:\s]+(.+?)(?:\.|$)", + r"[Tt]herefore[,:\s]+(?:the answer is[:\s]+)?(.+?)(?:\.|$)", + r"[Ss]o[,:\s]+(?:the answer is[:\s]+)?(.+?)(?:\.|$)", + r"[Hh]ence[,:\s]+(.+?)(?:\.|$)", + r"=\s*(.+?)$", + ] + + import re + for pattern in patterns: + match = re.search(pattern, text, re.MULTILINE) + if match: + return match.group(1).strip() + + return None + + +def normalize_answer(answer: str) -> str: + """Normalize answer for comparison.""" + if answer is None: + return "" + + # Convert to lowercase, remove whitespace + normalized = answer.lower().strip() + + # Remove common formatting + normalized = normalized.replace(",", "") + normalized = normalized.replace("$", "") + normalized = normalized.replace("%", "") + + # Try to extract numeric value + import re + numeric_match = re.search(r"-?\d+\.?\d*", normalized) + if numeric_match: + return numeric_match.group() + + return normalized + + +def verify_math_answer( + response: str, + ground_truth: str +) -> bool: + """ + Verify if the response contains the correct answer. + + This is a simplified verifier. For production use, replace with + Eval-Chemy or a more sophisticated verification system. + """ + # Extract answers + predicted = extract_final_answer(response) + + if predicted is None: + return False + + # Normalize for comparison + pred_normalized = normalize_answer(predicted) + gt_normalized = normalize_answer(ground_truth) + + # Direct comparison + if pred_normalized == gt_normalized: + return True + + # Try numeric comparison + try: + pred_num = float(pred_normalized) + gt_num = float(gt_normalized) + if abs(pred_num - gt_num) < 1e-6: + return True + except ValueError: + pass + + return False + + +# ============================================================================ +# KL Divergence Computation +# ============================================================================ + +def compute_sequence_kl( + finetuned_model: torch.nn.Module, + base_model: torch.nn.Module, + input_ids: torch.Tensor, + attention_mask: torch.Tensor, + response_start_idx: int, + device: torch.device +) -> Tuple[float, int]: + """ + Compute KL divergence for a single sequence. + + KL(π_ft || π_base) ≈ Σ_t [log π_ft(y_t|x,y_{<t}) - log π_base(y_t|x,y_{<t})] + + Returns: + Tuple of (kl_sum, num_tokens) + """ + with torch.no_grad(): + # Get logits from both models + ft_outputs = finetuned_model( + input_ids=input_ids, + attention_mask=attention_mask + ) + base_outputs = base_model( + input_ids=input_ids, + attention_mask=attention_mask + ) + + ft_logits = ft_outputs.logits + base_logits = base_outputs.logits + + # Compute log probabilities + ft_log_probs = torch.log_softmax(ft_logits, dim=-1) + base_log_probs = torch.log_softmax(base_logits, dim=-1) + + # Get log probs for actual tokens (shifted for autoregressive) + shift_ft_log_probs = ft_log_probs[:, :-1, :] + shift_base_log_probs = base_log_probs[:, :-1, :] + shift_labels = input_ids[:, 1:] + + ft_token_log_probs = torch.gather( + shift_ft_log_probs, + dim=-1, + index=shift_labels.unsqueeze(-1) + ).squeeze(-1) + + base_token_log_probs = torch.gather( + shift_base_log_probs, + dim=-1, + index=shift_labels.unsqueeze(-1) + ).squeeze(-1) + + # Compute KL only for response tokens + kl_per_token = ft_token_log_probs - base_token_log_probs + + # Mask for response tokens only + response_mask = torch.zeros_like(kl_per_token) + response_mask[:, response_start_idx-1:] = 1.0 + + # Apply attention mask + valid_mask = attention_mask[:, 1:].float() * response_mask + + kl_sum = (kl_per_token * valid_mask).sum().item() + num_tokens = valid_mask.sum().item() + + return kl_sum, int(num_tokens) + + +# ============================================================================ +# Evaluation Functions +# ============================================================================ + +@dataclass +class TaskResult: + """Results for a single evaluation task.""" + task_name: str + task_type: str + num_examples: int + avg_score: float + std_score: float + avg_kl: float + std_kl: float + avg_response_length: float + scores: List[float] + kl_values: List[float] + + +def evaluate_task( + base_model: torch.nn.Module, + base_tokenizer, + finetuned_model: torch.nn.Module, + finetuned_tokenizer, + task_config: EvalTaskConfig, + device: torch.device, + use_amp: bool = True +) -> TaskResult: + """ + Evaluate a single task. + + Computes: + - avg_score: Mean accuracy (for verifiable tasks) + - avg_kl: Mean KL divergence from base model + """ + dataset = load_dataset(task_config.dataset_path, task_config.num_samples) + + scores: List[float] = [] + kl_values: List[float] = [] + response_lengths: List[int] = [] + + finetuned_model.eval() + base_model.eval() + + amp_dtype = torch.bfloat16 if use_amp else torch.float32 + + for example in tqdm(dataset, desc=f"Evaluating {task_config.name}"): + prompt = example.get("prompt", example.get("question", "")) + ground_truth = example.get("answer", example.get("solution", None)) + + # Tokenize prompt + inputs = finetuned_tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=4096 + ) + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + prompt_len = input_ids.shape[1] + + # Generate response + with torch.no_grad(): + with autocast(enabled=use_amp, dtype=amp_dtype): + generated_ids = finetuned_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=task_config.max_gen_len, + do_sample=True, + temperature=task_config.temperature, + top_p=task_config.top_p, + pad_token_id=finetuned_tokenizer.eos_token_id + ) + + # Decode response + response_ids = generated_ids[:, prompt_len:] + response_text = finetuned_tokenizer.batch_decode( + response_ids, + skip_special_tokens=True + )[0] + + response_lengths.append(len(response_ids[0])) + + # Compute score (accuracy for verifiable tasks) + if task_config.is_verifiable and ground_truth is not None: + is_correct = verify_math_answer(response_text, str(ground_truth)) + score = 1.0 if is_correct else 0.0 + else: + # For non-verifiable tasks, use placeholder + score = 0.0 + + scores.append(score) + + # Compute KL divergence + full_ids = generated_ids + full_attention = torch.ones_like(full_ids, device=device) + + kl_sum, num_tokens = compute_sequence_kl( + finetuned_model=finetuned_model, + base_model=base_model, + input_ids=full_ids, + attention_mask=full_attention, + response_start_idx=prompt_len, + device=device + ) + + if num_tokens > 0: + avg_kl_per_token = kl_sum / num_tokens + else: + avg_kl_per_token = 0.0 + + kl_values.append(kl_sum) # Total KL for sequence + + # Compute statistics + result = TaskResult( + task_name=task_config.name, + task_type=task_config.task_type, + num_examples=len(dataset), + avg_score=float(np.mean(scores)) if scores else 0.0, + std_score=float(np.std(scores)) if scores else 0.0, + avg_kl=float(np.mean(kl_values)) if kl_values else 0.0, + std_kl=float(np.std(kl_values)) if kl_values else 0.0, + avg_response_length=float(np.mean(response_lengths)) if response_lengths else 0.0, + scores=scores, + kl_values=kl_values + ) + + logger.info( + f"Task {task_config.name}: " + f"Score={result.avg_score:.4f} (±{result.std_score:.4f}), " + f"KL={result.avg_kl:.4f} (±{result.std_kl:.4f})" + ) + + return result + + +def evaluate_base_model( + base_model: torch.nn.Module, + base_tokenizer, + task_config: EvalTaskConfig, + device: torch.device, + use_amp: bool = True +) -> Dict[str, float]: + """Evaluate the base model (for computing ΔJ).""" + dataset = load_dataset(task_config.dataset_path, task_config.num_samples) + + scores: List[float] = [] + base_model.eval() + + amp_dtype = torch.bfloat16 if use_amp else torch.float32 + + for example in tqdm(dataset, desc=f"Evaluating base on {task_config.name}"): + prompt = example.get("prompt", example.get("question", "")) + ground_truth = example.get("answer", example.get("solution", None)) + + inputs = base_tokenizer( + prompt, + return_tensors="pt", + truncation=True, + max_length=4096 + ) + input_ids = inputs["input_ids"].to(device) + attention_mask = inputs["attention_mask"].to(device) + + with torch.no_grad(): + with autocast(enabled=use_amp, dtype=amp_dtype): + generated_ids = base_model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + max_new_tokens=task_config.max_gen_len, + do_sample=True, + temperature=task_config.temperature, + top_p=task_config.top_p, + pad_token_id=base_tokenizer.eos_token_id + ) + + response_ids = generated_ids[:, input_ids.shape[1]:] + response_text = base_tokenizer.batch_decode( + response_ids, + skip_special_tokens=True + )[0] + + if task_config.is_verifiable and ground_truth is not None: + is_correct = verify_math_answer(response_text, str(ground_truth)) + score = 1.0 if is_correct else 0.0 + else: + score = 0.0 + + scores.append(score) + + return { + "avg_score": float(np.mean(scores)) if scores else 0.0, + "std_score": float(np.std(scores)) if scores else 0.0, + "num_examples": len(scores) + } + + +# ============================================================================ +# Main Evaluation Pipeline +# ============================================================================ + +def parse_args() -> argparse.Namespace: + """Parse command line arguments.""" + parser = argparse.ArgumentParser( + description="Evaluate RLVR trained models on multiple tasks" + ) + parser.add_argument( + "--base_ckpt", + type=str, + required=True, + help="Path to base model checkpoint" + ) + parser.add_argument( + "--ft_ckpt", + type=str, + required=True, + help="Path to finetuned model checkpoint" + ) + parser.add_argument( + "--eval_tasks_config", + type=str, + required=True, + help="Path to evaluation tasks configuration JSON" + ) + parser.add_argument( + "--output_path", + type=str, + required=True, + help="Path to save evaluation results" + ) + parser.add_argument( + "--device", + type=str, + default="cuda", + help="Device to use for evaluation" + ) + parser.add_argument( + "--eval_base", + action="store_true", + help="Also evaluate base model (for computing delta J)" + ) + parser.add_argument( + "--use_amp", + action="store_true", + default=True, + help="Use automatic mixed precision" + ) + return parser.parse_args() + + +def main() -> None: + """Main evaluation function.""" + args = parse_args() + + device = torch.device(args.device if torch.cuda.is_available() else "cpu") + logger.info(f"Using device: {device}") + + # Load tokenizers + logger.info(f"Loading base tokenizer from {args.base_ckpt}") + base_tokenizer = AutoTokenizer.from_pretrained( + args.base_ckpt, + use_fast=True, + trust_remote_code=True + ) + if base_tokenizer.pad_token is None: + base_tokenizer.pad_token = base_tokenizer.eos_token + + logger.info(f"Loading finetuned tokenizer from {args.ft_ckpt}") + ft_tokenizer = AutoTokenizer.from_pretrained( + args.ft_ckpt, + use_fast=True, + trust_remote_code=True + ) + if ft_tokenizer.pad_token is None: + ft_tokenizer.pad_token = ft_tokenizer.eos_token + + # Load models + logger.info(f"Loading base model from {args.base_ckpt}") + base_model = AutoModelForCausalLM.from_pretrained( + args.base_ckpt, + torch_dtype=torch.bfloat16, + device_map=None, + trust_remote_code=True + ).to(device) + base_model.eval() + + logger.info(f"Loading finetuned model from {args.ft_ckpt}") + ft_model = AutoModelForCausalLM.from_pretrained( + args.ft_ckpt, + torch_dtype=torch.bfloat16, + device_map=None, + trust_remote_code=True + ).to(device) + ft_model.eval() + + # Load evaluation tasks + eval_tasks = load_eval_tasks(args.eval_tasks_config) + + # Evaluate on all tasks + all_results: Dict[str, Any] = { + "base_ckpt": args.base_ckpt, + "ft_ckpt": args.ft_ckpt, + "tasks": {} + } + + for task in eval_tasks: + logger.info(f"\n{'='*60}") + logger.info(f"Evaluating task: {task.name}") + logger.info(f"{'='*60}") + + # Evaluate finetuned model + result = evaluate_task( + base_model=base_model, + base_tokenizer=base_tokenizer, + finetuned_model=ft_model, + finetuned_tokenizer=ft_tokenizer, + task_config=task, + device=device, + use_amp=args.use_amp + ) + + task_results = { + "ft_avg_score": result.avg_score, + "ft_std_score": result.std_score, + "avg_kl": result.avg_kl, + "std_kl": result.std_kl, + "avg_response_length": result.avg_response_length, + "num_examples": result.num_examples, + } + + # Optionally evaluate base model + if args.eval_base: + base_result = evaluate_base_model( + base_model=base_model, + base_tokenizer=base_tokenizer, + task_config=task, + device=device, + use_amp=args.use_amp + ) + task_results["base_avg_score"] = base_result["avg_score"] + task_results["base_std_score"] = base_result["std_score"] + task_results["delta_j"] = result.avg_score - base_result["avg_score"] + + all_results["tasks"][task.name] = task_results + + # Save results + os.makedirs(os.path.dirname(args.output_path), exist_ok=True) + with open(args.output_path, "w", encoding="utf-8") as f: + json.dump(all_results, f, indent=2) + + logger.info(f"\nResults saved to {args.output_path}") + + # Print summary + print("\n" + "="*80) + print("EVALUATION SUMMARY") + print("="*80) + for task_name, task_result in all_results["tasks"].items(): + print(f"\n{task_name}:") + print(f" Score: {task_result['ft_avg_score']:.4f} (±{task_result['ft_std_score']:.4f})") + print(f" KL: {task_result['avg_kl']:.4f} (±{task_result['std_kl']:.4f})") + if "delta_j" in task_result: + print(f" ΔJ: {task_result['delta_j']:+.4f}") + print("="*80) + + +if __name__ == "__main__": + main() + |
