summaryrefslogtreecommitdiff
path: root/hag/memory_bank.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/memory_bank.py
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/memory_bank.py')
-rw-r--r--hag/memory_bank.py93
1 files changed, 93 insertions, 0 deletions
diff --git a/hag/memory_bank.py b/hag/memory_bank.py
new file mode 100644
index 0000000..42dcc73
--- /dev/null
+++ b/hag/memory_bank.py
@@ -0,0 +1,93 @@
+"""Memory bank construction and management for passage embeddings."""
+
+import logging
+from typing import Dict, List, Optional
+
+import torch
+import torch.nn.functional as F
+
+from hag.config import MemoryBankConfig
+
+logger = logging.getLogger(__name__)
+
+
+class MemoryBank:
+ """Stores passage embeddings and provides lookup from indices back to text.
+
+ The memory bank is M in R^{d x N} where each column is a passage embedding.
+ Also maintains a mapping from column index to passage text for final retrieval.
+ """
+
+ def __init__(self, config: MemoryBankConfig) -> None:
+ self.config = config
+ self.embeddings: Optional[torch.Tensor] = None # (d, N)
+ self.passages: List[str] = []
+
+ def build_from_embeddings(
+ self, embeddings: torch.Tensor, passages: List[str]
+ ) -> None:
+ """Build memory bank from precomputed embeddings.
+
+ Args:
+ embeddings: (N, d) — passage embeddings (note: input is N x d)
+ passages: list of N passage strings
+ """
+ assert embeddings.shape[0] == len(passages), (
+ f"Number of embeddings ({embeddings.shape[0]}) must match "
+ f"number of passages ({len(passages)})"
+ )
+ if self.config.normalize:
+ embeddings = F.normalize(embeddings, dim=-1)
+ self.embeddings = embeddings.T # Store as (d, N) for efficient matmul
+ self.passages = list(passages)
+ logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim)
+
+ def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
+ """Given top-k indices, return corresponding passage texts.
+
+ Args:
+ indices: (k,) or (batch, k) tensor of integer indices
+
+ Returns:
+ List of passage strings.
+ """
+ flat_indices = indices.flatten().tolist()
+ return [self.passages[i] for i in flat_indices]
+
+ def save(self, path: str) -> None:
+ """Save memory bank to disk.
+
+ Args:
+ path: file path for saving (e.g., 'memory_bank.pt')
+ """
+ data: Dict = {
+ "embeddings": self.embeddings,
+ "passages": self.passages,
+ "config": {
+ "embedding_dim": self.config.embedding_dim,
+ "normalize": self.config.normalize,
+ },
+ }
+ torch.save(data, path)
+ logger.info("Saved memory bank to %s", path)
+
+ def load(self, path: str) -> None:
+ """Load memory bank from disk.
+
+ Args:
+ path: file path to load from
+ """
+ data = torch.load(path, weights_only=False)
+ self.embeddings = data["embeddings"]
+ self.passages = data["passages"]
+ logger.info("Loaded memory bank from %s (%d passages)", path, self.size)
+
+ @property
+ def size(self) -> int:
+ """Number of passages in the memory bank."""
+ return self.embeddings.shape[1] if self.embeddings is not None else 0
+
+ @property
+ def dim(self) -> int:
+ """Embedding dimensionality."""
+ return self.embeddings.shape[0] if self.embeddings is not None else 0