summaryrefslogtreecommitdiff
path: root/utils_math_eval.py
diff options
context:
space:
mode:
Diffstat (limited to 'utils_math_eval.py')
-rw-r--r--utils_math_eval.py367
1 files changed, 367 insertions, 0 deletions
diff --git a/utils_math_eval.py b/utils_math_eval.py
new file mode 100644
index 0000000..d4a1db2
--- /dev/null
+++ b/utils_math_eval.py
@@ -0,0 +1,367 @@
+# utils_math_eval.py
+"""
+Math Evaluation Utilities for RLVR Experiments.
+
+This module provides utilities for:
+- Extracting answers from model responses
+- Verifying mathematical answers
+- Computing accuracy metrics
+"""
+
+import re
+from typing import Optional, List, Dict, Any, Tuple
+import logging
+
+logger = logging.getLogger(__name__)
+
+
+# ============================================================================
+# Answer Extraction
+# ============================================================================
+
+def extract_boxed_content(text: str) -> List[str]:
+ """
+ Extract all content from \\boxed{} patterns.
+
+ Handles nested braces correctly.
+ """
+ results = []
+ i = 0
+
+ while i < len(text):
+ # Find \boxed{
+ idx = text.find("\\boxed{", i)
+ if idx == -1:
+ break
+
+ # Find matching closing brace
+ start = idx + 7 # len("\\boxed{")
+ depth = 1
+ j = start
+
+ while j < len(text) and depth > 0:
+ if text[j] == "{":
+ depth += 1
+ elif text[j] == "}":
+ depth -= 1
+ j += 1
+
+ if depth == 0:
+ content = text[start:j-1]
+ results.append(content.strip())
+
+ i = j
+
+ return results
+
+
+def extract_answer_from_boxed(text: str) -> Optional[str]:
+ """Extract the last boxed answer from text."""
+ boxed_contents = extract_boxed_content(text)
+ if boxed_contents:
+ return boxed_contents[-1]
+ return None
+
+
+def extract_answer_from_patterns(text: str) -> Optional[str]:
+ """
+ Extract answer using common natural language patterns.
+ """
+ # Patterns in order of priority
+ patterns = [
+ # Explicit answer statements
+ (r"[Tt]he\s+(?:final\s+)?answer\s+is\s*[:\s]*(.+?)(?:\.|,|$)", 1),
+ (r"[Aa]nswer\s*[:\s]+(.+?)(?:\.|,|$)", 1),
+
+ # Conclusion patterns
+ (r"[Tt]herefore\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1),
+ (r"[Hh]ence\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1),
+ (r"[Ss]o\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1),
+ (r"[Tt]hus\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1),
+
+ # Equation result
+ (r"=\s*(\S+)\s*$", 1),
+ ]
+
+ for pattern, group in patterns:
+ match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE)
+ if match:
+ answer = match.group(group).strip()
+ # Clean up trailing punctuation
+ answer = re.sub(r"[.,;:!?]+$", "", answer).strip()
+ if answer:
+ return answer
+
+ return None
+
+
+def extract_final_answer(text: str) -> Optional[str]:
+ """
+ Extract the final answer from a model response.
+
+ Priority:
+ 1. \\boxed{} format
+ 2. Natural language patterns
+ """
+ # Try boxed format first
+ boxed = extract_answer_from_boxed(text)
+ if boxed:
+ return boxed
+
+ # Try natural language patterns
+ pattern_answer = extract_answer_from_patterns(text)
+ if pattern_answer:
+ return pattern_answer
+
+ return None
+
+
+# ============================================================================
+# Answer Normalization
+# ============================================================================
+
+def normalize_numeric_answer(answer: str) -> Optional[float]:
+ """
+ Normalize a numeric answer for comparison.
+
+ Handles:
+ - Integers and decimals
+ - Fractions (simple forms like a/b)
+ - Scientific notation
+ - Percentages
+ """
+ if not answer:
+ return None
+
+ # Clean up the string
+ cleaned = answer.strip().lower()
+ cleaned = cleaned.replace(" ", "")
+ cleaned = cleaned.replace(",", "")
+
+ # Handle percentages
+ if cleaned.endswith("%"):
+ cleaned = cleaned[:-1]
+ try:
+ return float(cleaned) / 100
+ except ValueError:
+ pass
+
+ # Handle fractions (a/b)
+ if "/" in cleaned:
+ parts = cleaned.split("/")
+ if len(parts) == 2:
+ try:
+ num = float(parts[0])
+ denom = float(parts[1])
+ if denom != 0:
+ return num / denom
+ except ValueError:
+ pass
+
+ # Handle scientific notation and regular numbers
+ try:
+ return float(cleaned)
+ except ValueError:
+ pass
+
+ return None
+
+
+def normalize_text_answer(answer: str) -> str:
+ """
+ Normalize a text answer for comparison.
+
+ - Lowercase
+ - Remove extra whitespace
+ - Remove common formatting
+ """
+ if not answer:
+ return ""
+
+ normalized = answer.strip().lower()
+
+ # Remove LaTeX formatting
+ normalized = re.sub(r"\\[a-zA-Z]+", "", normalized)
+ normalized = re.sub(r"[{}$]", "", normalized)
+
+ # Normalize whitespace
+ normalized = " ".join(normalized.split())
+
+ # Remove common punctuation
+ normalized = re.sub(r"[.,;:!?]+$", "", normalized).strip()
+
+ return normalized
+
+
+# ============================================================================
+# Answer Comparison
+# ============================================================================
+
+def compare_numeric_answers(
+ predicted: str,
+ ground_truth: str,
+ tolerance: float = 1e-6
+) -> bool:
+ """
+ Compare two answers numerically.
+
+ Returns True if both can be parsed as numbers and are within tolerance.
+ """
+ pred_num = normalize_numeric_answer(predicted)
+ gt_num = normalize_numeric_answer(ground_truth)
+
+ if pred_num is None or gt_num is None:
+ return False
+
+ # Absolute tolerance for small numbers
+ if abs(gt_num) < 1e-6:
+ return abs(pred_num - gt_num) < tolerance
+
+ # Relative tolerance for larger numbers
+ rel_diff = abs(pred_num - gt_num) / abs(gt_num)
+ return rel_diff < tolerance
+
+
+def compare_text_answers(
+ predicted: str,
+ ground_truth: str
+) -> bool:
+ """Compare two text answers after normalization."""
+ pred_norm = normalize_text_answer(predicted)
+ gt_norm = normalize_text_answer(ground_truth)
+
+ return pred_norm == gt_norm
+
+
+def verify_answer(
+ response: str,
+ ground_truth: str,
+ task_type: str = "math"
+) -> Tuple[bool, Optional[str]]:
+ """
+ Verify if the response contains the correct answer.
+
+ Args:
+ response: Model's full response
+ ground_truth: Expected answer
+ task_type: Type of task ("math", "qa", "code")
+
+ Returns:
+ Tuple of (is_correct, extracted_answer)
+ """
+ # Extract predicted answer
+ predicted = extract_final_answer(response)
+
+ if predicted is None:
+ return False, None
+
+ # Try numeric comparison first
+ if compare_numeric_answers(predicted, ground_truth):
+ return True, predicted
+
+ # Try text comparison
+ if compare_text_answers(predicted, ground_truth):
+ return True, predicted
+
+ # Check if ground truth is contained in predicted
+ gt_norm = normalize_text_answer(ground_truth)
+ pred_norm = normalize_text_answer(predicted)
+
+ if gt_norm and gt_norm in pred_norm:
+ return True, predicted
+
+ return False, predicted
+
+
+# ============================================================================
+# Batch Evaluation
+# ============================================================================
+
+def evaluate_batch(
+ responses: List[str],
+ ground_truths: List[str],
+ task_type: str = "math"
+) -> Dict[str, Any]:
+ """
+ Evaluate a batch of responses.
+
+ Args:
+ responses: List of model responses
+ ground_truths: List of expected answers
+ task_type: Type of task
+
+ Returns:
+ Dictionary with evaluation metrics
+ """
+ assert len(responses) == len(ground_truths), \
+ "Number of responses must match ground truths"
+
+ correct = 0
+ total = len(responses)
+ results = []
+
+ for response, gt in zip(responses, ground_truths):
+ is_correct, extracted = verify_answer(response, gt, task_type)
+ correct += int(is_correct)
+ results.append({
+ "is_correct": is_correct,
+ "extracted_answer": extracted,
+ "ground_truth": gt
+ })
+
+ accuracy = correct / total if total > 0 else 0.0
+
+ return {
+ "accuracy": accuracy,
+ "correct": correct,
+ "total": total,
+ "results": results
+ }
+
+
+# ============================================================================
+# Answer Format Detection
+# ============================================================================
+
+def detect_answer_format(text: str) -> str:
+ """
+ Detect the format of an answer.
+
+ Returns one of: "boxed", "numeric", "fraction", "text", "unknown"
+ """
+ if "\\boxed{" in text:
+ return "boxed"
+
+ # Check for fraction
+ if re.match(r"^-?\d+/\d+$", text.strip()):
+ return "fraction"
+
+ # Check for numeric
+ try:
+ float(text.strip().replace(",", ""))
+ return "numeric"
+ except ValueError:
+ pass
+
+ if text.strip():
+ return "text"
+
+ return "unknown"
+
+
+def format_answer_for_display(answer: str, detected_format: str) -> str:
+ """Format answer for display based on detected format."""
+ if detected_format == "fraction":
+ num = normalize_numeric_answer(answer)
+ if num is not None:
+ return f"{answer} ≈ {num:.6f}"
+
+ if detected_format == "numeric":
+ try:
+ num = float(answer.replace(",", ""))
+ return f"{num:g}"
+ except ValueError:
+ pass
+
+ return answer
+