1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
|
"""Unit tests for the memory bank module."""
import torch
from hag.config import MemoryBankConfig
from hag.memory_bank import MemoryBank
class TestMemoryBank:
"""Tests for MemoryBank construction, lookup, and persistence."""
def test_build_and_size(self) -> None:
"""Build memory bank and verify size."""
config = MemoryBankConfig(embedding_dim=64, normalize=True)
mb = MemoryBank(config)
embeddings = torch.randn(100, 64)
passages = [f"passage {i}" for i in range(100)]
mb.build_from_embeddings(embeddings, passages)
assert mb.size == 100
assert mb.dim == 64
def test_normalization(self) -> None:
"""If normalize=True, stored embeddings should have unit norm."""
config = MemoryBankConfig(embedding_dim=64, normalize=True)
mb = MemoryBank(config)
embeddings = torch.randn(50, 64) * 5 # non-unit norm
mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
norms = torch.norm(mb.embeddings, dim=0)
assert torch.allclose(norms, torch.ones(50), atol=1e-5)
def test_no_normalization(self) -> None:
"""If normalize=False, stored embeddings keep original norms."""
config = MemoryBankConfig(embedding_dim=64, normalize=False)
mb = MemoryBank(config)
embeddings = torch.randn(50, 64) * 5
original_norms = torch.norm(embeddings, dim=-1)
mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
stored_norms = torch.norm(mb.embeddings, dim=0)
assert torch.allclose(stored_norms, original_norms, atol=1e-5)
def test_get_passages_by_indices(self) -> None:
"""Index -> passage text lookup."""
config = MemoryBankConfig(embedding_dim=64, normalize=False)
mb = MemoryBank(config)
passages = [f"passage {i}" for i in range(100)]
mb.build_from_embeddings(torch.randn(100, 64), passages)
result = mb.get_passages_by_indices(torch.tensor([0, 50, 99]))
assert result == ["passage 0", "passage 50", "passage 99"]
def test_save_and_load(self, tmp_path) -> None: # type: ignore[no-untyped-def]
"""Save and reload memory bank, verify contents match."""
config = MemoryBankConfig(embedding_dim=32, normalize=True)
mb = MemoryBank(config)
embeddings = torch.randn(20, 32)
passages = [f"text {i}" for i in range(20)]
mb.build_from_embeddings(embeddings, passages)
save_path = str(tmp_path / "mb.pt")
mb.save(save_path)
mb2 = MemoryBank(config)
mb2.load(save_path)
assert mb2.size == 20
assert mb2.passages == passages
assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6)
|