"""Hopfield-based retriever wrapping HopfieldRetrieval + MemoryBank.""" import logging from typing import List import torch from hag.datatypes import RetrievalResult from hag.hopfield import HopfieldRetrieval from hag.memory_bank import MemoryBank logger = logging.getLogger(__name__) class HopfieldRetriever: """Wraps HopfieldRetrieval + MemoryBank into a retriever interface. The bridge between Hopfield's continuous retrieval and the discrete passage selection needed for LLM prompting. """ def __init__( self, hopfield: HopfieldRetrieval, memory_bank: MemoryBank, top_k: int = 5, ) -> None: self.hopfield = hopfield self.memory_bank = memory_bank self.top_k = top_k def retrieve( self, query_embedding: torch.Tensor, return_analysis: bool = False, ) -> RetrievalResult: """Retrieve top-k passages using iterative Hopfield retrieval. 1. Run Hopfield iterative retrieval -> get attention weights alpha_T 2. Take top_k indices from alpha_T 3. Look up corresponding passage texts from memory bank 4. Optionally return trajectory and energy for analysis Args: query_embedding: (d,) or (batch, d) — query embedding return_analysis: if True, include full HopfieldResult Returns: RetrievalResult with passages, scores, indices, and optionally the full hopfield_result. """ hopfield_result = self.hopfield.retrieve( query_embedding, self.memory_bank.embeddings, return_trajectory=return_analysis, return_energy=return_analysis, ) alpha = hopfield_result.attention_weights # (batch, N) or (1, N) # Get top-k indices and scores k = min(self.top_k, alpha.shape[-1]) scores, indices = torch.topk(alpha, k, dim=-1) # (batch, k) # Flatten for single-query case if scores.shape[0] == 1: scores = scores.squeeze(0) # (k,) indices = indices.squeeze(0) # (k,) passages = self.memory_bank.get_passages_by_indices(indices) return RetrievalResult( passages=passages, scores=scores, indices=indices, hopfield_result=hopfield_result if return_analysis else None, )