diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /scripts/run_baseline.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/run_baseline.py')
| -rw-r--r-- | scripts/run_baseline.py | 75 |
1 files changed, 75 insertions, 0 deletions
diff --git a/scripts/run_baseline.py b/scripts/run_baseline.py new file mode 100644 index 0000000..74c4710 --- /dev/null +++ b/scripts/run_baseline.py @@ -0,0 +1,75 @@ +"""Run vanilla RAG baseline with FAISS retrieval. + +Usage: + python scripts/run_baseline.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" +""" + +import argparse +import logging + +import torch +import yaml + +from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, PipelineConfig +from hag.encoder import Encoder +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.retriever_faiss import FAISSRetriever +from hag.pipeline import RAGPipeline + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run vanilla RAG baseline") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--question", type=str, required=True) + parser.add_argument("--top-k", type=int, default=5) + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + # Override retriever type to faiss + pipeline_config = PipelineConfig( + hopfield=HopfieldConfig(**{**cfg.get("hopfield", {}), "top_k": args.top_k}), + encoder=EncoderConfig(**cfg.get("encoder", {})), + generator=GeneratorConfig(**cfg.get("generator", {})), + retriever_type="faiss", + ) + + # Load memory bank to get embeddings for FAISS + from hag.config import MemoryBankConfig + + mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) + mb.load(args.memory_bank) + + # Build FAISS index from memory bank embeddings + import numpy as np + + embeddings_np = mb.embeddings.T.numpy().astype(np.float32) # (N, d) + faiss_ret = FAISSRetriever(top_k=args.top_k) + faiss_ret.build_index(embeddings_np, mb.passages) + + encoder = Encoder(pipeline_config.encoder) + generator = Generator(pipeline_config.generator) + + pipeline = RAGPipeline( + config=pipeline_config, + encoder=encoder, + generator=generator, + faiss_retriever=faiss_ret, + ) + + result = pipeline.run(args.question) + print(f"\nQuestion: {result.question}") + print(f"Answer: {result.answer}") + print(f"\nRetrieved passages:") + for i, p in enumerate(result.retrieved_passages): + print(f" [{i+1}] {p[:200]}...") + + +if __name__ == "__main__": + main() |
