From c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sun, 15 Feb 2026 18:19:50 +0000 Subject: Initial implementation of HAG (Hopfield-Augmented Generation) Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 --- hag/retriever_hopfield.py | 77 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 hag/retriever_hopfield.py (limited to 'hag/retriever_hopfield.py') diff --git a/hag/retriever_hopfield.py b/hag/retriever_hopfield.py new file mode 100644 index 0000000..1cb6968 --- /dev/null +++ b/hag/retriever_hopfield.py @@ -0,0 +1,77 @@ +"""Hopfield-based retriever wrapping HopfieldRetrieval + MemoryBank.""" + +import logging +from typing import List + +import torch + +from hag.datatypes import RetrievalResult +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank + +logger = logging.getLogger(__name__) + + +class HopfieldRetriever: + """Wraps HopfieldRetrieval + MemoryBank into a retriever interface. + + The bridge between Hopfield's continuous retrieval and the discrete + passage selection needed for LLM prompting. + """ + + def __init__( + self, + hopfield: HopfieldRetrieval, + memory_bank: MemoryBank, + top_k: int = 5, + ) -> None: + self.hopfield = hopfield + self.memory_bank = memory_bank + self.top_k = top_k + + def retrieve( + self, + query_embedding: torch.Tensor, + return_analysis: bool = False, + ) -> RetrievalResult: + """Retrieve top-k passages using iterative Hopfield retrieval. + + 1. Run Hopfield iterative retrieval -> get attention weights alpha_T + 2. Take top_k indices from alpha_T + 3. Look up corresponding passage texts from memory bank + 4. Optionally return trajectory and energy for analysis + + Args: + query_embedding: (d,) or (batch, d) — query embedding + return_analysis: if True, include full HopfieldResult + + Returns: + RetrievalResult with passages, scores, indices, and optionally + the full hopfield_result. + """ + hopfield_result = self.hopfield.retrieve( + query_embedding, + self.memory_bank.embeddings, + return_trajectory=return_analysis, + return_energy=return_analysis, + ) + + alpha = hopfield_result.attention_weights # (batch, N) or (1, N) + + # Get top-k indices and scores + k = min(self.top_k, alpha.shape[-1]) + scores, indices = torch.topk(alpha, k, dim=-1) # (batch, k) + + # Flatten for single-query case + if scores.shape[0] == 1: + scores = scores.squeeze(0) # (k,) + indices = indices.squeeze(0) # (k,) + + passages = self.memory_bank.get_passages_by_indices(indices) + + return RetrievalResult( + passages=passages, + scores=scores, + indices=indices, + hopfield_result=hopfield_result if return_analysis else None, + ) -- cgit v1.2.3