summaryrefslogtreecommitdiff
path: root/hag/pipeline.py
blob: 1fefb84a4af5921b74067c07aafbe958706243b2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
"""End-to-end RAG/HAG pipeline: query -> encode -> retrieve -> generate."""

import logging
from typing import List, Optional, Protocol, Union

import numpy as np
import torch

from hag.config import PipelineConfig
from hag.datatypes import PipelineResult, RetrievalResult
from hag.hopfield import HopfieldRetrieval
from hag.memory_bank import MemoryBank
from hag.retriever_faiss import FAISSRetriever
from hag.retriever_hopfield import HopfieldRetriever

logger = logging.getLogger(__name__)


class EncoderProtocol(Protocol):
    """Protocol for encoder interface."""

    def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: ...


class GeneratorProtocol(Protocol):
    """Protocol for generator interface."""

    def generate(self, question: str, passages: List[str]) -> str: ...


class RAGPipeline:
    """End-to-end pipeline: query -> encode -> retrieve -> generate.

    Supports both FAISS (baseline) and Hopfield (ours) retrieval.
    """

    def __init__(
        self,
        config: PipelineConfig,
        encoder: EncoderProtocol,
        generator: GeneratorProtocol,
        memory_bank: Optional[MemoryBank] = None,
        faiss_retriever: Optional[FAISSRetriever] = None,
    ) -> None:
        self.config = config
        self.encoder = encoder
        self.generator = generator

        if config.retriever_type == "faiss":
            assert faiss_retriever is not None, "FAISSRetriever required for faiss mode"
            self.retriever_type = "faiss"
            self.faiss_retriever = faiss_retriever
            self.hopfield_retriever: Optional[HopfieldRetriever] = None
        elif config.retriever_type == "hopfield":
            assert memory_bank is not None, "MemoryBank required for hopfield mode"
            hopfield = HopfieldRetrieval(config.hopfield)
            self.retriever_type = "hopfield"
            self.hopfield_retriever = HopfieldRetriever(
                hopfield, memory_bank, top_k=config.hopfield.top_k
            )
            self.faiss_retriever = None
        else:
            raise ValueError(f"Unknown retriever_type: {config.retriever_type}")

    def run(self, question: str) -> PipelineResult:
        """Run the full pipeline on a single question.

        1. Encode question -> query embedding
        2. Retrieve passages (FAISS or Hopfield)
        3. Generate answer with LLM

        Args:
            question: input question string

        Returns:
            PipelineResult with answer and retrieval metadata.
        """
        # Encode
        query_emb = self.encoder.encode(question)  # (1, d)

        # Retrieve
        if self.retriever_type == "hopfield":
            retrieval_result = self.hopfield_retriever.retrieve(query_emb)
        else:
            query_np = query_emb.detach().numpy().astype(np.float32)
            retrieval_result = self.faiss_retriever.retrieve(query_np)

        # Generate
        answer = self.generator.generate(question, retrieval_result.passages)

        return PipelineResult(
            question=question,
            answer=answer,
            retrieved_passages=retrieval_result.passages,
            retrieval_result=retrieval_result,
        )

    def run_batch(self, questions: List[str]) -> List[PipelineResult]:
        """Run pipeline on a batch of questions.

        Args:
            questions: list of question strings

        Returns:
            List of PipelineResult, one per question.
        """
        return [self.run(q) for q in questions]