summaryrefslogtreecommitdiff
path: root/tests/test_metrics.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_metrics.py')
-rw-r--r--tests/test_metrics.py54
1 files changed, 54 insertions, 0 deletions
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
new file mode 100644
index 0000000..4a18bec
--- /dev/null
+++ b/tests/test_metrics.py
@@ -0,0 +1,54 @@
+"""Unit tests for evaluation metrics."""
+
+from hag.metrics import exact_match, f1_score, retrieval_recall_at_k
+
+
+class TestMetrics:
+ """Tests for EM, F1, and retrieval recall metrics."""
+
+ def test_exact_match_basic(self) -> None:
+ """Basic exact match: case insensitive."""
+ assert exact_match("Paris", "paris") == 1.0
+ assert exact_match("Paris", "London") == 0.0
+
+ def test_exact_match_normalization(self) -> None:
+ """Should strip whitespace, lowercase, remove articles."""
+ assert exact_match(" The Paris ", "paris") == 1.0
+ assert exact_match("A dog", "dog") == 1.0
+
+ def test_exact_match_empty(self) -> None:
+ """Empty strings should match."""
+ assert exact_match("", "") == 1.0
+
+ def test_f1_score_perfect(self) -> None:
+ """Identical strings should have F1 = 1.0."""
+ assert f1_score("the cat sat", "the cat sat") == 1.0
+
+ def test_f1_score_partial(self) -> None:
+ """Partial overlap should give 0 < F1 < 1."""
+ score = f1_score("the cat sat on the mat", "the cat sat")
+ assert 0.5 < score < 1.0
+
+ def test_f1_score_no_overlap(self) -> None:
+ """No common tokens should give F1 = 0."""
+ assert f1_score("hello world", "foo bar") == 0.0
+
+ def test_f1_score_empty(self) -> None:
+ """Two empty strings should have F1 = 1.0."""
+ assert f1_score("", "") == 1.0
+
+ def test_retrieval_recall(self) -> None:
+ """Standard retrieval recall computation."""
+ assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2 / 3
+
+ def test_retrieval_recall_perfect(self) -> None:
+ """All gold passages retrieved."""
+ assert retrieval_recall_at_k([1, 2, 3], [1, 2, 3], k=3) == 1.0
+
+ def test_retrieval_recall_none(self) -> None:
+ """No gold passages retrieved."""
+ assert retrieval_recall_at_k([4, 5, 6], [1, 2, 3], k=3) == 0.0
+
+ def test_retrieval_recall_empty_gold(self) -> None:
+ """No gold passages means perfect recall by convention."""
+ assert retrieval_recall_at_k([1, 2, 3], [], k=3) == 1.0