1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
|
"""End-to-end RAG/HAG pipeline: query -> encode -> retrieve -> generate."""
import logging
from typing import List, Optional, Protocol, Union
import numpy as np
import torch
from hag.config import PipelineConfig
from hag.datatypes import PipelineResult, RetrievalResult
from hag.hopfield import HopfieldRetrieval
from hag.memory_bank import MemoryBank
from hag.retriever_faiss import FAISSRetriever
from hag.retriever_hopfield import HopfieldRetriever
logger = logging.getLogger(__name__)
class EncoderProtocol(Protocol):
"""Protocol for encoder interface."""
def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: ...
class GeneratorProtocol(Protocol):
"""Protocol for generator interface."""
def generate(self, question: str, passages: List[str]) -> str: ...
class RAGPipeline:
"""End-to-end pipeline: query -> encode -> retrieve -> generate.
Supports both FAISS (baseline) and Hopfield (ours) retrieval.
"""
def __init__(
self,
config: PipelineConfig,
encoder: EncoderProtocol,
generator: GeneratorProtocol,
memory_bank: Optional[MemoryBank] = None,
faiss_retriever: Optional[FAISSRetriever] = None,
) -> None:
self.config = config
self.encoder = encoder
self.generator = generator
if config.retriever_type == "faiss":
assert faiss_retriever is not None, "FAISSRetriever required for faiss mode"
self.retriever_type = "faiss"
self.faiss_retriever = faiss_retriever
self.hopfield_retriever: Optional[HopfieldRetriever] = None
elif config.retriever_type == "hopfield":
assert memory_bank is not None, "MemoryBank required for hopfield mode"
hopfield = HopfieldRetrieval(config.hopfield)
self.retriever_type = "hopfield"
self.hopfield_retriever = HopfieldRetriever(
hopfield, memory_bank, top_k=config.hopfield.top_k
)
self.faiss_retriever = None
else:
raise ValueError(f"Unknown retriever_type: {config.retriever_type}")
def run(self, question: str) -> PipelineResult:
"""Run the full pipeline on a single question.
1. Encode question -> query embedding
2. Retrieve passages (FAISS or Hopfield)
3. Generate answer with LLM
Args:
question: input question string
Returns:
PipelineResult with answer and retrieval metadata.
"""
# Encode
query_emb = self.encoder.encode(question) # (1, d)
# Retrieve
if self.retriever_type == "hopfield":
retrieval_result = self.hopfield_retriever.retrieve(query_emb)
else:
query_np = query_emb.detach().numpy().astype(np.float32)
retrieval_result = self.faiss_retriever.retrieve(query_np)
# Generate
answer = self.generator.generate(question, retrieval_result.passages)
return PipelineResult(
question=question,
answer=answer,
retrieved_passages=retrieval_result.passages,
retrieval_result=retrieval_result,
)
def run_batch(self, questions: List[str]) -> List[PipelineResult]:
"""Run pipeline on a batch of questions.
Args:
questions: list of question strings
Returns:
List of PipelineResult, one per question.
"""
return [self.run(q) for q in questions]
|