#!/usr/bin/env python3 # eval_policy.py """ Policy Evaluation Script for RLVR Experiments. This script evaluates trained models on multiple tasks, computing: - J_k: Task performance (pass@1 accuracy for verifiable tasks) - KL_k: KL divergence from base model Usage: python eval_policy.py \ --base_ckpt Qwen/Qwen2.5-Math-7B \ --ft_ckpt results/train_logs/fp32_seed1/final_model \ --eval_tasks_config configs/eval_tasks_config.json \ --output_path results/eval_metrics/fp32_seed1.json """ import argparse import json import os import logging from typing import Dict, Any, List, Optional, Tuple from dataclasses import dataclass, asdict import numpy as np import torch from torch.cuda.amp import autocast from transformers import AutoModelForCausalLM, AutoTokenizer from tqdm import tqdm from config import EvalTaskConfig # Configure logging logging.basicConfig( level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s" ) logger = logging.getLogger(__name__) # ============================================================================ # Data Loading # ============================================================================ def load_eval_tasks(eval_config_path: str) -> List[EvalTaskConfig]: """Load evaluation task configurations from JSON file.""" with open(eval_config_path, "r", encoding="utf-8") as f: data = json.load(f) tasks: List[EvalTaskConfig] = [] for task_item in data: task = EvalTaskConfig( name=task_item.get("name", ""), task_type=task_item.get("task_type", "math"), dataset_path=task_item.get("dataset_path", ""), is_verifiable=task_item.get("is_verifiable", True), metric_type=task_item.get("metric_type", "accuracy"), num_samples=task_item.get("num_samples", -1), max_gen_len=task_item.get("max_gen_len", 2048), temperature=task_item.get("temperature", 0.7), top_p=task_item.get("top_p", 0.8), num_samples_per_prompt=task_item.get("num_samples_per_prompt", 1) ) tasks.append(task) logger.info(f"Loaded {len(tasks)} evaluation tasks from {eval_config_path}") return tasks def load_dataset(dataset_path: str, num_samples: int = -1) -> List[Dict[str, Any]]: """Load evaluation dataset from JSON file.""" with open(dataset_path, "r", encoding="utf-8") as f: data = json.load(f) if num_samples > 0 and num_samples < len(data): data = data[:num_samples] logger.info(f"Loaded {len(data)} examples from {dataset_path}") return data # ============================================================================ # Answer Verification # ============================================================================ def extract_boxed_answer(text: str) -> Optional[str]: """Extract answer from \\boxed{} format.""" import re # Find all \boxed{...} patterns pattern = r"\\boxed\{([^}]*)\}" matches = re.findall(pattern, text) if matches: return matches[-1].strip() # Return last match return None def extract_final_answer(text: str) -> Optional[str]: """Extract final answer using various heuristics.""" # Try boxed format first boxed = extract_boxed_answer(text) if boxed: return boxed # Common answer patterns patterns = [ r"[Tt]he (?:final )?answer is[:\s]+(.+?)(?:\.|$)", r"[Tt]herefore[,:\s]+(?:the answer is[:\s]+)?(.+?)(?:\.|$)", r"[Ss]o[,:\s]+(?:the answer is[:\s]+)?(.+?)(?:\.|$)", r"[Hh]ence[,:\s]+(.+?)(?:\.|$)", r"=\s*(.+?)$", ] import re for pattern in patterns: match = re.search(pattern, text, re.MULTILINE) if match: return match.group(1).strip() return None def normalize_answer(answer: str) -> str: """Normalize answer for comparison.""" if answer is None: return "" # Convert to lowercase, remove whitespace normalized = answer.lower().strip() # Remove common formatting normalized = normalized.replace(",", "") normalized = normalized.replace("$", "") normalized = normalized.replace("%", "") # Try to extract numeric value import re numeric_match = re.search(r"-?\d+\.?\d*", normalized) if numeric_match: return numeric_match.group() return normalized def verify_math_answer( response: str, ground_truth: str ) -> bool: """ Verify if the response contains the correct answer. This is a simplified verifier. For production use, replace with Eval-Chemy or a more sophisticated verification system. """ # Extract answers predicted = extract_final_answer(response) if predicted is None: return False # Normalize for comparison pred_normalized = normalize_answer(predicted) gt_normalized = normalize_answer(ground_truth) # Direct comparison if pred_normalized == gt_normalized: return True # Try numeric comparison try: pred_num = float(pred_normalized) gt_num = float(gt_normalized) if abs(pred_num - gt_num) < 1e-6: return True except ValueError: pass return False # ============================================================================ # KL Divergence Computation # ============================================================================ def compute_sequence_kl( finetuned_model: torch.nn.Module, base_model: torch.nn.Module, input_ids: torch.Tensor, attention_mask: torch.Tensor, response_start_idx: int, device: torch.device ) -> Tuple[float, int]: """ Compute KL divergence for a single sequence. KL(π_ft || π_base) ≈ Σ_t [log π_ft(y_t|x,y_{ TaskResult: """ Evaluate a single task. Computes: - avg_score: Mean accuracy (for verifiable tasks) - avg_kl: Mean KL divergence from base model """ dataset = load_dataset(task_config.dataset_path, task_config.num_samples) scores: List[float] = [] kl_values: List[float] = [] response_lengths: List[int] = [] finetuned_model.eval() base_model.eval() amp_dtype = torch.bfloat16 if use_amp else torch.float32 for example in tqdm(dataset, desc=f"Evaluating {task_config.name}"): prompt = example.get("prompt", example.get("question", "")) ground_truth = example.get("answer", example.get("solution", None)) # Tokenize prompt inputs = finetuned_tokenizer( prompt, return_tensors="pt", truncation=True, max_length=4096 ) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) prompt_len = input_ids.shape[1] # Generate response with torch.no_grad(): with autocast(enabled=use_amp, dtype=amp_dtype): generated_ids = finetuned_model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=task_config.max_gen_len, do_sample=True, temperature=task_config.temperature, top_p=task_config.top_p, pad_token_id=finetuned_tokenizer.eos_token_id ) # Decode response response_ids = generated_ids[:, prompt_len:] response_text = finetuned_tokenizer.batch_decode( response_ids, skip_special_tokens=True )[0] response_lengths.append(len(response_ids[0])) # Compute score (accuracy for verifiable tasks) if task_config.is_verifiable and ground_truth is not None: is_correct = verify_math_answer(response_text, str(ground_truth)) score = 1.0 if is_correct else 0.0 else: # For non-verifiable tasks, use placeholder score = 0.0 scores.append(score) # Compute KL divergence full_ids = generated_ids full_attention = torch.ones_like(full_ids, device=device) kl_sum, num_tokens = compute_sequence_kl( finetuned_model=finetuned_model, base_model=base_model, input_ids=full_ids, attention_mask=full_attention, response_start_idx=prompt_len, device=device ) if num_tokens > 0: avg_kl_per_token = kl_sum / num_tokens else: avg_kl_per_token = 0.0 kl_values.append(kl_sum) # Total KL for sequence # Compute statistics result = TaskResult( task_name=task_config.name, task_type=task_config.task_type, num_examples=len(dataset), avg_score=float(np.mean(scores)) if scores else 0.0, std_score=float(np.std(scores)) if scores else 0.0, avg_kl=float(np.mean(kl_values)) if kl_values else 0.0, std_kl=float(np.std(kl_values)) if kl_values else 0.0, avg_response_length=float(np.mean(response_lengths)) if response_lengths else 0.0, scores=scores, kl_values=kl_values ) logger.info( f"Task {task_config.name}: " f"Score={result.avg_score:.4f} (±{result.std_score:.4f}), " f"KL={result.avg_kl:.4f} (±{result.std_kl:.4f})" ) return result def evaluate_base_model( base_model: torch.nn.Module, base_tokenizer, task_config: EvalTaskConfig, device: torch.device, use_amp: bool = True ) -> Dict[str, float]: """Evaluate the base model (for computing ΔJ).""" dataset = load_dataset(task_config.dataset_path, task_config.num_samples) scores: List[float] = [] base_model.eval() amp_dtype = torch.bfloat16 if use_amp else torch.float32 for example in tqdm(dataset, desc=f"Evaluating base on {task_config.name}"): prompt = example.get("prompt", example.get("question", "")) ground_truth = example.get("answer", example.get("solution", None)) inputs = base_tokenizer( prompt, return_tensors="pt", truncation=True, max_length=4096 ) input_ids = inputs["input_ids"].to(device) attention_mask = inputs["attention_mask"].to(device) with torch.no_grad(): with autocast(enabled=use_amp, dtype=amp_dtype): generated_ids = base_model.generate( input_ids=input_ids, attention_mask=attention_mask, max_new_tokens=task_config.max_gen_len, do_sample=True, temperature=task_config.temperature, top_p=task_config.top_p, pad_token_id=base_tokenizer.eos_token_id ) response_ids = generated_ids[:, input_ids.shape[1]:] response_text = base_tokenizer.batch_decode( response_ids, skip_special_tokens=True )[0] if task_config.is_verifiable and ground_truth is not None: is_correct = verify_math_answer(response_text, str(ground_truth)) score = 1.0 if is_correct else 0.0 else: score = 0.0 scores.append(score) return { "avg_score": float(np.mean(scores)) if scores else 0.0, "std_score": float(np.std(scores)) if scores else 0.0, "num_examples": len(scores) } # ============================================================================ # Main Evaluation Pipeline # ============================================================================ def parse_args() -> argparse.Namespace: """Parse command line arguments.""" parser = argparse.ArgumentParser( description="Evaluate RLVR trained models on multiple tasks" ) parser.add_argument( "--base_ckpt", type=str, required=True, help="Path to base model checkpoint" ) parser.add_argument( "--ft_ckpt", type=str, required=True, help="Path to finetuned model checkpoint" ) parser.add_argument( "--eval_tasks_config", type=str, required=True, help="Path to evaluation tasks configuration JSON" ) parser.add_argument( "--output_path", type=str, required=True, help="Path to save evaluation results" ) parser.add_argument( "--device", type=str, default="cuda", help="Device to use for evaluation" ) parser.add_argument( "--eval_base", action="store_true", help="Also evaluate base model (for computing delta J)" ) parser.add_argument( "--use_amp", action="store_true", default=True, help="Use automatic mixed precision" ) return parser.parse_args() def main() -> None: """Main evaluation function.""" args = parse_args() device = torch.device(args.device if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") # Load tokenizers logger.info(f"Loading base tokenizer from {args.base_ckpt}") base_tokenizer = AutoTokenizer.from_pretrained( args.base_ckpt, use_fast=True, trust_remote_code=True ) if base_tokenizer.pad_token is None: base_tokenizer.pad_token = base_tokenizer.eos_token logger.info(f"Loading finetuned tokenizer from {args.ft_ckpt}") ft_tokenizer = AutoTokenizer.from_pretrained( args.ft_ckpt, use_fast=True, trust_remote_code=True ) if ft_tokenizer.pad_token is None: ft_tokenizer.pad_token = ft_tokenizer.eos_token # Load models logger.info(f"Loading base model from {args.base_ckpt}") base_model = AutoModelForCausalLM.from_pretrained( args.base_ckpt, torch_dtype=torch.bfloat16, device_map=None, trust_remote_code=True ).to(device) base_model.eval() logger.info(f"Loading finetuned model from {args.ft_ckpt}") ft_model = AutoModelForCausalLM.from_pretrained( args.ft_ckpt, torch_dtype=torch.bfloat16, device_map=None, trust_remote_code=True ).to(device) ft_model.eval() # Load evaluation tasks eval_tasks = load_eval_tasks(args.eval_tasks_config) # Evaluate on all tasks all_results: Dict[str, Any] = { "base_ckpt": args.base_ckpt, "ft_ckpt": args.ft_ckpt, "tasks": {} } for task in eval_tasks: logger.info(f"\n{'='*60}") logger.info(f"Evaluating task: {task.name}") logger.info(f"{'='*60}") # Evaluate finetuned model result = evaluate_task( base_model=base_model, base_tokenizer=base_tokenizer, finetuned_model=ft_model, finetuned_tokenizer=ft_tokenizer, task_config=task, device=device, use_amp=args.use_amp ) task_results = { "ft_avg_score": result.avg_score, "ft_std_score": result.std_score, "avg_kl": result.avg_kl, "std_kl": result.std_kl, "avg_response_length": result.avg_response_length, "num_examples": result.num_examples, } # Optionally evaluate base model if args.eval_base: base_result = evaluate_base_model( base_model=base_model, base_tokenizer=base_tokenizer, task_config=task, device=device, use_amp=args.use_amp ) task_results["base_avg_score"] = base_result["avg_score"] task_results["base_std_score"] = base_result["std_score"] task_results["delta_j"] = result.avg_score - base_result["avg_score"] all_results["tasks"][task.name] = task_results # Save results os.makedirs(os.path.dirname(args.output_path), exist_ok=True) with open(args.output_path, "w", encoding="utf-8") as f: json.dump(all_results, f, indent=2) logger.info(f"\nResults saved to {args.output_path}") # Print summary print("\n" + "="*80) print("EVALUATION SUMMARY") print("="*80) for task_name, task_result in all_results["tasks"].items(): print(f"\n{task_name}:") print(f" Score: {task_result['ft_avg_score']:.4f} (±{task_result['ft_std_score']:.4f})") print(f" KL: {task_result['avg_kl']:.4f} (±{task_result['std_kl']:.4f})") if "delta_j" in task_result: print(f" ΔJ: {task_result['delta_j']:+.4f}") print("="*80) if __name__ == "__main__": main()