"""Unit tests for the memory bank module.""" import torch from hag.config import MemoryBankConfig from hag.memory_bank import MemoryBank class TestMemoryBank: """Tests for MemoryBank construction, lookup, and persistence.""" def test_build_and_size(self) -> None: """Build memory bank and verify size.""" config = MemoryBankConfig(embedding_dim=64, normalize=True) mb = MemoryBank(config) embeddings = torch.randn(100, 64) passages = [f"passage {i}" for i in range(100)] mb.build_from_embeddings(embeddings, passages) assert mb.size == 100 assert mb.dim == 64 def test_normalization(self) -> None: """If normalize=True, stored embeddings should have unit norm.""" config = MemoryBankConfig(embedding_dim=64, normalize=True) mb = MemoryBank(config) embeddings = torch.randn(50, 64) * 5 # non-unit norm mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) norms = torch.norm(mb.embeddings, dim=0) assert torch.allclose(norms, torch.ones(50), atol=1e-5) def test_no_normalization(self) -> None: """If normalize=False, stored embeddings keep original norms.""" config = MemoryBankConfig(embedding_dim=64, normalize=False) mb = MemoryBank(config) embeddings = torch.randn(50, 64) * 5 original_norms = torch.norm(embeddings, dim=-1) mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) stored_norms = torch.norm(mb.embeddings, dim=0) assert torch.allclose(stored_norms, original_norms, atol=1e-5) def test_get_passages_by_indices(self) -> None: """Index -> passage text lookup.""" config = MemoryBankConfig(embedding_dim=64, normalize=False) mb = MemoryBank(config) passages = [f"passage {i}" for i in range(100)] mb.build_from_embeddings(torch.randn(100, 64), passages) result = mb.get_passages_by_indices(torch.tensor([0, 50, 99])) assert result == ["passage 0", "passage 50", "passage 99"] def test_save_and_load(self, tmp_path) -> None: # type: ignore[no-untyped-def] """Save and reload memory bank, verify contents match.""" config = MemoryBankConfig(embedding_dim=32, normalize=True) mb = MemoryBank(config) embeddings = torch.randn(20, 32) passages = [f"text {i}" for i in range(20)] mb.build_from_embeddings(embeddings, passages) save_path = str(tmp_path / "mb.pt") mb.save(save_path) mb2 = MemoryBank(config) mb2.load(save_path) assert mb2.size == 20 assert mb2.passages == passages assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6)