summaryrefslogtreecommitdiff
path: root/tests/test_pipeline.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /tests/test_pipeline.py
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'tests/test_pipeline.py')
-rw-r--r--tests/test_pipeline.py132
1 files changed, 132 insertions, 0 deletions
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
new file mode 100644
index 0000000..3ac81ce
--- /dev/null
+++ b/tests/test_pipeline.py
@@ -0,0 +1,132 @@
+"""Integration tests for the RAG/HAG pipeline using mock encoder and generator."""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from hag.config import (
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import FakeEncoder
+from hag.generator import FakeGenerator
+from hag.memory_bank import MemoryBank
+from hag.pipeline import RAGPipeline
+from hag.retriever_faiss import FAISSRetriever
+
+
+def _make_memory_bank(n: int = 20, d: int = 64) -> MemoryBank:
+ """Create a small memory bank with random embeddings."""
+ torch.manual_seed(123)
+ config = MemoryBankConfig(embedding_dim=d, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(n, d)
+ passages = [f"passage {i} content" for i in range(n)]
+ mb.build_from_embeddings(embeddings, passages)
+ return mb
+
+
+def _make_faiss_retriever(n: int = 20, d: int = 64) -> FAISSRetriever:
+ """Create a small FAISS retriever with random embeddings."""
+ np.random.seed(123)
+ embeddings = np.random.randn(n, d).astype(np.float32)
+ passages = [f"passage {i} content" for i in range(n)]
+ retriever = FAISSRetriever(top_k=3)
+ retriever.build_index(embeddings, passages)
+ return retriever
+
+
+class TestPipeline:
+ """Integration tests using mock encoder and generator."""
+
+ def test_pipeline_runs_end_to_end_hopfield(self) -> None:
+ """Mock encoder + mock LLM + Hopfield retriever -> produces an answer."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3),
+ memory=MemoryBankConfig(embedding_dim=64, normalize=True),
+ retriever_type="hopfield",
+ )
+ mb = _make_memory_bank(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+ result = pipeline.run("What is the capital of France?")
+
+ assert isinstance(result.answer, str)
+ assert len(result.answer) > 0
+ assert len(result.retrieved_passages) == 3
+
+ def test_pipeline_runs_end_to_end_faiss(self) -> None:
+ """Mock encoder + mock LLM + FAISS retriever -> produces an answer."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(top_k=3),
+ retriever_type="faiss",
+ )
+ faiss_ret = _make_faiss_retriever(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ faiss_retriever=faiss_ret,
+ )
+ result = pipeline.run("What is the capital of France?")
+
+ assert isinstance(result.answer, str)
+ assert len(result.answer) > 0
+ assert len(result.retrieved_passages) == 3
+
+ def test_pipeline_retrieval_result_included(self) -> None:
+ """PipelineResult should contain retrieval metadata."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3),
+ memory=MemoryBankConfig(embedding_dim=64, normalize=True),
+ retriever_type="hopfield",
+ )
+ mb = _make_memory_bank(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+ result = pipeline.run("Test question?")
+
+ assert result.retrieval_result is not None
+ assert result.retrieval_result.scores is not None
+ assert result.retrieval_result.indices is not None
+ assert result.question == "Test question?"
+
+ def test_pipeline_batch(self) -> None:
+ """run_batch should return results for each question."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3),
+ memory=MemoryBankConfig(embedding_dim=64, normalize=True),
+ retriever_type="hopfield",
+ )
+ mb = _make_memory_bank(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+ questions = ["Q1?", "Q2?", "Q3?"]
+ results = pipeline.run_batch(questions)
+ assert len(results) == 3
+ assert all(r.answer == "mock answer" for r in results)