summaryrefslogtreecommitdiff
path: root/tests/test_memory_bank.py
diff options
context:
space:
mode:
Diffstat (limited to 'tests/test_memory_bank.py')
-rw-r--r--tests/test_memory_bank.py65
1 files changed, 65 insertions, 0 deletions
diff --git a/tests/test_memory_bank.py b/tests/test_memory_bank.py
new file mode 100644
index 0000000..0087bbd
--- /dev/null
+++ b/tests/test_memory_bank.py
@@ -0,0 +1,65 @@
+"""Unit tests for the memory bank module."""
+
+import torch
+
+from hag.config import MemoryBankConfig
+from hag.memory_bank import MemoryBank
+
+
+class TestMemoryBank:
+ """Tests for MemoryBank construction, lookup, and persistence."""
+
+ def test_build_and_size(self) -> None:
+ """Build memory bank and verify size."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(100, 64)
+ passages = [f"passage {i}" for i in range(100)]
+ mb.build_from_embeddings(embeddings, passages)
+ assert mb.size == 100
+ assert mb.dim == 64
+
+ def test_normalization(self) -> None:
+ """If normalize=True, stored embeddings should have unit norm."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(50, 64) * 5 # non-unit norm
+ mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
+ norms = torch.norm(mb.embeddings, dim=0)
+ assert torch.allclose(norms, torch.ones(50), atol=1e-5)
+
+ def test_no_normalization(self) -> None:
+ """If normalize=False, stored embeddings keep original norms."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=False)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(50, 64) * 5
+ original_norms = torch.norm(embeddings, dim=-1)
+ mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
+ stored_norms = torch.norm(mb.embeddings, dim=0)
+ assert torch.allclose(stored_norms, original_norms, atol=1e-5)
+
+ def test_get_passages_by_indices(self) -> None:
+ """Index -> passage text lookup."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=False)
+ mb = MemoryBank(config)
+ passages = [f"passage {i}" for i in range(100)]
+ mb.build_from_embeddings(torch.randn(100, 64), passages)
+ result = mb.get_passages_by_indices(torch.tensor([0, 50, 99]))
+ assert result == ["passage 0", "passage 50", "passage 99"]
+
+ def test_save_and_load(self, tmp_path) -> None: # type: ignore[no-untyped-def]
+ """Save and reload memory bank, verify contents match."""
+ config = MemoryBankConfig(embedding_dim=32, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(20, 32)
+ passages = [f"text {i}" for i in range(20)]
+ mb.build_from_embeddings(embeddings, passages)
+
+ save_path = str(tmp_path / "mb.pt")
+ mb.save(save_path)
+
+ mb2 = MemoryBank(config)
+ mb2.load(save_path)
+ assert mb2.size == 20
+ assert mb2.passages == passages
+ assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6)