1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
|
"""Memory bank construction and management for passage embeddings."""
import logging
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
from hag.config import MemoryBankConfig
logger = logging.getLogger(__name__)
class MemoryBank:
"""Stores passage embeddings and provides lookup from indices back to text.
The memory bank is M in R^{d x N} where each column is a passage embedding.
Also maintains a mapping from column index to passage text for final retrieval.
When config.center=True, embeddings are mean-centered to remove the centroid
attractor in Hopfield dynamics. The mean is saved so queries can be centered
with the same offset via center_query().
"""
def __init__(self, config: MemoryBankConfig) -> None:
self.config = config
self.embeddings: Optional[torch.Tensor] = None # (d, N)
self.passages: List[str] = []
self.mean: Optional[torch.Tensor] = None # (d,) — saved for query centering
def build_from_embeddings(
self, embeddings: torch.Tensor, passages: List[str]
) -> None:
"""Build memory bank from precomputed embeddings.
Args:
embeddings: (N, d) — passage embeddings (note: input is N x d)
passages: list of N passage strings
"""
assert embeddings.shape[0] == len(passages), (
f"Number of embeddings ({embeddings.shape[0]}) must match "
f"number of passages ({len(passages)})"
)
if self.config.normalize:
embeddings = F.normalize(embeddings, dim=-1)
if self.config.center:
self.mean = embeddings.mean(dim=0) # (d,)
embeddings = embeddings - self.mean.unsqueeze(0) # (N, d)
logger.info("Centered memory bank (removed mean)")
self.embeddings = embeddings.T # Store as (d, N) for efficient matmul
self.passages = list(passages)
logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim)
def center_query(self, query: torch.Tensor) -> torch.Tensor:
"""Center a query embedding using the saved memory mean.
Must be called before Hopfield retrieval when config.center=True.
Args:
query: (d,) or (batch, d) — query embedding(s)
Returns:
Centered query tensor, same shape as input.
"""
if self.mean is None:
return query
return query - self.mean.to(query.device)
def apply_centering(self) -> None:
"""Center an already-loaded (uncentered) memory bank in-place.
Useful when loading a memory bank that was saved without centering.
Computes and stores the mean, then subtracts it from embeddings.
"""
if self.embeddings is None:
return
# embeddings is (d, N), mean over columns
self.mean = self.embeddings.mean(dim=1) # (d,)
self.embeddings = self.embeddings - self.mean.unsqueeze(1) # (d, N)
logger.info("Applied centering to loaded memory bank")
def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
"""Given top-k indices, return corresponding passage texts.
Args:
indices: (k,) or (batch, k) tensor of integer indices
Returns:
List of passage strings.
"""
flat_indices = indices.flatten().tolist()
return [self.passages[i] for i in flat_indices]
def save(self, path: str) -> None:
"""Save memory bank to disk.
Args:
path: file path for saving (e.g., 'memory_bank.pt')
"""
data: Dict = {
"embeddings": self.embeddings,
"passages": self.passages,
"config": {
"embedding_dim": self.config.embedding_dim,
"normalize": self.config.normalize,
},
"mean": self.mean,
}
torch.save(data, path)
logger.info("Saved memory bank to %s", path)
def load(self, path: str, device: str = "cpu") -> None:
"""Load memory bank from disk.
Args:
path: file path to load from
device: device to load tensors onto ("cpu", "cuda", "cuda:0", etc.)
"""
data = torch.load(path, weights_only=False, map_location=device)
self.embeddings = data["embeddings"]
self.passages = data["passages"]
self.mean = data.get("mean", None)
logger.info("Loaded memory bank from %s (%d passages, device=%s)", path, self.size, device)
def to(self, device: str) -> "MemoryBank":
"""Move memory bank embeddings to the specified device.
Args:
device: target device ("cpu", "cuda", "cuda:0", etc.)
Returns:
self (for chaining).
"""
if self.embeddings is not None:
self.embeddings = self.embeddings.to(device)
if self.mean is not None:
self.mean = self.mean.to(device)
return self
@property
def size(self) -> int:
"""Number of passages in the memory bank."""
return self.embeddings.shape[1] if self.embeddings is not None else 0
@property
def dim(self) -> int:
"""Embedding dimensionality."""
return self.embeddings.shape[0] if self.embeddings is not None else 0
|