diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/retriever_hopfield.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/retriever_hopfield.py')
| -rw-r--r-- | hag/retriever_hopfield.py | 77 |
1 files changed, 77 insertions, 0 deletions
diff --git a/hag/retriever_hopfield.py b/hag/retriever_hopfield.py new file mode 100644 index 0000000..1cb6968 --- /dev/null +++ b/hag/retriever_hopfield.py @@ -0,0 +1,77 @@ +"""Hopfield-based retriever wrapping HopfieldRetrieval + MemoryBank.""" + +import logging +from typing import List + +import torch + +from hag.datatypes import RetrievalResult +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank + +logger = logging.getLogger(__name__) + + +class HopfieldRetriever: + """Wraps HopfieldRetrieval + MemoryBank into a retriever interface. + + The bridge between Hopfield's continuous retrieval and the discrete + passage selection needed for LLM prompting. + """ + + def __init__( + self, + hopfield: HopfieldRetrieval, + memory_bank: MemoryBank, + top_k: int = 5, + ) -> None: + self.hopfield = hopfield + self.memory_bank = memory_bank + self.top_k = top_k + + def retrieve( + self, + query_embedding: torch.Tensor, + return_analysis: bool = False, + ) -> RetrievalResult: + """Retrieve top-k passages using iterative Hopfield retrieval. + + 1. Run Hopfield iterative retrieval -> get attention weights alpha_T + 2. Take top_k indices from alpha_T + 3. Look up corresponding passage texts from memory bank + 4. Optionally return trajectory and energy for analysis + + Args: + query_embedding: (d,) or (batch, d) — query embedding + return_analysis: if True, include full HopfieldResult + + Returns: + RetrievalResult with passages, scores, indices, and optionally + the full hopfield_result. + """ + hopfield_result = self.hopfield.retrieve( + query_embedding, + self.memory_bank.embeddings, + return_trajectory=return_analysis, + return_energy=return_analysis, + ) + + alpha = hopfield_result.attention_weights # (batch, N) or (1, N) + + # Get top-k indices and scores + k = min(self.top_k, alpha.shape[-1]) + scores, indices = torch.topk(alpha, k, dim=-1) # (batch, k) + + # Flatten for single-query case + if scores.shape[0] == 1: + scores = scores.squeeze(0) # (k,) + indices = indices.squeeze(0) # (k,) + + passages = self.memory_bank.get_passages_by_indices(indices) + + return RetrievalResult( + passages=passages, + scores=scores, + indices=indices, + hopfield_result=hopfield_result if return_analysis else None, + ) |
