summaryrefslogtreecommitdiff
path: root/hag/metrics.py
blob: 6a196df49bd167730fd2bdc69cab08a8c46c04be (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""Evaluation metrics for HAG: exact match, F1, retrieval recall."""

import logging
import re
import string
from collections import Counter
from typing import Dict, List

from hag.datatypes import PipelineResult

logger = logging.getLogger(__name__)


def _normalize_answer(text: str) -> str:
    """Normalize answer text: lowercase, strip, remove articles and punctuation."""
    text = text.lower().strip()
    # Remove articles
    text = re.sub(r"\b(a|an|the)\b", " ", text)
    # Remove punctuation
    text = text.translate(str.maketrans("", "", string.punctuation))
    # Collapse whitespace
    text = " ".join(text.split())
    return text


def exact_match(prediction: str, ground_truth: str) -> float:
    """Normalized exact match.

    Args:
        prediction: predicted answer string
        ground_truth: gold answer string

    Returns:
        1.0 if normalized strings match, 0.0 otherwise.
    """
    return float(_normalize_answer(prediction) == _normalize_answer(ground_truth))


def f1_score(prediction: str, ground_truth: str) -> float:
    """Token-level F1 between prediction and ground truth.

    Args:
        prediction: predicted answer string
        ground_truth: gold answer string

    Returns:
        F1 score between 0.0 and 1.0.
    """
    pred_tokens = _normalize_answer(prediction).split()
    gold_tokens = _normalize_answer(ground_truth).split()

    if not pred_tokens and not gold_tokens:
        return 1.0
    if not pred_tokens or not gold_tokens:
        return 0.0

    common = Counter(pred_tokens) & Counter(gold_tokens)
    num_same = sum(common.values())

    if num_same == 0:
        return 0.0

    precision = num_same / len(pred_tokens)
    recall = num_same / len(gold_tokens)
    f1 = 2 * precision * recall / (precision + recall)
    return f1


def retrieval_recall_at_k(
    retrieved_indices: List[int], gold_indices: List[int], k: int
) -> float:
    """What fraction of gold passages appear in the retrieved top-k?

    Args:
        retrieved_indices: list of retrieved passage indices (top-k)
        gold_indices: list of gold/relevant passage indices
        k: number of retrieved passages to consider

    Returns:
        Recall score between 0.0 and 1.0.
    """
    if not gold_indices:
        return 1.0
    retrieved_set = set(retrieved_indices[:k])
    gold_set = set(gold_indices)
    return len(retrieved_set & gold_set) / len(gold_set)


def evaluate_dataset(
    results: List[PipelineResult], gold_answers: List[str]
) -> Dict[str, float]:
    """Compute aggregate metrics over a dataset.

    Args:
        results: list of PipelineResult from the pipeline
        gold_answers: list of gold answer strings

    Returns:
        Dict with keys 'em', 'f1' containing averaged scores.
    """
    assert len(results) == len(gold_answers)

    em_scores = []
    f1_scores = []

    for result, gold in zip(results, gold_answers):
        em_scores.append(exact_match(result.answer, gold))
        f1_scores.append(f1_score(result.answer, gold))

    return {
        "em": sum(em_scores) / len(em_scores) if em_scores else 0.0,
        "f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
    }