summaryrefslogtreecommitdiff
path: root/eval_policy.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-04 18:59:35 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-04 18:59:35 -0600
commitf1c2cc22d46a6976df3555391e667c7e61592fad (patch)
tree0b37b52c8ff91042a742d3b3ec54542cb6d6e2f6 /eval_policy.py
Initial commit: RL floating-point noise projectHEADmain
Diffstat (limited to 'eval_policy.py')
-rw-r--r--eval_policy.py621
1 files changed, 621 insertions, 0 deletions
diff --git a/eval_policy.py b/eval_policy.py
new file mode 100644
index 0000000..cc30209
--- /dev/null
+++ b/eval_policy.py
@@ -0,0 +1,621 @@
+#!/usr/bin/env python3
+# eval_policy.py
+"""
+Policy Evaluation Script for RLVR Experiments.
+
+This script evaluates trained models on multiple tasks, computing:
+- J_k: Task performance (pass@1 accuracy for verifiable tasks)
+- KL_k: KL divergence from base model
+
+Usage:
+ python eval_policy.py \
+ --base_ckpt Qwen/Qwen2.5-Math-7B \
+ --ft_ckpt results/train_logs/fp32_seed1/final_model \
+ --eval_tasks_config configs/eval_tasks_config.json \
+ --output_path results/eval_metrics/fp32_seed1.json
+"""
+
+import argparse
+import json
+import os
+import logging
+from typing import Dict, Any, List, Optional, Tuple
+from dataclasses import dataclass, asdict
+
+import numpy as np
+import torch
+from torch.cuda.amp import autocast
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from tqdm import tqdm
+
+from config import EvalTaskConfig
+
+# Configure logging
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
+)
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Data Loading
+# ============================================================================
+
+def load_eval_tasks(eval_config_path: str) -> List[EvalTaskConfig]:
+ """Load evaluation task configurations from JSON file."""
+ with open(eval_config_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ tasks: List[EvalTaskConfig] = []
+ for task_item in data:
+ task = EvalTaskConfig(
+ name=task_item.get("name", ""),
+ task_type=task_item.get("task_type", "math"),
+ dataset_path=task_item.get("dataset_path", ""),
+ is_verifiable=task_item.get("is_verifiable", True),
+ metric_type=task_item.get("metric_type", "accuracy"),
+ num_samples=task_item.get("num_samples", -1),
+ max_gen_len=task_item.get("max_gen_len", 2048),
+ temperature=task_item.get("temperature", 0.7),
+ top_p=task_item.get("top_p", 0.8),
+ num_samples_per_prompt=task_item.get("num_samples_per_prompt", 1)
+ )
+ tasks.append(task)
+
+ logger.info(f"Loaded {len(tasks)} evaluation tasks from {eval_config_path}")
+ return tasks
+
+
+def load_dataset(dataset_path: str, num_samples: int = -1) -> List[Dict[str, Any]]:
+ """Load evaluation dataset from JSON file."""
+ with open(dataset_path, "r", encoding="utf-8") as f:
+ data = json.load(f)
+
+ if num_samples > 0 and num_samples < len(data):
+ data = data[:num_samples]
+
+ logger.info(f"Loaded {len(data)} examples from {dataset_path}")
+ return data
+
+
+# ============================================================================
+# Answer Verification
+# ============================================================================
+
+def extract_boxed_answer(text: str) -> Optional[str]:
+ """Extract answer from \\boxed{} format."""
+ import re
+
+ # Find all \boxed{...} patterns
+ pattern = r"\\boxed\{([^}]*)\}"
+ matches = re.findall(pattern, text)
+
+ if matches:
+ return matches[-1].strip() # Return last match
+
+ return None
+
+
+def extract_final_answer(text: str) -> Optional[str]:
+ """Extract final answer using various heuristics."""
+ # Try boxed format first
+ boxed = extract_boxed_answer(text)
+ if boxed:
+ return boxed
+
+ # Common answer patterns
+ patterns = [
+ r"[Tt]he (?:final )?answer is[:\s]+(.+?)(?:\.|$)",
+ r"[Tt]herefore[,:\s]+(?:the answer is[:\s]+)?(.+?)(?:\.|$)",
+ r"[Ss]o[,:\s]+(?:the answer is[:\s]+)?(.+?)(?:\.|$)",
+ r"[Hh]ence[,:\s]+(.+?)(?:\.|$)",
+ r"=\s*(.+?)$",
+ ]
+
+ import re
+ for pattern in patterns:
+ match = re.search(pattern, text, re.MULTILINE)
+ if match:
+ return match.group(1).strip()
+
+ return None
+
+
+def normalize_answer(answer: str) -> str:
+ """Normalize answer for comparison."""
+ if answer is None:
+ return ""
+
+ # Convert to lowercase, remove whitespace
+ normalized = answer.lower().strip()
+
+ # Remove common formatting
+ normalized = normalized.replace(",", "")
+ normalized = normalized.replace("$", "")
+ normalized = normalized.replace("%", "")
+
+ # Try to extract numeric value
+ import re
+ numeric_match = re.search(r"-?\d+\.?\d*", normalized)
+ if numeric_match:
+ return numeric_match.group()
+
+ return normalized
+
+
+def verify_math_answer(
+ response: str,
+ ground_truth: str
+) -> bool:
+ """
+ Verify if the response contains the correct answer.
+
+ This is a simplified verifier. For production use, replace with
+ Eval-Chemy or a more sophisticated verification system.
+ """
+ # Extract answers
+ predicted = extract_final_answer(response)
+
+ if predicted is None:
+ return False
+
+ # Normalize for comparison
+ pred_normalized = normalize_answer(predicted)
+ gt_normalized = normalize_answer(ground_truth)
+
+ # Direct comparison
+ if pred_normalized == gt_normalized:
+ return True
+
+ # Try numeric comparison
+ try:
+ pred_num = float(pred_normalized)
+ gt_num = float(gt_normalized)
+ if abs(pred_num - gt_num) < 1e-6:
+ return True
+ except ValueError:
+ pass
+
+ return False
+
+
+# ============================================================================
+# KL Divergence Computation
+# ============================================================================
+
+def compute_sequence_kl(
+ finetuned_model: torch.nn.Module,
+ base_model: torch.nn.Module,
+ input_ids: torch.Tensor,
+ attention_mask: torch.Tensor,
+ response_start_idx: int,
+ device: torch.device
+) -> Tuple[float, int]:
+ """
+ Compute KL divergence for a single sequence.
+
+ KL(π_ft || π_base) ≈ Σ_t [log π_ft(y_t|x,y_{<t}) - log π_base(y_t|x,y_{<t})]
+
+ Returns:
+ Tuple of (kl_sum, num_tokens)
+ """
+ with torch.no_grad():
+ # Get logits from both models
+ ft_outputs = finetuned_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask
+ )
+ base_outputs = base_model(
+ input_ids=input_ids,
+ attention_mask=attention_mask
+ )
+
+ ft_logits = ft_outputs.logits
+ base_logits = base_outputs.logits
+
+ # Compute log probabilities
+ ft_log_probs = torch.log_softmax(ft_logits, dim=-1)
+ base_log_probs = torch.log_softmax(base_logits, dim=-1)
+
+ # Get log probs for actual tokens (shifted for autoregressive)
+ shift_ft_log_probs = ft_log_probs[:, :-1, :]
+ shift_base_log_probs = base_log_probs[:, :-1, :]
+ shift_labels = input_ids[:, 1:]
+
+ ft_token_log_probs = torch.gather(
+ shift_ft_log_probs,
+ dim=-1,
+ index=shift_labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ base_token_log_probs = torch.gather(
+ shift_base_log_probs,
+ dim=-1,
+ index=shift_labels.unsqueeze(-1)
+ ).squeeze(-1)
+
+ # Compute KL only for response tokens
+ kl_per_token = ft_token_log_probs - base_token_log_probs
+
+ # Mask for response tokens only
+ response_mask = torch.zeros_like(kl_per_token)
+ response_mask[:, response_start_idx-1:] = 1.0
+
+ # Apply attention mask
+ valid_mask = attention_mask[:, 1:].float() * response_mask
+
+ kl_sum = (kl_per_token * valid_mask).sum().item()
+ num_tokens = valid_mask.sum().item()
+
+ return kl_sum, int(num_tokens)
+
+
+# ============================================================================
+# Evaluation Functions
+# ============================================================================
+
+@dataclass
+class TaskResult:
+ """Results for a single evaluation task."""
+ task_name: str
+ task_type: str
+ num_examples: int
+ avg_score: float
+ std_score: float
+ avg_kl: float
+ std_kl: float
+ avg_response_length: float
+ scores: List[float]
+ kl_values: List[float]
+
+
+def evaluate_task(
+ base_model: torch.nn.Module,
+ base_tokenizer,
+ finetuned_model: torch.nn.Module,
+ finetuned_tokenizer,
+ task_config: EvalTaskConfig,
+ device: torch.device,
+ use_amp: bool = True
+) -> TaskResult:
+ """
+ Evaluate a single task.
+
+ Computes:
+ - avg_score: Mean accuracy (for verifiable tasks)
+ - avg_kl: Mean KL divergence from base model
+ """
+ dataset = load_dataset(task_config.dataset_path, task_config.num_samples)
+
+ scores: List[float] = []
+ kl_values: List[float] = []
+ response_lengths: List[int] = []
+
+ finetuned_model.eval()
+ base_model.eval()
+
+ amp_dtype = torch.bfloat16 if use_amp else torch.float32
+
+ for example in tqdm(dataset, desc=f"Evaluating {task_config.name}"):
+ prompt = example.get("prompt", example.get("question", ""))
+ ground_truth = example.get("answer", example.get("solution", None))
+
+ # Tokenize prompt
+ inputs = finetuned_tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=4096
+ )
+ input_ids = inputs["input_ids"].to(device)
+ attention_mask = inputs["attention_mask"].to(device)
+ prompt_len = input_ids.shape[1]
+
+ # Generate response
+ with torch.no_grad():
+ with autocast(enabled=use_amp, dtype=amp_dtype):
+ generated_ids = finetuned_model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_new_tokens=task_config.max_gen_len,
+ do_sample=True,
+ temperature=task_config.temperature,
+ top_p=task_config.top_p,
+ pad_token_id=finetuned_tokenizer.eos_token_id
+ )
+
+ # Decode response
+ response_ids = generated_ids[:, prompt_len:]
+ response_text = finetuned_tokenizer.batch_decode(
+ response_ids,
+ skip_special_tokens=True
+ )[0]
+
+ response_lengths.append(len(response_ids[0]))
+
+ # Compute score (accuracy for verifiable tasks)
+ if task_config.is_verifiable and ground_truth is not None:
+ is_correct = verify_math_answer(response_text, str(ground_truth))
+ score = 1.0 if is_correct else 0.0
+ else:
+ # For non-verifiable tasks, use placeholder
+ score = 0.0
+
+ scores.append(score)
+
+ # Compute KL divergence
+ full_ids = generated_ids
+ full_attention = torch.ones_like(full_ids, device=device)
+
+ kl_sum, num_tokens = compute_sequence_kl(
+ finetuned_model=finetuned_model,
+ base_model=base_model,
+ input_ids=full_ids,
+ attention_mask=full_attention,
+ response_start_idx=prompt_len,
+ device=device
+ )
+
+ if num_tokens > 0:
+ avg_kl_per_token = kl_sum / num_tokens
+ else:
+ avg_kl_per_token = 0.0
+
+ kl_values.append(kl_sum) # Total KL for sequence
+
+ # Compute statistics
+ result = TaskResult(
+ task_name=task_config.name,
+ task_type=task_config.task_type,
+ num_examples=len(dataset),
+ avg_score=float(np.mean(scores)) if scores else 0.0,
+ std_score=float(np.std(scores)) if scores else 0.0,
+ avg_kl=float(np.mean(kl_values)) if kl_values else 0.0,
+ std_kl=float(np.std(kl_values)) if kl_values else 0.0,
+ avg_response_length=float(np.mean(response_lengths)) if response_lengths else 0.0,
+ scores=scores,
+ kl_values=kl_values
+ )
+
+ logger.info(
+ f"Task {task_config.name}: "
+ f"Score={result.avg_score:.4f} (±{result.std_score:.4f}), "
+ f"KL={result.avg_kl:.4f} (±{result.std_kl:.4f})"
+ )
+
+ return result
+
+
+def evaluate_base_model(
+ base_model: torch.nn.Module,
+ base_tokenizer,
+ task_config: EvalTaskConfig,
+ device: torch.device,
+ use_amp: bool = True
+) -> Dict[str, float]:
+ """Evaluate the base model (for computing ΔJ)."""
+ dataset = load_dataset(task_config.dataset_path, task_config.num_samples)
+
+ scores: List[float] = []
+ base_model.eval()
+
+ amp_dtype = torch.bfloat16 if use_amp else torch.float32
+
+ for example in tqdm(dataset, desc=f"Evaluating base on {task_config.name}"):
+ prompt = example.get("prompt", example.get("question", ""))
+ ground_truth = example.get("answer", example.get("solution", None))
+
+ inputs = base_tokenizer(
+ prompt,
+ return_tensors="pt",
+ truncation=True,
+ max_length=4096
+ )
+ input_ids = inputs["input_ids"].to(device)
+ attention_mask = inputs["attention_mask"].to(device)
+
+ with torch.no_grad():
+ with autocast(enabled=use_amp, dtype=amp_dtype):
+ generated_ids = base_model.generate(
+ input_ids=input_ids,
+ attention_mask=attention_mask,
+ max_new_tokens=task_config.max_gen_len,
+ do_sample=True,
+ temperature=task_config.temperature,
+ top_p=task_config.top_p,
+ pad_token_id=base_tokenizer.eos_token_id
+ )
+
+ response_ids = generated_ids[:, input_ids.shape[1]:]
+ response_text = base_tokenizer.batch_decode(
+ response_ids,
+ skip_special_tokens=True
+ )[0]
+
+ if task_config.is_verifiable and ground_truth is not None:
+ is_correct = verify_math_answer(response_text, str(ground_truth))
+ score = 1.0 if is_correct else 0.0
+ else:
+ score = 0.0
+
+ scores.append(score)
+
+ return {
+ "avg_score": float(np.mean(scores)) if scores else 0.0,
+ "std_score": float(np.std(scores)) if scores else 0.0,
+ "num_examples": len(scores)
+ }
+
+
+# ============================================================================
+# Main Evaluation Pipeline
+# ============================================================================
+
+def parse_args() -> argparse.Namespace:
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Evaluate RLVR trained models on multiple tasks"
+ )
+ parser.add_argument(
+ "--base_ckpt",
+ type=str,
+ required=True,
+ help="Path to base model checkpoint"
+ )
+ parser.add_argument(
+ "--ft_ckpt",
+ type=str,
+ required=True,
+ help="Path to finetuned model checkpoint"
+ )
+ parser.add_argument(
+ "--eval_tasks_config",
+ type=str,
+ required=True,
+ help="Path to evaluation tasks configuration JSON"
+ )
+ parser.add_argument(
+ "--output_path",
+ type=str,
+ required=True,
+ help="Path to save evaluation results"
+ )
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="cuda",
+ help="Device to use for evaluation"
+ )
+ parser.add_argument(
+ "--eval_base",
+ action="store_true",
+ help="Also evaluate base model (for computing delta J)"
+ )
+ parser.add_argument(
+ "--use_amp",
+ action="store_true",
+ default=True,
+ help="Use automatic mixed precision"
+ )
+ return parser.parse_args()
+
+
+def main() -> None:
+ """Main evaluation function."""
+ args = parse_args()
+
+ device = torch.device(args.device if torch.cuda.is_available() else "cpu")
+ logger.info(f"Using device: {device}")
+
+ # Load tokenizers
+ logger.info(f"Loading base tokenizer from {args.base_ckpt}")
+ base_tokenizer = AutoTokenizer.from_pretrained(
+ args.base_ckpt,
+ use_fast=True,
+ trust_remote_code=True
+ )
+ if base_tokenizer.pad_token is None:
+ base_tokenizer.pad_token = base_tokenizer.eos_token
+
+ logger.info(f"Loading finetuned tokenizer from {args.ft_ckpt}")
+ ft_tokenizer = AutoTokenizer.from_pretrained(
+ args.ft_ckpt,
+ use_fast=True,
+ trust_remote_code=True
+ )
+ if ft_tokenizer.pad_token is None:
+ ft_tokenizer.pad_token = ft_tokenizer.eos_token
+
+ # Load models
+ logger.info(f"Loading base model from {args.base_ckpt}")
+ base_model = AutoModelForCausalLM.from_pretrained(
+ args.base_ckpt,
+ torch_dtype=torch.bfloat16,
+ device_map=None,
+ trust_remote_code=True
+ ).to(device)
+ base_model.eval()
+
+ logger.info(f"Loading finetuned model from {args.ft_ckpt}")
+ ft_model = AutoModelForCausalLM.from_pretrained(
+ args.ft_ckpt,
+ torch_dtype=torch.bfloat16,
+ device_map=None,
+ trust_remote_code=True
+ ).to(device)
+ ft_model.eval()
+
+ # Load evaluation tasks
+ eval_tasks = load_eval_tasks(args.eval_tasks_config)
+
+ # Evaluate on all tasks
+ all_results: Dict[str, Any] = {
+ "base_ckpt": args.base_ckpt,
+ "ft_ckpt": args.ft_ckpt,
+ "tasks": {}
+ }
+
+ for task in eval_tasks:
+ logger.info(f"\n{'='*60}")
+ logger.info(f"Evaluating task: {task.name}")
+ logger.info(f"{'='*60}")
+
+ # Evaluate finetuned model
+ result = evaluate_task(
+ base_model=base_model,
+ base_tokenizer=base_tokenizer,
+ finetuned_model=ft_model,
+ finetuned_tokenizer=ft_tokenizer,
+ task_config=task,
+ device=device,
+ use_amp=args.use_amp
+ )
+
+ task_results = {
+ "ft_avg_score": result.avg_score,
+ "ft_std_score": result.std_score,
+ "avg_kl": result.avg_kl,
+ "std_kl": result.std_kl,
+ "avg_response_length": result.avg_response_length,
+ "num_examples": result.num_examples,
+ }
+
+ # Optionally evaluate base model
+ if args.eval_base:
+ base_result = evaluate_base_model(
+ base_model=base_model,
+ base_tokenizer=base_tokenizer,
+ task_config=task,
+ device=device,
+ use_amp=args.use_amp
+ )
+ task_results["base_avg_score"] = base_result["avg_score"]
+ task_results["base_std_score"] = base_result["std_score"]
+ task_results["delta_j"] = result.avg_score - base_result["avg_score"]
+
+ all_results["tasks"][task.name] = task_results
+
+ # Save results
+ os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
+ with open(args.output_path, "w", encoding="utf-8") as f:
+ json.dump(all_results, f, indent=2)
+
+ logger.info(f"\nResults saved to {args.output_path}")
+
+ # Print summary
+ print("\n" + "="*80)
+ print("EVALUATION SUMMARY")
+ print("="*80)
+ for task_name, task_result in all_results["tasks"].items():
+ print(f"\n{task_name}:")
+ print(f" Score: {task_result['ft_avg_score']:.4f} (±{task_result['ft_std_score']:.4f})")
+ print(f" KL: {task_result['avg_kl']:.4f} (±{task_result['std_kl']:.4f})")
+ if "delta_j" in task_result:
+ print(f" ΔJ: {task_result['delta_j']:+.4f}")
+ print("="*80)
+
+
+if __name__ == "__main__":
+ main()
+