diff options
Diffstat (limited to 'tests/test_pipeline.py')
| -rw-r--r-- | tests/test_pipeline.py | 132 |
1 files changed, 132 insertions, 0 deletions
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..3ac81ce --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,132 @@ +"""Integration tests for the RAG/HAG pipeline using mock encoder and generator.""" + +import numpy as np +import torch +import torch.nn.functional as F + +from hag.config import ( + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import FakeEncoder +from hag.generator import FakeGenerator +from hag.memory_bank import MemoryBank +from hag.pipeline import RAGPipeline +from hag.retriever_faiss import FAISSRetriever + + +def _make_memory_bank(n: int = 20, d: int = 64) -> MemoryBank: + """Create a small memory bank with random embeddings.""" + torch.manual_seed(123) + config = MemoryBankConfig(embedding_dim=d, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(n, d) + passages = [f"passage {i} content" for i in range(n)] + mb.build_from_embeddings(embeddings, passages) + return mb + + +def _make_faiss_retriever(n: int = 20, d: int = 64) -> FAISSRetriever: + """Create a small FAISS retriever with random embeddings.""" + np.random.seed(123) + embeddings = np.random.randn(n, d).astype(np.float32) + passages = [f"passage {i} content" for i in range(n)] + retriever = FAISSRetriever(top_k=3) + retriever.build_index(embeddings, passages) + return retriever + + +class TestPipeline: + """Integration tests using mock encoder and generator.""" + + def test_pipeline_runs_end_to_end_hopfield(self) -> None: + """Mock encoder + mock LLM + Hopfield retriever -> produces an answer.""" + config = PipelineConfig( + hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3), + memory=MemoryBankConfig(embedding_dim=64, normalize=True), + retriever_type="hopfield", + ) + mb = _make_memory_bank(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + result = pipeline.run("What is the capital of France?") + + assert isinstance(result.answer, str) + assert len(result.answer) > 0 + assert len(result.retrieved_passages) == 3 + + def test_pipeline_runs_end_to_end_faiss(self) -> None: + """Mock encoder + mock LLM + FAISS retriever -> produces an answer.""" + config = PipelineConfig( + hopfield=HopfieldConfig(top_k=3), + retriever_type="faiss", + ) + faiss_ret = _make_faiss_retriever(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + faiss_retriever=faiss_ret, + ) + result = pipeline.run("What is the capital of France?") + + assert isinstance(result.answer, str) + assert len(result.answer) > 0 + assert len(result.retrieved_passages) == 3 + + def test_pipeline_retrieval_result_included(self) -> None: + """PipelineResult should contain retrieval metadata.""" + config = PipelineConfig( + hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3), + memory=MemoryBankConfig(embedding_dim=64, normalize=True), + retriever_type="hopfield", + ) + mb = _make_memory_bank(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + result = pipeline.run("Test question?") + + assert result.retrieval_result is not None + assert result.retrieval_result.scores is not None + assert result.retrieval_result.indices is not None + assert result.question == "Test question?" + + def test_pipeline_batch(self) -> None: + """run_batch should return results for each question.""" + config = PipelineConfig( + hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3), + memory=MemoryBankConfig(embedding_dim=64, normalize=True), + retriever_type="hopfield", + ) + mb = _make_memory_bank(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + questions = ["Q1?", "Q2?", "Q3?"] + results = pipeline.run_batch(questions) + assert len(results) == 3 + assert all(r.answer == "mock answer" for r in results) |
