summaryrefslogtreecommitdiff
path: root/tests/test_metrics.py
blob: 4a18bec6a3a2625495cf91e606095daea72a67ba (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Unit tests for evaluation metrics."""

from hag.metrics import exact_match, f1_score, retrieval_recall_at_k


class TestMetrics:
    """Tests for EM, F1, and retrieval recall metrics."""

    def test_exact_match_basic(self) -> None:
        """Basic exact match: case insensitive."""
        assert exact_match("Paris", "paris") == 1.0
        assert exact_match("Paris", "London") == 0.0

    def test_exact_match_normalization(self) -> None:
        """Should strip whitespace, lowercase, remove articles."""
        assert exact_match("  The Paris  ", "paris") == 1.0
        assert exact_match("A dog", "dog") == 1.0

    def test_exact_match_empty(self) -> None:
        """Empty strings should match."""
        assert exact_match("", "") == 1.0

    def test_f1_score_perfect(self) -> None:
        """Identical strings should have F1 = 1.0."""
        assert f1_score("the cat sat", "the cat sat") == 1.0

    def test_f1_score_partial(self) -> None:
        """Partial overlap should give 0 < F1 < 1."""
        score = f1_score("the cat sat on the mat", "the cat sat")
        assert 0.5 < score < 1.0

    def test_f1_score_no_overlap(self) -> None:
        """No common tokens should give F1 = 0."""
        assert f1_score("hello world", "foo bar") == 0.0

    def test_f1_score_empty(self) -> None:
        """Two empty strings should have F1 = 1.0."""
        assert f1_score("", "") == 1.0

    def test_retrieval_recall(self) -> None:
        """Standard retrieval recall computation."""
        assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2 / 3

    def test_retrieval_recall_perfect(self) -> None:
        """All gold passages retrieved."""
        assert retrieval_recall_at_k([1, 2, 3], [1, 2, 3], k=3) == 1.0

    def test_retrieval_recall_none(self) -> None:
        """No gold passages retrieved."""
        assert retrieval_recall_at_k([4, 5, 6], [1, 2, 3], k=3) == 0.0

    def test_retrieval_recall_empty_gold(self) -> None:
        """No gold passages means perfect recall by convention."""
        assert retrieval_recall_at_k([1, 2, 3], [], k=3) == 1.0