"""Memory bank construction and management for passage embeddings.""" import logging from typing import Dict, List, Optional import torch import torch.nn.functional as F from hag.config import MemoryBankConfig logger = logging.getLogger(__name__) class MemoryBank: """Stores passage embeddings and provides lookup from indices back to text. The memory bank is M in R^{d x N} where each column is a passage embedding. Also maintains a mapping from column index to passage text for final retrieval. When config.center=True, embeddings are mean-centered to remove the centroid attractor in Hopfield dynamics. The mean is saved so queries can be centered with the same offset via center_query(). """ def __init__(self, config: MemoryBankConfig) -> None: self.config = config self.embeddings: Optional[torch.Tensor] = None # (d, N) self.passages: List[str] = [] self.mean: Optional[torch.Tensor] = None # (d,) — saved for query centering def build_from_embeddings( self, embeddings: torch.Tensor, passages: List[str] ) -> None: """Build memory bank from precomputed embeddings. Args: embeddings: (N, d) — passage embeddings (note: input is N x d) passages: list of N passage strings """ assert embeddings.shape[0] == len(passages), ( f"Number of embeddings ({embeddings.shape[0]}) must match " f"number of passages ({len(passages)})" ) if self.config.normalize: embeddings = F.normalize(embeddings, dim=-1) if self.config.center: self.mean = embeddings.mean(dim=0) # (d,) embeddings = embeddings - self.mean.unsqueeze(0) # (N, d) logger.info("Centered memory bank (removed mean)") self.embeddings = embeddings.T # Store as (d, N) for efficient matmul self.passages = list(passages) logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim) def center_query(self, query: torch.Tensor) -> torch.Tensor: """Center a query embedding using the saved memory mean. Must be called before Hopfield retrieval when config.center=True. Args: query: (d,) or (batch, d) — query embedding(s) Returns: Centered query tensor, same shape as input. """ if self.mean is None: return query return query - self.mean.to(query.device) def apply_centering(self) -> None: """Center an already-loaded (uncentered) memory bank in-place. Useful when loading a memory bank that was saved without centering. Computes and stores the mean, then subtracts it from embeddings. """ if self.embeddings is None: return # embeddings is (d, N), mean over columns self.mean = self.embeddings.mean(dim=1) # (d,) self.embeddings = self.embeddings - self.mean.unsqueeze(1) # (d, N) logger.info("Applied centering to loaded memory bank") def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: """Given top-k indices, return corresponding passage texts. Args: indices: (k,) or (batch, k) tensor of integer indices Returns: List of passage strings. """ flat_indices = indices.flatten().tolist() return [self.passages[i] for i in flat_indices] def save(self, path: str) -> None: """Save memory bank to disk. Args: path: file path for saving (e.g., 'memory_bank.pt') """ data: Dict = { "embeddings": self.embeddings, "passages": self.passages, "config": { "embedding_dim": self.config.embedding_dim, "normalize": self.config.normalize, }, "mean": self.mean, } torch.save(data, path) logger.info("Saved memory bank to %s", path) def load(self, path: str, device: str = "cpu") -> None: """Load memory bank from disk. Args: path: file path to load from device: device to load tensors onto ("cpu", "cuda", "cuda:0", etc.) """ data = torch.load(path, weights_only=False, map_location=device) self.embeddings = data["embeddings"] self.passages = data["passages"] self.mean = data.get("mean", None) logger.info("Loaded memory bank from %s (%d passages, device=%s)", path, self.size, device) def to(self, device: str) -> "MemoryBank": """Move memory bank embeddings to the specified device. Args: device: target device ("cpu", "cuda", "cuda:0", etc.) Returns: self (for chaining). """ if self.embeddings is not None: self.embeddings = self.embeddings.to(device) if self.mean is not None: self.mean = self.mean.to(device) return self @property def size(self) -> int: """Number of passages in the memory bank.""" return self.embeddings.shape[1] if self.embeddings is not None else 0 @property def dim(self) -> int: """Embedding dimensionality.""" return self.embeddings.shape[0] if self.embeddings is not None else 0