summaryrefslogtreecommitdiff
path: root/hag/memory_bank.py
blob: 42dcc73743e58063b1d104e7096e45066b3dd45f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
"""Memory bank construction and management for passage embeddings."""

import logging
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F

from hag.config import MemoryBankConfig

logger = logging.getLogger(__name__)


class MemoryBank:
    """Stores passage embeddings and provides lookup from indices back to text.

    The memory bank is M in R^{d x N} where each column is a passage embedding.
    Also maintains a mapping from column index to passage text for final retrieval.
    """

    def __init__(self, config: MemoryBankConfig) -> None:
        self.config = config
        self.embeddings: Optional[torch.Tensor] = None  # (d, N)
        self.passages: List[str] = []

    def build_from_embeddings(
        self, embeddings: torch.Tensor, passages: List[str]
    ) -> None:
        """Build memory bank from precomputed embeddings.

        Args:
            embeddings: (N, d) — passage embeddings (note: input is N x d)
            passages: list of N passage strings
        """
        assert embeddings.shape[0] == len(passages), (
            f"Number of embeddings ({embeddings.shape[0]}) must match "
            f"number of passages ({len(passages)})"
        )
        if self.config.normalize:
            embeddings = F.normalize(embeddings, dim=-1)
        self.embeddings = embeddings.T  # Store as (d, N) for efficient matmul
        self.passages = list(passages)
        logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim)

    def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
        """Given top-k indices, return corresponding passage texts.

        Args:
            indices: (k,) or (batch, k) tensor of integer indices

        Returns:
            List of passage strings.
        """
        flat_indices = indices.flatten().tolist()
        return [self.passages[i] for i in flat_indices]

    def save(self, path: str) -> None:
        """Save memory bank to disk.

        Args:
            path: file path for saving (e.g., 'memory_bank.pt')
        """
        data: Dict = {
            "embeddings": self.embeddings,
            "passages": self.passages,
            "config": {
                "embedding_dim": self.config.embedding_dim,
                "normalize": self.config.normalize,
            },
        }
        torch.save(data, path)
        logger.info("Saved memory bank to %s", path)

    def load(self, path: str) -> None:
        """Load memory bank from disk.

        Args:
            path: file path to load from
        """
        data = torch.load(path, weights_only=False)
        self.embeddings = data["embeddings"]
        self.passages = data["passages"]
        logger.info("Loaded memory bank from %s (%d passages)", path, self.size)

    @property
    def size(self) -> int:
        """Number of passages in the memory bank."""
        return self.embeddings.shape[1] if self.embeddings is not None else 0

    @property
    def dim(self) -> int:
        """Embedding dimensionality."""
        return self.embeddings.shape[0] if self.embeddings is not None else 0