1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
|
"""Evaluation metrics for HAG: exact match, F1, retrieval recall."""
import logging
import re
import string
from collections import Counter
from typing import Dict, List
from hag.datatypes import PipelineResult
logger = logging.getLogger(__name__)
def _normalize_answer(text: str) -> str:
"""Normalize answer text: lowercase, strip, remove articles and punctuation."""
text = text.lower().strip()
# Remove articles
text = re.sub(r"\b(a|an|the)\b", " ", text)
# Remove punctuation
text = text.translate(str.maketrans("", "", string.punctuation))
# Collapse whitespace
text = " ".join(text.split())
return text
def exact_match(prediction: str, ground_truth: str) -> float:
"""Normalized exact match.
Args:
prediction: predicted answer string
ground_truth: gold answer string
Returns:
1.0 if normalized strings match, 0.0 otherwise.
"""
return float(_normalize_answer(prediction) == _normalize_answer(ground_truth))
def f1_score(prediction: str, ground_truth: str) -> float:
"""Token-level F1 between prediction and ground truth.
Args:
prediction: predicted answer string
ground_truth: gold answer string
Returns:
F1 score between 0.0 and 1.0.
"""
pred_tokens = _normalize_answer(prediction).split()
gold_tokens = _normalize_answer(ground_truth).split()
if not pred_tokens and not gold_tokens:
return 1.0
if not pred_tokens or not gold_tokens:
return 0.0
common = Counter(pred_tokens) & Counter(gold_tokens)
num_same = sum(common.values())
if num_same == 0:
return 0.0
precision = num_same / len(pred_tokens)
recall = num_same / len(gold_tokens)
f1 = 2 * precision * recall / (precision + recall)
return f1
def retrieval_recall_at_k(
retrieved_indices: List[int], gold_indices: List[int], k: int
) -> float:
"""What fraction of gold passages appear in the retrieved top-k?
Args:
retrieved_indices: list of retrieved passage indices (top-k)
gold_indices: list of gold/relevant passage indices
k: number of retrieved passages to consider
Returns:
Recall score between 0.0 and 1.0.
"""
if not gold_indices:
return 1.0
retrieved_set = set(retrieved_indices[:k])
gold_set = set(gold_indices)
return len(retrieved_set & gold_set) / len(gold_set)
def evaluate_dataset(
results: List[PipelineResult], gold_answers: List[str]
) -> Dict[str, float]:
"""Compute aggregate metrics over a dataset.
Args:
results: list of PipelineResult from the pipeline
gold_answers: list of gold answer strings
Returns:
Dict with keys 'em', 'f1' containing averaged scores.
"""
assert len(results) == len(gold_answers)
em_scores = []
f1_scores = []
for result, gold in zip(results, gold_answers):
em_scores.append(exact_match(result.answer, gold))
f1_scores.append(f1_score(result.answer, gold))
return {
"em": sum(em_scores) / len(em_scores) if em_scores else 0.0,
"f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
}
|