diff options
Diffstat (limited to 'hag/memory_bank.py')
| -rw-r--r-- | hag/memory_bank.py | 61 |
1 files changed, 58 insertions, 3 deletions
diff --git a/hag/memory_bank.py b/hag/memory_bank.py index 42dcc73..0a0a87c 100644 --- a/hag/memory_bank.py +++ b/hag/memory_bank.py @@ -16,12 +16,17 @@ class MemoryBank: The memory bank is M in R^{d x N} where each column is a passage embedding. Also maintains a mapping from column index to passage text for final retrieval. + + When config.center=True, embeddings are mean-centered to remove the centroid + attractor in Hopfield dynamics. The mean is saved so queries can be centered + with the same offset via center_query(). """ def __init__(self, config: MemoryBankConfig) -> None: self.config = config self.embeddings: Optional[torch.Tensor] = None # (d, N) self.passages: List[str] = [] + self.mean: Optional[torch.Tensor] = None # (d,) — saved for query centering def build_from_embeddings( self, embeddings: torch.Tensor, passages: List[str] @@ -38,10 +43,42 @@ class MemoryBank: ) if self.config.normalize: embeddings = F.normalize(embeddings, dim=-1) + if self.config.center: + self.mean = embeddings.mean(dim=0) # (d,) + embeddings = embeddings - self.mean.unsqueeze(0) # (N, d) + logger.info("Centered memory bank (removed mean)") self.embeddings = embeddings.T # Store as (d, N) for efficient matmul self.passages = list(passages) logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim) + def center_query(self, query: torch.Tensor) -> torch.Tensor: + """Center a query embedding using the saved memory mean. + + Must be called before Hopfield retrieval when config.center=True. + + Args: + query: (d,) or (batch, d) — query embedding(s) + + Returns: + Centered query tensor, same shape as input. + """ + if self.mean is None: + return query + return query - self.mean.to(query.device) + + def apply_centering(self) -> None: + """Center an already-loaded (uncentered) memory bank in-place. + + Useful when loading a memory bank that was saved without centering. + Computes and stores the mean, then subtracts it from embeddings. + """ + if self.embeddings is None: + return + # embeddings is (d, N), mean over columns + self.mean = self.embeddings.mean(dim=1) # (d,) + self.embeddings = self.embeddings - self.mean.unsqueeze(1) # (d, N) + logger.info("Applied centering to loaded memory bank") + def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: """Given top-k indices, return corresponding passage texts. @@ -67,20 +104,38 @@ class MemoryBank: "embedding_dim": self.config.embedding_dim, "normalize": self.config.normalize, }, + "mean": self.mean, } torch.save(data, path) logger.info("Saved memory bank to %s", path) - def load(self, path: str) -> None: + def load(self, path: str, device: str = "cpu") -> None: """Load memory bank from disk. Args: path: file path to load from + device: device to load tensors onto ("cpu", "cuda", "cuda:0", etc.) """ - data = torch.load(path, weights_only=False) + data = torch.load(path, weights_only=False, map_location=device) self.embeddings = data["embeddings"] self.passages = data["passages"] - logger.info("Loaded memory bank from %s (%d passages)", path, self.size) + self.mean = data.get("mean", None) + logger.info("Loaded memory bank from %s (%d passages, device=%s)", path, self.size, device) + + def to(self, device: str) -> "MemoryBank": + """Move memory bank embeddings to the specified device. + + Args: + device: target device ("cpu", "cuda", "cuda:0", etc.) + + Returns: + self (for chaining). + """ + if self.embeddings is not None: + self.embeddings = self.embeddings.to(device) + if self.mean is not None: + self.mean = self.mean.to(device) + return self @property def size(self) -> int: |
