"""Memory bank construction and management for passage embeddings.""" import logging from typing import Dict, List, Optional import torch import torch.nn.functional as F from hag.config import MemoryBankConfig logger = logging.getLogger(__name__) class MemoryBank: """Stores passage embeddings and provides lookup from indices back to text. The memory bank is M in R^{d x N} where each column is a passage embedding. Also maintains a mapping from column index to passage text for final retrieval. """ def __init__(self, config: MemoryBankConfig) -> None: self.config = config self.embeddings: Optional[torch.Tensor] = None # (d, N) self.passages: List[str] = [] def build_from_embeddings( self, embeddings: torch.Tensor, passages: List[str] ) -> None: """Build memory bank from precomputed embeddings. Args: embeddings: (N, d) — passage embeddings (note: input is N x d) passages: list of N passage strings """ assert embeddings.shape[0] == len(passages), ( f"Number of embeddings ({embeddings.shape[0]}) must match " f"number of passages ({len(passages)})" ) if self.config.normalize: embeddings = F.normalize(embeddings, dim=-1) self.embeddings = embeddings.T # Store as (d, N) for efficient matmul self.passages = list(passages) logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim) def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: """Given top-k indices, return corresponding passage texts. Args: indices: (k,) or (batch, k) tensor of integer indices Returns: List of passage strings. """ flat_indices = indices.flatten().tolist() return [self.passages[i] for i in flat_indices] def save(self, path: str) -> None: """Save memory bank to disk. Args: path: file path for saving (e.g., 'memory_bank.pt') """ data: Dict = { "embeddings": self.embeddings, "passages": self.passages, "config": { "embedding_dim": self.config.embedding_dim, "normalize": self.config.normalize, }, } torch.save(data, path) logger.info("Saved memory bank to %s", path) def load(self, path: str) -> None: """Load memory bank from disk. Args: path: file path to load from """ data = torch.load(path, weights_only=False) self.embeddings = data["embeddings"] self.passages = data["passages"] logger.info("Loaded memory bank from %s (%d passages)", path, self.size) @property def size(self) -> int: """Number of passages in the memory bank.""" return self.embeddings.shape[1] if self.embeddings is not None else 0 @property def dim(self) -> int: """Embedding dimensionality.""" return self.embeddings.shape[0] if self.embeddings is not None else 0