"""Evaluation metrics for HAG: exact match, F1, retrieval recall.""" import logging import re import string from collections import Counter from typing import Dict, List from hag.datatypes import PipelineResult logger = logging.getLogger(__name__) def _normalize_answer(text: str) -> str: """Normalize answer text: lowercase, strip, remove articles and punctuation.""" text = text.lower().strip() # Remove articles text = re.sub(r"\b(a|an|the)\b", " ", text) # Remove punctuation text = text.translate(str.maketrans("", "", string.punctuation)) # Collapse whitespace text = " ".join(text.split()) return text def exact_match(prediction: str, ground_truth: str) -> float: """Normalized exact match. Args: prediction: predicted answer string ground_truth: gold answer string Returns: 1.0 if normalized strings match, 0.0 otherwise. """ return float(_normalize_answer(prediction) == _normalize_answer(ground_truth)) def f1_score(prediction: str, ground_truth: str) -> float: """Token-level F1 between prediction and ground truth. Args: prediction: predicted answer string ground_truth: gold answer string Returns: F1 score between 0.0 and 1.0. """ pred_tokens = _normalize_answer(prediction).split() gold_tokens = _normalize_answer(ground_truth).split() if not pred_tokens and not gold_tokens: return 1.0 if not pred_tokens or not gold_tokens: return 0.0 common = Counter(pred_tokens) & Counter(gold_tokens) num_same = sum(common.values()) if num_same == 0: return 0.0 precision = num_same / len(pred_tokens) recall = num_same / len(gold_tokens) f1 = 2 * precision * recall / (precision + recall) return f1 def retrieval_recall_at_k( retrieved_indices: List[int], gold_indices: List[int], k: int ) -> float: """What fraction of gold passages appear in the retrieved top-k? Args: retrieved_indices: list of retrieved passage indices (top-k) gold_indices: list of gold/relevant passage indices k: number of retrieved passages to consider Returns: Recall score between 0.0 and 1.0. """ if not gold_indices: return 1.0 retrieved_set = set(retrieved_indices[:k]) gold_set = set(gold_indices) return len(retrieved_set & gold_set) / len(gold_set) def evaluate_dataset( results: List[PipelineResult], gold_answers: List[str] ) -> Dict[str, float]: """Compute aggregate metrics over a dataset. Args: results: list of PipelineResult from the pipeline gold_answers: list of gold answer strings Returns: Dict with keys 'em', 'f1' containing averaged scores. """ assert len(results) == len(gold_answers) em_scores = [] f1_scores = [] for result, gold in zip(results, gold_answers): em_scores.append(exact_match(result.answer, gold)) f1_scores.append(f1_score(result.answer, gold)) return { "em": sum(em_scores) / len(em_scores) if em_scores else 0.0, "f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0, }