diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /tests/test_metrics.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'tests/test_metrics.py')
| -rw-r--r-- | tests/test_metrics.py | 54 |
1 files changed, 54 insertions, 0 deletions
diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..4a18bec --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,54 @@ +"""Unit tests for evaluation metrics.""" + +from hag.metrics import exact_match, f1_score, retrieval_recall_at_k + + +class TestMetrics: + """Tests for EM, F1, and retrieval recall metrics.""" + + def test_exact_match_basic(self) -> None: + """Basic exact match: case insensitive.""" + assert exact_match("Paris", "paris") == 1.0 + assert exact_match("Paris", "London") == 0.0 + + def test_exact_match_normalization(self) -> None: + """Should strip whitespace, lowercase, remove articles.""" + assert exact_match(" The Paris ", "paris") == 1.0 + assert exact_match("A dog", "dog") == 1.0 + + def test_exact_match_empty(self) -> None: + """Empty strings should match.""" + assert exact_match("", "") == 1.0 + + def test_f1_score_perfect(self) -> None: + """Identical strings should have F1 = 1.0.""" + assert f1_score("the cat sat", "the cat sat") == 1.0 + + def test_f1_score_partial(self) -> None: + """Partial overlap should give 0 < F1 < 1.""" + score = f1_score("the cat sat on the mat", "the cat sat") + assert 0.5 < score < 1.0 + + def test_f1_score_no_overlap(self) -> None: + """No common tokens should give F1 = 0.""" + assert f1_score("hello world", "foo bar") == 0.0 + + def test_f1_score_empty(self) -> None: + """Two empty strings should have F1 = 1.0.""" + assert f1_score("", "") == 1.0 + + def test_retrieval_recall(self) -> None: + """Standard retrieval recall computation.""" + assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2 / 3 + + def test_retrieval_recall_perfect(self) -> None: + """All gold passages retrieved.""" + assert retrieval_recall_at_k([1, 2, 3], [1, 2, 3], k=3) == 1.0 + + def test_retrieval_recall_none(self) -> None: + """No gold passages retrieved.""" + assert retrieval_recall_at_k([4, 5, 6], [1, 2, 3], k=3) == 0.0 + + def test_retrieval_recall_empty_gold(self) -> None: + """No gold passages means perfect recall by convention.""" + assert retrieval_recall_at_k([1, 2, 3], [], k=3) == 1.0 |
