From c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sun, 15 Feb 2026 18:19:50 +0000 Subject: Initial implementation of HAG (Hopfield-Augmented Generation) Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 --- tests/test_memory_bank.py | 65 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) create mode 100644 tests/test_memory_bank.py (limited to 'tests/test_memory_bank.py') diff --git a/tests/test_memory_bank.py b/tests/test_memory_bank.py new file mode 100644 index 0000000..0087bbd --- /dev/null +++ b/tests/test_memory_bank.py @@ -0,0 +1,65 @@ +"""Unit tests for the memory bank module.""" + +import torch + +from hag.config import MemoryBankConfig +from hag.memory_bank import MemoryBank + + +class TestMemoryBank: + """Tests for MemoryBank construction, lookup, and persistence.""" + + def test_build_and_size(self) -> None: + """Build memory bank and verify size.""" + config = MemoryBankConfig(embedding_dim=64, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(100, 64) + passages = [f"passage {i}" for i in range(100)] + mb.build_from_embeddings(embeddings, passages) + assert mb.size == 100 + assert mb.dim == 64 + + def test_normalization(self) -> None: + """If normalize=True, stored embeddings should have unit norm.""" + config = MemoryBankConfig(embedding_dim=64, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(50, 64) * 5 # non-unit norm + mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) + norms = torch.norm(mb.embeddings, dim=0) + assert torch.allclose(norms, torch.ones(50), atol=1e-5) + + def test_no_normalization(self) -> None: + """If normalize=False, stored embeddings keep original norms.""" + config = MemoryBankConfig(embedding_dim=64, normalize=False) + mb = MemoryBank(config) + embeddings = torch.randn(50, 64) * 5 + original_norms = torch.norm(embeddings, dim=-1) + mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) + stored_norms = torch.norm(mb.embeddings, dim=0) + assert torch.allclose(stored_norms, original_norms, atol=1e-5) + + def test_get_passages_by_indices(self) -> None: + """Index -> passage text lookup.""" + config = MemoryBankConfig(embedding_dim=64, normalize=False) + mb = MemoryBank(config) + passages = [f"passage {i}" for i in range(100)] + mb.build_from_embeddings(torch.randn(100, 64), passages) + result = mb.get_passages_by_indices(torch.tensor([0, 50, 99])) + assert result == ["passage 0", "passage 50", "passage 99"] + + def test_save_and_load(self, tmp_path) -> None: # type: ignore[no-untyped-def] + """Save and reload memory bank, verify contents match.""" + config = MemoryBankConfig(embedding_dim=32, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(20, 32) + passages = [f"text {i}" for i in range(20)] + mb.build_from_embeddings(embeddings, passages) + + save_path = str(tmp_path / "mb.pt") + mb.save(save_path) + + mb2 = MemoryBank(config) + mb2.load(save_path) + assert mb2.size == 20 + assert mb2.passages == passages + assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6) -- cgit v1.2.3