"""End-to-end RAG/HAG pipeline: query -> encode -> retrieve -> generate.""" import logging from typing import List, Optional, Protocol, Union import numpy as np import torch from hag.config import PipelineConfig from hag.datatypes import PipelineResult, RetrievalResult from hag.hopfield import HopfieldRetrieval from hag.memory_bank import MemoryBank from hag.retriever_faiss import FAISSRetriever from hag.retriever_hopfield import HopfieldRetriever logger = logging.getLogger(__name__) class EncoderProtocol(Protocol): """Protocol for encoder interface.""" def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: ... class GeneratorProtocol(Protocol): """Protocol for generator interface.""" def generate(self, question: str, passages: List[str]) -> str: ... class RAGPipeline: """End-to-end pipeline: query -> encode -> retrieve -> generate. Supports both FAISS (baseline) and Hopfield (ours) retrieval. """ def __init__( self, config: PipelineConfig, encoder: EncoderProtocol, generator: GeneratorProtocol, memory_bank: Optional[MemoryBank] = None, faiss_retriever: Optional[FAISSRetriever] = None, ) -> None: self.config = config self.encoder = encoder self.generator = generator if config.retriever_type == "faiss": assert faiss_retriever is not None, "FAISSRetriever required for faiss mode" self.retriever_type = "faiss" self.faiss_retriever = faiss_retriever self.hopfield_retriever: Optional[HopfieldRetriever] = None elif config.retriever_type == "hopfield": assert memory_bank is not None, "MemoryBank required for hopfield mode" hopfield = HopfieldRetrieval(config.hopfield) self.retriever_type = "hopfield" self.hopfield_retriever = HopfieldRetriever( hopfield, memory_bank, top_k=config.hopfield.top_k ) self.faiss_retriever = None else: raise ValueError(f"Unknown retriever_type: {config.retriever_type}") def run(self, question: str) -> PipelineResult: """Run the full pipeline on a single question. 1. Encode question -> query embedding 2. Retrieve passages (FAISS or Hopfield) 3. Generate answer with LLM Args: question: input question string Returns: PipelineResult with answer and retrieval metadata. """ # Encode query_emb = self.encoder.encode(question) # (1, d) # Retrieve if self.retriever_type == "hopfield": retrieval_result = self.hopfield_retriever.retrieve(query_emb) else: query_np = query_emb.detach().numpy().astype(np.float32) retrieval_result = self.faiss_retriever.retrieve(query_np) # Generate answer = self.generator.generate(question, retrieval_result.passages) return PipelineResult( question=question, answer=answer, retrieved_passages=retrieval_result.passages, retrieval_result=retrieval_result, ) def run_batch(self, questions: List[str]) -> List[PipelineResult]: """Run pipeline on a batch of questions. Args: questions: list of question strings Returns: List of PipelineResult, one per question. """ return [self.run(q) for q in questions]