summaryrefslogtreecommitdiff
path: root/hag/memory_bank.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/memory_bank.py')
-rw-r--r--hag/memory_bank.py61
1 files changed, 58 insertions, 3 deletions
diff --git a/hag/memory_bank.py b/hag/memory_bank.py
index 42dcc73..0a0a87c 100644
--- a/hag/memory_bank.py
+++ b/hag/memory_bank.py
@@ -16,12 +16,17 @@ class MemoryBank:
The memory bank is M in R^{d x N} where each column is a passage embedding.
Also maintains a mapping from column index to passage text for final retrieval.
+
+ When config.center=True, embeddings are mean-centered to remove the centroid
+ attractor in Hopfield dynamics. The mean is saved so queries can be centered
+ with the same offset via center_query().
"""
def __init__(self, config: MemoryBankConfig) -> None:
self.config = config
self.embeddings: Optional[torch.Tensor] = None # (d, N)
self.passages: List[str] = []
+ self.mean: Optional[torch.Tensor] = None # (d,) — saved for query centering
def build_from_embeddings(
self, embeddings: torch.Tensor, passages: List[str]
@@ -38,10 +43,42 @@ class MemoryBank:
)
if self.config.normalize:
embeddings = F.normalize(embeddings, dim=-1)
+ if self.config.center:
+ self.mean = embeddings.mean(dim=0) # (d,)
+ embeddings = embeddings - self.mean.unsqueeze(0) # (N, d)
+ logger.info("Centered memory bank (removed mean)")
self.embeddings = embeddings.T # Store as (d, N) for efficient matmul
self.passages = list(passages)
logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim)
+ def center_query(self, query: torch.Tensor) -> torch.Tensor:
+ """Center a query embedding using the saved memory mean.
+
+ Must be called before Hopfield retrieval when config.center=True.
+
+ Args:
+ query: (d,) or (batch, d) — query embedding(s)
+
+ Returns:
+ Centered query tensor, same shape as input.
+ """
+ if self.mean is None:
+ return query
+ return query - self.mean.to(query.device)
+
+ def apply_centering(self) -> None:
+ """Center an already-loaded (uncentered) memory bank in-place.
+
+ Useful when loading a memory bank that was saved without centering.
+ Computes and stores the mean, then subtracts it from embeddings.
+ """
+ if self.embeddings is None:
+ return
+ # embeddings is (d, N), mean over columns
+ self.mean = self.embeddings.mean(dim=1) # (d,)
+ self.embeddings = self.embeddings - self.mean.unsqueeze(1) # (d, N)
+ logger.info("Applied centering to loaded memory bank")
+
def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
"""Given top-k indices, return corresponding passage texts.
@@ -67,20 +104,38 @@ class MemoryBank:
"embedding_dim": self.config.embedding_dim,
"normalize": self.config.normalize,
},
+ "mean": self.mean,
}
torch.save(data, path)
logger.info("Saved memory bank to %s", path)
- def load(self, path: str) -> None:
+ def load(self, path: str, device: str = "cpu") -> None:
"""Load memory bank from disk.
Args:
path: file path to load from
+ device: device to load tensors onto ("cpu", "cuda", "cuda:0", etc.)
"""
- data = torch.load(path, weights_only=False)
+ data = torch.load(path, weights_only=False, map_location=device)
self.embeddings = data["embeddings"]
self.passages = data["passages"]
- logger.info("Loaded memory bank from %s (%d passages)", path, self.size)
+ self.mean = data.get("mean", None)
+ logger.info("Loaded memory bank from %s (%d passages, device=%s)", path, self.size, device)
+
+ def to(self, device: str) -> "MemoryBank":
+ """Move memory bank embeddings to the specified device.
+
+ Args:
+ device: target device ("cpu", "cuda", "cuda:0", etc.)
+
+ Returns:
+ self (for chaining).
+ """
+ if self.embeddings is not None:
+ self.embeddings = self.embeddings.to(device)
+ if self.mean is not None:
+ self.mean = self.mean.to(device)
+ return self
@property
def size(self) -> int: