summaryrefslogtreecommitdiff
path: root/hag/hopfield.py
blob: 287e4afd5547675e32ef30353f87e432c2b51d45 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
"""Core Modern Continuous Hopfield Network retrieval module.

Implements the iterative retrieval dynamics from:
    Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021)

Update rule: q_{t+1} = M * softmax(beta * M^T * q_t)
Energy:      E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
"""

import logging
from typing import Optional

import torch

from hag.config import HopfieldConfig
from hag.datatypes import HopfieldResult

logger = logging.getLogger(__name__)


class HopfieldRetrieval:
    """Modern Continuous Hopfield Network for memory retrieval.

    Given memory bank M in R^{d x N} and query q in R^d:
    1. Compute attention: alpha = softmax(beta * M^T @ q)
    2. Update query: q_new = M @ alpha
    3. Repeat until convergence or max_iter

    The energy function is:
        E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2

    Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease)
    """

    def __init__(self, config: HopfieldConfig) -> None:
        self.config = config

    @torch.no_grad()
    def retrieve(
        self,
        query: torch.Tensor,
        memory: torch.Tensor,
        return_trajectory: bool = False,
        return_energy: bool = False,
    ) -> HopfieldResult:
        """Run iterative Hopfield retrieval.

        Args:
            query: (d,) or (batch, d) — query embedding(s)
            memory: (d, N) — memory bank of passage embeddings
            return_trajectory: if True, store q_t at each step
            return_energy: if True, store E(q_t) at each step

        Returns:
            HopfieldResult with attention_weights, converged_query, num_steps,
            and optionally trajectory and energy_curve.
        """
        # Ensure query is 2D: (batch, d)
        if query.dim() == 1:
            query = query.unsqueeze(0)  # (1, d)

        q = query.clone()  # (batch, d)

        trajectory = [q.clone()] if return_trajectory else None
        energies = [self.compute_energy(q, memory)] if return_energy else None

        num_steps = 0
        for t in range(self.config.max_iter):
            # Core Hopfield update
            logits = self.config.beta * (q @ memory)  # (batch, N)
            alpha = torch.softmax(logits, dim=-1)  # (batch, N)
            q_new = alpha @ memory.T  # (batch, d)

            # Check convergence
            delta = torch.norm(q_new - q, dim=-1).max()  # scalar
            q = q_new

            if return_trajectory:
                trajectory.append(q.clone())
            if return_energy:
                energies.append(self.compute_energy(q, memory))

            num_steps = t + 1

            if delta < self.config.conv_threshold:
                break

        # Final attention weights (recompute to ensure consistency)
        logits = self.config.beta * (q @ memory)  # (batch, N)
        alpha = torch.softmax(logits, dim=-1)  # (batch, N)

        return HopfieldResult(
            attention_weights=alpha,
            converged_query=q,
            num_steps=num_steps,
            trajectory=trajectory,
            energy_curve=energies,
        )

    def compute_energy(
        self,
        query: torch.Tensor,
        memory: torch.Tensor,
    ) -> torch.Tensor:
        """Compute the Hopfield energy function.

        E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2

        Args:
            query: (batch, d) or (d,) — query embedding(s)
            memory: (d, N) — memory bank

        Returns:
            Energy scalar or (batch,) tensor.
        """
        if query.dim() == 1:
            query = query.unsqueeze(0)  # (1, d)

        logits = self.config.beta * (query @ memory)  # (batch, N)
        lse = torch.logsumexp(logits, dim=-1)  # (batch,)
        norm_sq = 0.5 * (query**2).sum(dim=-1)  # (batch,)
        energy = -1.0 / self.config.beta * lse + norm_sq  # (batch,)

        return energy