1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
|
"""Hopfield-based retriever wrapping HopfieldRetrieval + MemoryBank."""
import logging
from typing import List
import torch
from hag.datatypes import RetrievalResult
from hag.hopfield import HopfieldRetrieval
from hag.memory_bank import MemoryBank
logger = logging.getLogger(__name__)
class HopfieldRetriever:
"""Wraps HopfieldRetrieval + MemoryBank into a retriever interface.
The bridge between Hopfield's continuous retrieval and the discrete
passage selection needed for LLM prompting.
"""
def __init__(
self,
hopfield: HopfieldRetrieval,
memory_bank: MemoryBank,
top_k: int = 5,
) -> None:
self.hopfield = hopfield
self.memory_bank = memory_bank
self.top_k = top_k
def retrieve(
self,
query_embedding: torch.Tensor,
return_analysis: bool = False,
) -> RetrievalResult:
"""Retrieve top-k passages using iterative Hopfield retrieval.
1. Run Hopfield iterative retrieval -> get attention weights alpha_T
2. Take top_k indices from alpha_T
3. Look up corresponding passage texts from memory bank
4. Optionally return trajectory and energy for analysis
Args:
query_embedding: (d,) or (batch, d) — query embedding
return_analysis: if True, include full HopfieldResult
Returns:
RetrievalResult with passages, scores, indices, and optionally
the full hopfield_result.
"""
hopfield_result = self.hopfield.retrieve(
query_embedding,
self.memory_bank.embeddings,
return_trajectory=return_analysis,
return_energy=return_analysis,
)
alpha = hopfield_result.attention_weights # (batch, N) or (1, N)
# Get top-k indices and scores
k = min(self.top_k, alpha.shape[-1])
scores, indices = torch.topk(alpha, k, dim=-1) # (batch, k)
# Flatten for single-query case
if scores.shape[0] == 1:
scores = scores.squeeze(0) # (k,)
indices = indices.squeeze(0) # (k,)
passages = self.memory_bank.get_passages_by_indices(indices)
return RetrievalResult(
passages=passages,
scores=scores,
indices=indices,
hopfield_result=hopfield_result if return_analysis else None,
)
|