diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/memory_bank.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/memory_bank.py')
| -rw-r--r-- | hag/memory_bank.py | 93 |
1 files changed, 93 insertions, 0 deletions
diff --git a/hag/memory_bank.py b/hag/memory_bank.py new file mode 100644 index 0000000..42dcc73 --- /dev/null +++ b/hag/memory_bank.py @@ -0,0 +1,93 @@ +"""Memory bank construction and management for passage embeddings.""" + +import logging +from typing import Dict, List, Optional + +import torch +import torch.nn.functional as F + +from hag.config import MemoryBankConfig + +logger = logging.getLogger(__name__) + + +class MemoryBank: + """Stores passage embeddings and provides lookup from indices back to text. + + The memory bank is M in R^{d x N} where each column is a passage embedding. + Also maintains a mapping from column index to passage text for final retrieval. + """ + + def __init__(self, config: MemoryBankConfig) -> None: + self.config = config + self.embeddings: Optional[torch.Tensor] = None # (d, N) + self.passages: List[str] = [] + + def build_from_embeddings( + self, embeddings: torch.Tensor, passages: List[str] + ) -> None: + """Build memory bank from precomputed embeddings. + + Args: + embeddings: (N, d) — passage embeddings (note: input is N x d) + passages: list of N passage strings + """ + assert embeddings.shape[0] == len(passages), ( + f"Number of embeddings ({embeddings.shape[0]}) must match " + f"number of passages ({len(passages)})" + ) + if self.config.normalize: + embeddings = F.normalize(embeddings, dim=-1) + self.embeddings = embeddings.T # Store as (d, N) for efficient matmul + self.passages = list(passages) + logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim) + + def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: + """Given top-k indices, return corresponding passage texts. + + Args: + indices: (k,) or (batch, k) tensor of integer indices + + Returns: + List of passage strings. + """ + flat_indices = indices.flatten().tolist() + return [self.passages[i] for i in flat_indices] + + def save(self, path: str) -> None: + """Save memory bank to disk. + + Args: + path: file path for saving (e.g., 'memory_bank.pt') + """ + data: Dict = { + "embeddings": self.embeddings, + "passages": self.passages, + "config": { + "embedding_dim": self.config.embedding_dim, + "normalize": self.config.normalize, + }, + } + torch.save(data, path) + logger.info("Saved memory bank to %s", path) + + def load(self, path: str) -> None: + """Load memory bank from disk. + + Args: + path: file path to load from + """ + data = torch.load(path, weights_only=False) + self.embeddings = data["embeddings"] + self.passages = data["passages"] + logger.info("Loaded memory bank from %s (%d passages)", path, self.size) + + @property + def size(self) -> int: + """Number of passages in the memory bank.""" + return self.embeddings.shape[1] if self.embeddings is not None else 0 + + @property + def dim(self) -> int: + """Embedding dimensionality.""" + return self.embeddings.shape[0] if self.embeddings is not None else 0 |
