summaryrefslogtreecommitdiff
path: root/hag/retriever_hopfield.py
blob: 1cb696834147ca1b96848c09e6eee8f8015b3e0e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
"""Hopfield-based retriever wrapping HopfieldRetrieval + MemoryBank."""

import logging
from typing import List

import torch

from hag.datatypes import RetrievalResult
from hag.hopfield import HopfieldRetrieval
from hag.memory_bank import MemoryBank

logger = logging.getLogger(__name__)


class HopfieldRetriever:
    """Wraps HopfieldRetrieval + MemoryBank into a retriever interface.

    The bridge between Hopfield's continuous retrieval and the discrete
    passage selection needed for LLM prompting.
    """

    def __init__(
        self,
        hopfield: HopfieldRetrieval,
        memory_bank: MemoryBank,
        top_k: int = 5,
    ) -> None:
        self.hopfield = hopfield
        self.memory_bank = memory_bank
        self.top_k = top_k

    def retrieve(
        self,
        query_embedding: torch.Tensor,
        return_analysis: bool = False,
    ) -> RetrievalResult:
        """Retrieve top-k passages using iterative Hopfield retrieval.

        1. Run Hopfield iterative retrieval -> get attention weights alpha_T
        2. Take top_k indices from alpha_T
        3. Look up corresponding passage texts from memory bank
        4. Optionally return trajectory and energy for analysis

        Args:
            query_embedding: (d,) or (batch, d) — query embedding
            return_analysis: if True, include full HopfieldResult

        Returns:
            RetrievalResult with passages, scores, indices, and optionally
            the full hopfield_result.
        """
        hopfield_result = self.hopfield.retrieve(
            query_embedding,
            self.memory_bank.embeddings,
            return_trajectory=return_analysis,
            return_energy=return_analysis,
        )

        alpha = hopfield_result.attention_weights  # (batch, N) or (1, N)

        # Get top-k indices and scores
        k = min(self.top_k, alpha.shape[-1])
        scores, indices = torch.topk(alpha, k, dim=-1)  # (batch, k)

        # Flatten for single-query case
        if scores.shape[0] == 1:
            scores = scores.squeeze(0)  # (k,)
            indices = indices.squeeze(0)  # (k,)

        passages = self.memory_bank.get_passages_by_indices(indices)

        return RetrievalResult(
            passages=passages,
            scores=scores,
            indices=indices,
            hopfield_result=hopfield_result if return_analysis else None,
        )