summaryrefslogtreecommitdiff
path: root/hag/metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/metrics.py')
-rw-r--r--hag/metrics.py113
1 files changed, 113 insertions, 0 deletions
diff --git a/hag/metrics.py b/hag/metrics.py
new file mode 100644
index 0000000..6a196df
--- /dev/null
+++ b/hag/metrics.py
@@ -0,0 +1,113 @@
+"""Evaluation metrics for HAG: exact match, F1, retrieval recall."""
+
+import logging
+import re
+import string
+from collections import Counter
+from typing import Dict, List
+
+from hag.datatypes import PipelineResult
+
+logger = logging.getLogger(__name__)
+
+
+def _normalize_answer(text: str) -> str:
+ """Normalize answer text: lowercase, strip, remove articles and punctuation."""
+ text = text.lower().strip()
+ # Remove articles
+ text = re.sub(r"\b(a|an|the)\b", " ", text)
+ # Remove punctuation
+ text = text.translate(str.maketrans("", "", string.punctuation))
+ # Collapse whitespace
+ text = " ".join(text.split())
+ return text
+
+
+def exact_match(prediction: str, ground_truth: str) -> float:
+ """Normalized exact match.
+
+ Args:
+ prediction: predicted answer string
+ ground_truth: gold answer string
+
+ Returns:
+ 1.0 if normalized strings match, 0.0 otherwise.
+ """
+ return float(_normalize_answer(prediction) == _normalize_answer(ground_truth))
+
+
+def f1_score(prediction: str, ground_truth: str) -> float:
+ """Token-level F1 between prediction and ground truth.
+
+ Args:
+ prediction: predicted answer string
+ ground_truth: gold answer string
+
+ Returns:
+ F1 score between 0.0 and 1.0.
+ """
+ pred_tokens = _normalize_answer(prediction).split()
+ gold_tokens = _normalize_answer(ground_truth).split()
+
+ if not pred_tokens and not gold_tokens:
+ return 1.0
+ if not pred_tokens or not gold_tokens:
+ return 0.0
+
+ common = Counter(pred_tokens) & Counter(gold_tokens)
+ num_same = sum(common.values())
+
+ if num_same == 0:
+ return 0.0
+
+ precision = num_same / len(pred_tokens)
+ recall = num_same / len(gold_tokens)
+ f1 = 2 * precision * recall / (precision + recall)
+ return f1
+
+
+def retrieval_recall_at_k(
+ retrieved_indices: List[int], gold_indices: List[int], k: int
+) -> float:
+ """What fraction of gold passages appear in the retrieved top-k?
+
+ Args:
+ retrieved_indices: list of retrieved passage indices (top-k)
+ gold_indices: list of gold/relevant passage indices
+ k: number of retrieved passages to consider
+
+ Returns:
+ Recall score between 0.0 and 1.0.
+ """
+ if not gold_indices:
+ return 1.0
+ retrieved_set = set(retrieved_indices[:k])
+ gold_set = set(gold_indices)
+ return len(retrieved_set & gold_set) / len(gold_set)
+
+
+def evaluate_dataset(
+ results: List[PipelineResult], gold_answers: List[str]
+) -> Dict[str, float]:
+ """Compute aggregate metrics over a dataset.
+
+ Args:
+ results: list of PipelineResult from the pipeline
+ gold_answers: list of gold answer strings
+
+ Returns:
+ Dict with keys 'em', 'f1' containing averaged scores.
+ """
+ assert len(results) == len(gold_answers)
+
+ em_scores = []
+ f1_scores = []
+
+ for result, gold in zip(results, gold_answers):
+ em_scores.append(exact_match(result.answer, gold))
+ f1_scores.append(f1_score(result.answer, gold))
+
+ return {
+ "em": sum(em_scores) / len(em_scores) if em_scores else 0.0,
+ "f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
+ }