diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/metrics.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/metrics.py')
| -rw-r--r-- | hag/metrics.py | 113 |
1 files changed, 113 insertions, 0 deletions
diff --git a/hag/metrics.py b/hag/metrics.py new file mode 100644 index 0000000..6a196df --- /dev/null +++ b/hag/metrics.py @@ -0,0 +1,113 @@ +"""Evaluation metrics for HAG: exact match, F1, retrieval recall.""" + +import logging +import re +import string +from collections import Counter +from typing import Dict, List + +from hag.datatypes import PipelineResult + +logger = logging.getLogger(__name__) + + +def _normalize_answer(text: str) -> str: + """Normalize answer text: lowercase, strip, remove articles and punctuation.""" + text = text.lower().strip() + # Remove articles + text = re.sub(r"\b(a|an|the)\b", " ", text) + # Remove punctuation + text = text.translate(str.maketrans("", "", string.punctuation)) + # Collapse whitespace + text = " ".join(text.split()) + return text + + +def exact_match(prediction: str, ground_truth: str) -> float: + """Normalized exact match. + + Args: + prediction: predicted answer string + ground_truth: gold answer string + + Returns: + 1.0 if normalized strings match, 0.0 otherwise. + """ + return float(_normalize_answer(prediction) == _normalize_answer(ground_truth)) + + +def f1_score(prediction: str, ground_truth: str) -> float: + """Token-level F1 between prediction and ground truth. + + Args: + prediction: predicted answer string + ground_truth: gold answer string + + Returns: + F1 score between 0.0 and 1.0. + """ + pred_tokens = _normalize_answer(prediction).split() + gold_tokens = _normalize_answer(ground_truth).split() + + if not pred_tokens and not gold_tokens: + return 1.0 + if not pred_tokens or not gold_tokens: + return 0.0 + + common = Counter(pred_tokens) & Counter(gold_tokens) + num_same = sum(common.values()) + + if num_same == 0: + return 0.0 + + precision = num_same / len(pred_tokens) + recall = num_same / len(gold_tokens) + f1 = 2 * precision * recall / (precision + recall) + return f1 + + +def retrieval_recall_at_k( + retrieved_indices: List[int], gold_indices: List[int], k: int +) -> float: + """What fraction of gold passages appear in the retrieved top-k? + + Args: + retrieved_indices: list of retrieved passage indices (top-k) + gold_indices: list of gold/relevant passage indices + k: number of retrieved passages to consider + + Returns: + Recall score between 0.0 and 1.0. + """ + if not gold_indices: + return 1.0 + retrieved_set = set(retrieved_indices[:k]) + gold_set = set(gold_indices) + return len(retrieved_set & gold_set) / len(gold_set) + + +def evaluate_dataset( + results: List[PipelineResult], gold_answers: List[str] +) -> Dict[str, float]: + """Compute aggregate metrics over a dataset. + + Args: + results: list of PipelineResult from the pipeline + gold_answers: list of gold answer strings + + Returns: + Dict with keys 'em', 'f1' containing averaged scores. + """ + assert len(results) == len(gold_answers) + + em_scores = [] + f1_scores = [] + + for result, gold in zip(results, gold_answers): + em_scores.append(exact_match(result.answer, gold)) + f1_scores.append(f1_score(result.answer, gold)) + + return { + "em": sum(em_scores) / len(em_scores) if em_scores else 0.0, + "f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0, + } |
