# utils_math_eval.py """ Math Evaluation Utilities for RLVR Experiments. This module provides utilities for: - Extracting answers from model responses - Verifying mathematical answers - Computing accuracy metrics """ import re from typing import Optional, List, Dict, Any, Tuple import logging logger = logging.getLogger(__name__) # ============================================================================ # Answer Extraction # ============================================================================ def extract_boxed_content(text: str) -> List[str]: """ Extract all content from \\boxed{} patterns. Handles nested braces correctly. """ results = [] i = 0 while i < len(text): # Find \boxed{ idx = text.find("\\boxed{", i) if idx == -1: break # Find matching closing brace start = idx + 7 # len("\\boxed{") depth = 1 j = start while j < len(text) and depth > 0: if text[j] == "{": depth += 1 elif text[j] == "}": depth -= 1 j += 1 if depth == 0: content = text[start:j-1] results.append(content.strip()) i = j return results def extract_answer_from_boxed(text: str) -> Optional[str]: """Extract the last boxed answer from text.""" boxed_contents = extract_boxed_content(text) if boxed_contents: return boxed_contents[-1] return None def extract_answer_from_patterns(text: str) -> Optional[str]: """ Extract answer using common natural language patterns. """ # Patterns in order of priority patterns = [ # Explicit answer statements (r"[Tt]he\s+(?:final\s+)?answer\s+is\s*[:\s]*(.+?)(?:\.|,|$)", 1), (r"[Aa]nswer\s*[:\s]+(.+?)(?:\.|,|$)", 1), # Conclusion patterns (r"[Tt]herefore\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), (r"[Hh]ence\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), (r"[Ss]o\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), (r"[Tt]hus\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), # Equation result (r"=\s*(\S+)\s*$", 1), ] for pattern, group in patterns: match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE) if match: answer = match.group(group).strip() # Clean up trailing punctuation answer = re.sub(r"[.,;:!?]+$", "", answer).strip() if answer: return answer return None def extract_final_answer(text: str) -> Optional[str]: """ Extract the final answer from a model response. Priority: 1. \\boxed{} format 2. Natural language patterns """ # Try boxed format first boxed = extract_answer_from_boxed(text) if boxed: return boxed # Try natural language patterns pattern_answer = extract_answer_from_patterns(text) if pattern_answer: return pattern_answer return None # ============================================================================ # Answer Normalization # ============================================================================ def normalize_numeric_answer(answer: str) -> Optional[float]: """ Normalize a numeric answer for comparison. Handles: - Integers and decimals - Fractions (simple forms like a/b) - Scientific notation - Percentages """ if not answer: return None # Clean up the string cleaned = answer.strip().lower() cleaned = cleaned.replace(" ", "") cleaned = cleaned.replace(",", "") # Handle percentages if cleaned.endswith("%"): cleaned = cleaned[:-1] try: return float(cleaned) / 100 except ValueError: pass # Handle fractions (a/b) if "/" in cleaned: parts = cleaned.split("/") if len(parts) == 2: try: num = float(parts[0]) denom = float(parts[1]) if denom != 0: return num / denom except ValueError: pass # Handle scientific notation and regular numbers try: return float(cleaned) except ValueError: pass return None def normalize_text_answer(answer: str) -> str: """ Normalize a text answer for comparison. - Lowercase - Remove extra whitespace - Remove common formatting """ if not answer: return "" normalized = answer.strip().lower() # Remove LaTeX formatting normalized = re.sub(r"\\[a-zA-Z]+", "", normalized) normalized = re.sub(r"[{}$]", "", normalized) # Normalize whitespace normalized = " ".join(normalized.split()) # Remove common punctuation normalized = re.sub(r"[.,;:!?]+$", "", normalized).strip() return normalized # ============================================================================ # Answer Comparison # ============================================================================ def compare_numeric_answers( predicted: str, ground_truth: str, tolerance: float = 1e-6 ) -> bool: """ Compare two answers numerically. Returns True if both can be parsed as numbers and are within tolerance. """ pred_num = normalize_numeric_answer(predicted) gt_num = normalize_numeric_answer(ground_truth) if pred_num is None or gt_num is None: return False # Absolute tolerance for small numbers if abs(gt_num) < 1e-6: return abs(pred_num - gt_num) < tolerance # Relative tolerance for larger numbers rel_diff = abs(pred_num - gt_num) / abs(gt_num) return rel_diff < tolerance def compare_text_answers( predicted: str, ground_truth: str ) -> bool: """Compare two text answers after normalization.""" pred_norm = normalize_text_answer(predicted) gt_norm = normalize_text_answer(ground_truth) return pred_norm == gt_norm def verify_answer( response: str, ground_truth: str, task_type: str = "math" ) -> Tuple[bool, Optional[str]]: """ Verify if the response contains the correct answer. Args: response: Model's full response ground_truth: Expected answer task_type: Type of task ("math", "qa", "code") Returns: Tuple of (is_correct, extracted_answer) """ # Extract predicted answer predicted = extract_final_answer(response) if predicted is None: return False, None # Try numeric comparison first if compare_numeric_answers(predicted, ground_truth): return True, predicted # Try text comparison if compare_text_answers(predicted, ground_truth): return True, predicted # Check if ground truth is contained in predicted gt_norm = normalize_text_answer(ground_truth) pred_norm = normalize_text_answer(predicted) if gt_norm and gt_norm in pred_norm: return True, predicted return False, predicted # ============================================================================ # Batch Evaluation # ============================================================================ def evaluate_batch( responses: List[str], ground_truths: List[str], task_type: str = "math" ) -> Dict[str, Any]: """ Evaluate a batch of responses. Args: responses: List of model responses ground_truths: List of expected answers task_type: Type of task Returns: Dictionary with evaluation metrics """ assert len(responses) == len(ground_truths), \ "Number of responses must match ground truths" correct = 0 total = len(responses) results = [] for response, gt in zip(responses, ground_truths): is_correct, extracted = verify_answer(response, gt, task_type) correct += int(is_correct) results.append({ "is_correct": is_correct, "extracted_answer": extracted, "ground_truth": gt }) accuracy = correct / total if total > 0 else 0.0 return { "accuracy": accuracy, "correct": correct, "total": total, "results": results } # ============================================================================ # Answer Format Detection # ============================================================================ def detect_answer_format(text: str) -> str: """ Detect the format of an answer. Returns one of: "boxed", "numeric", "fraction", "text", "unknown" """ if "\\boxed{" in text: return "boxed" # Check for fraction if re.match(r"^-?\d+/\d+$", text.strip()): return "fraction" # Check for numeric try: float(text.strip().replace(",", "")) return "numeric" except ValueError: pass if text.strip(): return "text" return "unknown" def format_answer_for_display(answer: str, detected_format: str) -> str: """Format answer for display based on detected format.""" if detected_format == "fraction": num = normalize_numeric_answer(answer) if num is not None: return f"{answer} ≈ {num:.6f}" if detected_format == "numeric": try: num = float(answer.replace(",", "")) return f"{num:g}" except ValueError: pass return answer