summaryrefslogtreecommitdiff
path: root/tests/test_memory_bank.py
blob: 0087bbd866cb4b4dbd456a90592f95835ad1c270 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""Unit tests for the memory bank module."""

import torch

from hag.config import MemoryBankConfig
from hag.memory_bank import MemoryBank


class TestMemoryBank:
    """Tests for MemoryBank construction, lookup, and persistence."""

    def test_build_and_size(self) -> None:
        """Build memory bank and verify size."""
        config = MemoryBankConfig(embedding_dim=64, normalize=True)
        mb = MemoryBank(config)
        embeddings = torch.randn(100, 64)
        passages = [f"passage {i}" for i in range(100)]
        mb.build_from_embeddings(embeddings, passages)
        assert mb.size == 100
        assert mb.dim == 64

    def test_normalization(self) -> None:
        """If normalize=True, stored embeddings should have unit norm."""
        config = MemoryBankConfig(embedding_dim=64, normalize=True)
        mb = MemoryBank(config)
        embeddings = torch.randn(50, 64) * 5  # non-unit norm
        mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
        norms = torch.norm(mb.embeddings, dim=0)
        assert torch.allclose(norms, torch.ones(50), atol=1e-5)

    def test_no_normalization(self) -> None:
        """If normalize=False, stored embeddings keep original norms."""
        config = MemoryBankConfig(embedding_dim=64, normalize=False)
        mb = MemoryBank(config)
        embeddings = torch.randn(50, 64) * 5
        original_norms = torch.norm(embeddings, dim=-1)
        mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
        stored_norms = torch.norm(mb.embeddings, dim=0)
        assert torch.allclose(stored_norms, original_norms, atol=1e-5)

    def test_get_passages_by_indices(self) -> None:
        """Index -> passage text lookup."""
        config = MemoryBankConfig(embedding_dim=64, normalize=False)
        mb = MemoryBank(config)
        passages = [f"passage {i}" for i in range(100)]
        mb.build_from_embeddings(torch.randn(100, 64), passages)
        result = mb.get_passages_by_indices(torch.tensor([0, 50, 99]))
        assert result == ["passage 0", "passage 50", "passage 99"]

    def test_save_and_load(self, tmp_path) -> None:  # type: ignore[no-untyped-def]
        """Save and reload memory bank, verify contents match."""
        config = MemoryBankConfig(embedding_dim=32, normalize=True)
        mb = MemoryBank(config)
        embeddings = torch.randn(20, 32)
        passages = [f"text {i}" for i in range(20)]
        mb.build_from_embeddings(embeddings, passages)

        save_path = str(tmp_path / "mb.pt")
        mb.save(save_path)

        mb2 = MemoryBank(config)
        mb2.load(save_path)
        assert mb2.size == 20
        assert mb2.passages == passages
        assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6)