From f1c2cc22d46a6976df3555391e667c7e61592fad Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Wed, 4 Feb 2026 18:59:35 -0600 Subject: Initial commit: RL floating-point noise project --- utils_math_eval.py | 367 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 367 insertions(+) create mode 100644 utils_math_eval.py (limited to 'utils_math_eval.py') diff --git a/utils_math_eval.py b/utils_math_eval.py new file mode 100644 index 0000000..d4a1db2 --- /dev/null +++ b/utils_math_eval.py @@ -0,0 +1,367 @@ +# utils_math_eval.py +""" +Math Evaluation Utilities for RLVR Experiments. + +This module provides utilities for: +- Extracting answers from model responses +- Verifying mathematical answers +- Computing accuracy metrics +""" + +import re +from typing import Optional, List, Dict, Any, Tuple +import logging + +logger = logging.getLogger(__name__) + + +# ============================================================================ +# Answer Extraction +# ============================================================================ + +def extract_boxed_content(text: str) -> List[str]: + """ + Extract all content from \\boxed{} patterns. + + Handles nested braces correctly. + """ + results = [] + i = 0 + + while i < len(text): + # Find \boxed{ + idx = text.find("\\boxed{", i) + if idx == -1: + break + + # Find matching closing brace + start = idx + 7 # len("\\boxed{") + depth = 1 + j = start + + while j < len(text) and depth > 0: + if text[j] == "{": + depth += 1 + elif text[j] == "}": + depth -= 1 + j += 1 + + if depth == 0: + content = text[start:j-1] + results.append(content.strip()) + + i = j + + return results + + +def extract_answer_from_boxed(text: str) -> Optional[str]: + """Extract the last boxed answer from text.""" + boxed_contents = extract_boxed_content(text) + if boxed_contents: + return boxed_contents[-1] + return None + + +def extract_answer_from_patterns(text: str) -> Optional[str]: + """ + Extract answer using common natural language patterns. + """ + # Patterns in order of priority + patterns = [ + # Explicit answer statements + (r"[Tt]he\s+(?:final\s+)?answer\s+is\s*[:\s]*(.+?)(?:\.|,|$)", 1), + (r"[Aa]nswer\s*[:\s]+(.+?)(?:\.|,|$)", 1), + + # Conclusion patterns + (r"[Tt]herefore\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), + (r"[Hh]ence\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), + (r"[Ss]o\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), + (r"[Tt]hus\s*,?\s*(?:the\s+answer\s+is\s*)?(.+?)(?:\.|$)", 1), + + # Equation result + (r"=\s*(\S+)\s*$", 1), + ] + + for pattern, group in patterns: + match = re.search(pattern, text, re.MULTILINE | re.IGNORECASE) + if match: + answer = match.group(group).strip() + # Clean up trailing punctuation + answer = re.sub(r"[.,;:!?]+$", "", answer).strip() + if answer: + return answer + + return None + + +def extract_final_answer(text: str) -> Optional[str]: + """ + Extract the final answer from a model response. + + Priority: + 1. \\boxed{} format + 2. Natural language patterns + """ + # Try boxed format first + boxed = extract_answer_from_boxed(text) + if boxed: + return boxed + + # Try natural language patterns + pattern_answer = extract_answer_from_patterns(text) + if pattern_answer: + return pattern_answer + + return None + + +# ============================================================================ +# Answer Normalization +# ============================================================================ + +def normalize_numeric_answer(answer: str) -> Optional[float]: + """ + Normalize a numeric answer for comparison. + + Handles: + - Integers and decimals + - Fractions (simple forms like a/b) + - Scientific notation + - Percentages + """ + if not answer: + return None + + # Clean up the string + cleaned = answer.strip().lower() + cleaned = cleaned.replace(" ", "") + cleaned = cleaned.replace(",", "") + + # Handle percentages + if cleaned.endswith("%"): + cleaned = cleaned[:-1] + try: + return float(cleaned) / 100 + except ValueError: + pass + + # Handle fractions (a/b) + if "/" in cleaned: + parts = cleaned.split("/") + if len(parts) == 2: + try: + num = float(parts[0]) + denom = float(parts[1]) + if denom != 0: + return num / denom + except ValueError: + pass + + # Handle scientific notation and regular numbers + try: + return float(cleaned) + except ValueError: + pass + + return None + + +def normalize_text_answer(answer: str) -> str: + """ + Normalize a text answer for comparison. + + - Lowercase + - Remove extra whitespace + - Remove common formatting + """ + if not answer: + return "" + + normalized = answer.strip().lower() + + # Remove LaTeX formatting + normalized = re.sub(r"\\[a-zA-Z]+", "", normalized) + normalized = re.sub(r"[{}$]", "", normalized) + + # Normalize whitespace + normalized = " ".join(normalized.split()) + + # Remove common punctuation + normalized = re.sub(r"[.,;:!?]+$", "", normalized).strip() + + return normalized + + +# ============================================================================ +# Answer Comparison +# ============================================================================ + +def compare_numeric_answers( + predicted: str, + ground_truth: str, + tolerance: float = 1e-6 +) -> bool: + """ + Compare two answers numerically. + + Returns True if both can be parsed as numbers and are within tolerance. + """ + pred_num = normalize_numeric_answer(predicted) + gt_num = normalize_numeric_answer(ground_truth) + + if pred_num is None or gt_num is None: + return False + + # Absolute tolerance for small numbers + if abs(gt_num) < 1e-6: + return abs(pred_num - gt_num) < tolerance + + # Relative tolerance for larger numbers + rel_diff = abs(pred_num - gt_num) / abs(gt_num) + return rel_diff < tolerance + + +def compare_text_answers( + predicted: str, + ground_truth: str +) -> bool: + """Compare two text answers after normalization.""" + pred_norm = normalize_text_answer(predicted) + gt_norm = normalize_text_answer(ground_truth) + + return pred_norm == gt_norm + + +def verify_answer( + response: str, + ground_truth: str, + task_type: str = "math" +) -> Tuple[bool, Optional[str]]: + """ + Verify if the response contains the correct answer. + + Args: + response: Model's full response + ground_truth: Expected answer + task_type: Type of task ("math", "qa", "code") + + Returns: + Tuple of (is_correct, extracted_answer) + """ + # Extract predicted answer + predicted = extract_final_answer(response) + + if predicted is None: + return False, None + + # Try numeric comparison first + if compare_numeric_answers(predicted, ground_truth): + return True, predicted + + # Try text comparison + if compare_text_answers(predicted, ground_truth): + return True, predicted + + # Check if ground truth is contained in predicted + gt_norm = normalize_text_answer(ground_truth) + pred_norm = normalize_text_answer(predicted) + + if gt_norm and gt_norm in pred_norm: + return True, predicted + + return False, predicted + + +# ============================================================================ +# Batch Evaluation +# ============================================================================ + +def evaluate_batch( + responses: List[str], + ground_truths: List[str], + task_type: str = "math" +) -> Dict[str, Any]: + """ + Evaluate a batch of responses. + + Args: + responses: List of model responses + ground_truths: List of expected answers + task_type: Type of task + + Returns: + Dictionary with evaluation metrics + """ + assert len(responses) == len(ground_truths), \ + "Number of responses must match ground truths" + + correct = 0 + total = len(responses) + results = [] + + for response, gt in zip(responses, ground_truths): + is_correct, extracted = verify_answer(response, gt, task_type) + correct += int(is_correct) + results.append({ + "is_correct": is_correct, + "extracted_answer": extracted, + "ground_truth": gt + }) + + accuracy = correct / total if total > 0 else 0.0 + + return { + "accuracy": accuracy, + "correct": correct, + "total": total, + "results": results + } + + +# ============================================================================ +# Answer Format Detection +# ============================================================================ + +def detect_answer_format(text: str) -> str: + """ + Detect the format of an answer. + + Returns one of: "boxed", "numeric", "fraction", "text", "unknown" + """ + if "\\boxed{" in text: + return "boxed" + + # Check for fraction + if re.match(r"^-?\d+/\d+$", text.strip()): + return "fraction" + + # Check for numeric + try: + float(text.strip().replace(",", "")) + return "numeric" + except ValueError: + pass + + if text.strip(): + return "text" + + return "unknown" + + +def format_answer_for_display(answer: str, detected_format: str) -> str: + """Format answer for display based on detected format.""" + if detected_format == "fraction": + num = normalize_numeric_answer(answer) + if num is not None: + return f"{answer} ≈ {num:.6f}" + + if detected_format == "numeric": + try: + num = float(answer.replace(",", "")) + return f"{num:g}" + except ValueError: + pass + + return answer + -- cgit v1.2.3