summaryrefslogtreecommitdiff
path: root/hag/retriever_hopfield.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/retriever_hopfield.py')
-rw-r--r--hag/retriever_hopfield.py77
1 files changed, 77 insertions, 0 deletions
diff --git a/hag/retriever_hopfield.py b/hag/retriever_hopfield.py
new file mode 100644
index 0000000..1cb6968
--- /dev/null
+++ b/hag/retriever_hopfield.py
@@ -0,0 +1,77 @@
+"""Hopfield-based retriever wrapping HopfieldRetrieval + MemoryBank."""
+
+import logging
+from typing import List
+
+import torch
+
+from hag.datatypes import RetrievalResult
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+
+logger = logging.getLogger(__name__)
+
+
+class HopfieldRetriever:
+ """Wraps HopfieldRetrieval + MemoryBank into a retriever interface.
+
+ The bridge between Hopfield's continuous retrieval and the discrete
+ passage selection needed for LLM prompting.
+ """
+
+ def __init__(
+ self,
+ hopfield: HopfieldRetrieval,
+ memory_bank: MemoryBank,
+ top_k: int = 5,
+ ) -> None:
+ self.hopfield = hopfield
+ self.memory_bank = memory_bank
+ self.top_k = top_k
+
+ def retrieve(
+ self,
+ query_embedding: torch.Tensor,
+ return_analysis: bool = False,
+ ) -> RetrievalResult:
+ """Retrieve top-k passages using iterative Hopfield retrieval.
+
+ 1. Run Hopfield iterative retrieval -> get attention weights alpha_T
+ 2. Take top_k indices from alpha_T
+ 3. Look up corresponding passage texts from memory bank
+ 4. Optionally return trajectory and energy for analysis
+
+ Args:
+ query_embedding: (d,) or (batch, d) — query embedding
+ return_analysis: if True, include full HopfieldResult
+
+ Returns:
+ RetrievalResult with passages, scores, indices, and optionally
+ the full hopfield_result.
+ """
+ hopfield_result = self.hopfield.retrieve(
+ query_embedding,
+ self.memory_bank.embeddings,
+ return_trajectory=return_analysis,
+ return_energy=return_analysis,
+ )
+
+ alpha = hopfield_result.attention_weights # (batch, N) or (1, N)
+
+ # Get top-k indices and scores
+ k = min(self.top_k, alpha.shape[-1])
+ scores, indices = torch.topk(alpha, k, dim=-1) # (batch, k)
+
+ # Flatten for single-query case
+ if scores.shape[0] == 1:
+ scores = scores.squeeze(0) # (k,)
+ indices = indices.squeeze(0) # (k,)
+
+ passages = self.memory_bank.get_passages_by_indices(indices)
+
+ return RetrievalResult(
+ passages=passages,
+ scores=scores,
+ indices=indices,
+ hopfield_result=hopfield_result if return_analysis else None,
+ )