1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
|
"""Memory bank construction and management for passage embeddings."""
import logging
from typing import Dict, List, Optional
import torch
import torch.nn.functional as F
from hag.config import MemoryBankConfig
logger = logging.getLogger(__name__)
class MemoryBank:
"""Stores passage embeddings and provides lookup from indices back to text.
The memory bank is M in R^{d x N} where each column is a passage embedding.
Also maintains a mapping from column index to passage text for final retrieval.
"""
def __init__(self, config: MemoryBankConfig) -> None:
self.config = config
self.embeddings: Optional[torch.Tensor] = None # (d, N)
self.passages: List[str] = []
def build_from_embeddings(
self, embeddings: torch.Tensor, passages: List[str]
) -> None:
"""Build memory bank from precomputed embeddings.
Args:
embeddings: (N, d) — passage embeddings (note: input is N x d)
passages: list of N passage strings
"""
assert embeddings.shape[0] == len(passages), (
f"Number of embeddings ({embeddings.shape[0]}) must match "
f"number of passages ({len(passages)})"
)
if self.config.normalize:
embeddings = F.normalize(embeddings, dim=-1)
self.embeddings = embeddings.T # Store as (d, N) for efficient matmul
self.passages = list(passages)
logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim)
def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
"""Given top-k indices, return corresponding passage texts.
Args:
indices: (k,) or (batch, k) tensor of integer indices
Returns:
List of passage strings.
"""
flat_indices = indices.flatten().tolist()
return [self.passages[i] for i in flat_indices]
def save(self, path: str) -> None:
"""Save memory bank to disk.
Args:
path: file path for saving (e.g., 'memory_bank.pt')
"""
data: Dict = {
"embeddings": self.embeddings,
"passages": self.passages,
"config": {
"embedding_dim": self.config.embedding_dim,
"normalize": self.config.normalize,
},
}
torch.save(data, path)
logger.info("Saved memory bank to %s", path)
def load(self, path: str) -> None:
"""Load memory bank from disk.
Args:
path: file path to load from
"""
data = torch.load(path, weights_only=False)
self.embeddings = data["embeddings"]
self.passages = data["passages"]
logger.info("Loaded memory bank from %s (%d passages)", path, self.size)
@property
def size(self) -> int:
"""Number of passages in the memory bank."""
return self.embeddings.shape[1] if self.embeddings is not None else 0
@property
def dim(self) -> int:
"""Embedding dimensionality."""
return self.embeddings.shape[0] if self.embeddings is not None else 0
|