summaryrefslogtreecommitdiff
path: root/hag/pipeline.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/pipeline.py')
-rw-r--r--hag/pipeline.py107
1 files changed, 107 insertions, 0 deletions
diff --git a/hag/pipeline.py b/hag/pipeline.py
new file mode 100644
index 0000000..1fefb84
--- /dev/null
+++ b/hag/pipeline.py
@@ -0,0 +1,107 @@
+"""End-to-end RAG/HAG pipeline: query -> encode -> retrieve -> generate."""
+
+import logging
+from typing import List, Optional, Protocol, Union
+
+import numpy as np
+import torch
+
+from hag.config import PipelineConfig
+from hag.datatypes import PipelineResult, RetrievalResult
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+from hag.retriever_faiss import FAISSRetriever
+from hag.retriever_hopfield import HopfieldRetriever
+
+logger = logging.getLogger(__name__)
+
+
+class EncoderProtocol(Protocol):
+ """Protocol for encoder interface."""
+
+ def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: ...
+
+
+class GeneratorProtocol(Protocol):
+ """Protocol for generator interface."""
+
+ def generate(self, question: str, passages: List[str]) -> str: ...
+
+
+class RAGPipeline:
+ """End-to-end pipeline: query -> encode -> retrieve -> generate.
+
+ Supports both FAISS (baseline) and Hopfield (ours) retrieval.
+ """
+
+ def __init__(
+ self,
+ config: PipelineConfig,
+ encoder: EncoderProtocol,
+ generator: GeneratorProtocol,
+ memory_bank: Optional[MemoryBank] = None,
+ faiss_retriever: Optional[FAISSRetriever] = None,
+ ) -> None:
+ self.config = config
+ self.encoder = encoder
+ self.generator = generator
+
+ if config.retriever_type == "faiss":
+ assert faiss_retriever is not None, "FAISSRetriever required for faiss mode"
+ self.retriever_type = "faiss"
+ self.faiss_retriever = faiss_retriever
+ self.hopfield_retriever: Optional[HopfieldRetriever] = None
+ elif config.retriever_type == "hopfield":
+ assert memory_bank is not None, "MemoryBank required for hopfield mode"
+ hopfield = HopfieldRetrieval(config.hopfield)
+ self.retriever_type = "hopfield"
+ self.hopfield_retriever = HopfieldRetriever(
+ hopfield, memory_bank, top_k=config.hopfield.top_k
+ )
+ self.faiss_retriever = None
+ else:
+ raise ValueError(f"Unknown retriever_type: {config.retriever_type}")
+
+ def run(self, question: str) -> PipelineResult:
+ """Run the full pipeline on a single question.
+
+ 1. Encode question -> query embedding
+ 2. Retrieve passages (FAISS or Hopfield)
+ 3. Generate answer with LLM
+
+ Args:
+ question: input question string
+
+ Returns:
+ PipelineResult with answer and retrieval metadata.
+ """
+ # Encode
+ query_emb = self.encoder.encode(question) # (1, d)
+
+ # Retrieve
+ if self.retriever_type == "hopfield":
+ retrieval_result = self.hopfield_retriever.retrieve(query_emb)
+ else:
+ query_np = query_emb.detach().numpy().astype(np.float32)
+ retrieval_result = self.faiss_retriever.retrieve(query_np)
+
+ # Generate
+ answer = self.generator.generate(question, retrieval_result.passages)
+
+ return PipelineResult(
+ question=question,
+ answer=answer,
+ retrieved_passages=retrieval_result.passages,
+ retrieval_result=retrieval_result,
+ )
+
+ def run_batch(self, questions: List[str]) -> List[PipelineResult]:
+ """Run pipeline on a batch of questions.
+
+ Args:
+ questions: list of question strings
+
+ Returns:
+ List of PipelineResult, one per question.
+ """
+ return [self.run(q) for q in questions]