From c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Sun, 15 Feb 2026 18:19:50 +0000 Subject: Initial implementation of HAG (Hopfield-Augmented Generation) Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 --- hag/pipeline.py | 107 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 107 insertions(+) create mode 100644 hag/pipeline.py (limited to 'hag/pipeline.py') diff --git a/hag/pipeline.py b/hag/pipeline.py new file mode 100644 index 0000000..1fefb84 --- /dev/null +++ b/hag/pipeline.py @@ -0,0 +1,107 @@ +"""End-to-end RAG/HAG pipeline: query -> encode -> retrieve -> generate.""" + +import logging +from typing import List, Optional, Protocol, Union + +import numpy as np +import torch + +from hag.config import PipelineConfig +from hag.datatypes import PipelineResult, RetrievalResult +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank +from hag.retriever_faiss import FAISSRetriever +from hag.retriever_hopfield import HopfieldRetriever + +logger = logging.getLogger(__name__) + + +class EncoderProtocol(Protocol): + """Protocol for encoder interface.""" + + def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: ... + + +class GeneratorProtocol(Protocol): + """Protocol for generator interface.""" + + def generate(self, question: str, passages: List[str]) -> str: ... + + +class RAGPipeline: + """End-to-end pipeline: query -> encode -> retrieve -> generate. + + Supports both FAISS (baseline) and Hopfield (ours) retrieval. + """ + + def __init__( + self, + config: PipelineConfig, + encoder: EncoderProtocol, + generator: GeneratorProtocol, + memory_bank: Optional[MemoryBank] = None, + faiss_retriever: Optional[FAISSRetriever] = None, + ) -> None: + self.config = config + self.encoder = encoder + self.generator = generator + + if config.retriever_type == "faiss": + assert faiss_retriever is not None, "FAISSRetriever required for faiss mode" + self.retriever_type = "faiss" + self.faiss_retriever = faiss_retriever + self.hopfield_retriever: Optional[HopfieldRetriever] = None + elif config.retriever_type == "hopfield": + assert memory_bank is not None, "MemoryBank required for hopfield mode" + hopfield = HopfieldRetrieval(config.hopfield) + self.retriever_type = "hopfield" + self.hopfield_retriever = HopfieldRetriever( + hopfield, memory_bank, top_k=config.hopfield.top_k + ) + self.faiss_retriever = None + else: + raise ValueError(f"Unknown retriever_type: {config.retriever_type}") + + def run(self, question: str) -> PipelineResult: + """Run the full pipeline on a single question. + + 1. Encode question -> query embedding + 2. Retrieve passages (FAISS or Hopfield) + 3. Generate answer with LLM + + Args: + question: input question string + + Returns: + PipelineResult with answer and retrieval metadata. + """ + # Encode + query_emb = self.encoder.encode(question) # (1, d) + + # Retrieve + if self.retriever_type == "hopfield": + retrieval_result = self.hopfield_retriever.retrieve(query_emb) + else: + query_np = query_emb.detach().numpy().astype(np.float32) + retrieval_result = self.faiss_retriever.retrieve(query_np) + + # Generate + answer = self.generator.generate(question, retrieval_result.passages) + + return PipelineResult( + question=question, + answer=answer, + retrieved_passages=retrieval_result.passages, + retrieval_result=retrieval_result, + ) + + def run_batch(self, questions: List[str]) -> List[PipelineResult]: + """Run pipeline on a batch of questions. + + Args: + questions: list of question strings + + Returns: + List of PipelineResult, one per question. + """ + return [self.run(q) for q in questions] -- cgit v1.2.3