summaryrefslogtreecommitdiff
path: root/hag/memory_bank.py
blob: 0a0a87c19489644155e8cb6b3285459fd1b901ef (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Memory bank construction and management for passage embeddings."""

import logging
from typing import Dict, List, Optional

import torch
import torch.nn.functional as F

from hag.config import MemoryBankConfig

logger = logging.getLogger(__name__)


class MemoryBank:
    """Stores passage embeddings and provides lookup from indices back to text.

    The memory bank is M in R^{d x N} where each column is a passage embedding.
    Also maintains a mapping from column index to passage text for final retrieval.

    When config.center=True, embeddings are mean-centered to remove the centroid
    attractor in Hopfield dynamics. The mean is saved so queries can be centered
    with the same offset via center_query().
    """

    def __init__(self, config: MemoryBankConfig) -> None:
        self.config = config
        self.embeddings: Optional[torch.Tensor] = None  # (d, N)
        self.passages: List[str] = []
        self.mean: Optional[torch.Tensor] = None  # (d,) — saved for query centering

    def build_from_embeddings(
        self, embeddings: torch.Tensor, passages: List[str]
    ) -> None:
        """Build memory bank from precomputed embeddings.

        Args:
            embeddings: (N, d) — passage embeddings (note: input is N x d)
            passages: list of N passage strings
        """
        assert embeddings.shape[0] == len(passages), (
            f"Number of embeddings ({embeddings.shape[0]}) must match "
            f"number of passages ({len(passages)})"
        )
        if self.config.normalize:
            embeddings = F.normalize(embeddings, dim=-1)
        if self.config.center:
            self.mean = embeddings.mean(dim=0)  # (d,)
            embeddings = embeddings - self.mean.unsqueeze(0)  # (N, d)
            logger.info("Centered memory bank (removed mean)")
        self.embeddings = embeddings.T  # Store as (d, N) for efficient matmul
        self.passages = list(passages)
        logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim)

    def center_query(self, query: torch.Tensor) -> torch.Tensor:
        """Center a query embedding using the saved memory mean.

        Must be called before Hopfield retrieval when config.center=True.

        Args:
            query: (d,) or (batch, d) — query embedding(s)

        Returns:
            Centered query tensor, same shape as input.
        """
        if self.mean is None:
            return query
        return query - self.mean.to(query.device)

    def apply_centering(self) -> None:
        """Center an already-loaded (uncentered) memory bank in-place.

        Useful when loading a memory bank that was saved without centering.
        Computes and stores the mean, then subtracts it from embeddings.
        """
        if self.embeddings is None:
            return
        # embeddings is (d, N), mean over columns
        self.mean = self.embeddings.mean(dim=1)  # (d,)
        self.embeddings = self.embeddings - self.mean.unsqueeze(1)  # (d, N)
        logger.info("Applied centering to loaded memory bank")

    def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
        """Given top-k indices, return corresponding passage texts.

        Args:
            indices: (k,) or (batch, k) tensor of integer indices

        Returns:
            List of passage strings.
        """
        flat_indices = indices.flatten().tolist()
        return [self.passages[i] for i in flat_indices]

    def save(self, path: str) -> None:
        """Save memory bank to disk.

        Args:
            path: file path for saving (e.g., 'memory_bank.pt')
        """
        data: Dict = {
            "embeddings": self.embeddings,
            "passages": self.passages,
            "config": {
                "embedding_dim": self.config.embedding_dim,
                "normalize": self.config.normalize,
            },
            "mean": self.mean,
        }
        torch.save(data, path)
        logger.info("Saved memory bank to %s", path)

    def load(self, path: str, device: str = "cpu") -> None:
        """Load memory bank from disk.

        Args:
            path: file path to load from
            device: device to load tensors onto ("cpu", "cuda", "cuda:0", etc.)
        """
        data = torch.load(path, weights_only=False, map_location=device)
        self.embeddings = data["embeddings"]
        self.passages = data["passages"]
        self.mean = data.get("mean", None)
        logger.info("Loaded memory bank from %s (%d passages, device=%s)", path, self.size, device)

    def to(self, device: str) -> "MemoryBank":
        """Move memory bank embeddings to the specified device.

        Args:
            device: target device ("cpu", "cuda", "cuda:0", etc.)

        Returns:
            self (for chaining).
        """
        if self.embeddings is not None:
            self.embeddings = self.embeddings.to(device)
        if self.mean is not None:
            self.mean = self.mean.to(device)
        return self

    @property
    def size(self) -> int:
        """Number of passages in the memory bank."""
        return self.embeddings.shape[1] if self.embeddings is not None else 0

    @property
    def dim(self) -> int:
        """Embedding dimensionality."""
        return self.embeddings.shape[0] if self.embeddings is not None else 0