"""Run vanilla RAG baseline with FAISS retrieval. Usage: python scripts/run_baseline.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" """ import argparse import logging import torch import yaml from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, PipelineConfig from hag.encoder import Encoder from hag.generator import Generator from hag.memory_bank import MemoryBank from hag.retriever_faiss import FAISSRetriever from hag.pipeline import RAGPipeline logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) def main() -> None: parser = argparse.ArgumentParser(description="Run vanilla RAG baseline") parser.add_argument("--config", type=str, default="configs/default.yaml") parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--question", type=str, required=True) parser.add_argument("--top-k", type=int, default=5) args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) # Override retriever type to faiss pipeline_config = PipelineConfig( hopfield=HopfieldConfig(**{**cfg.get("hopfield", {}), "top_k": args.top_k}), encoder=EncoderConfig(**cfg.get("encoder", {})), generator=GeneratorConfig(**cfg.get("generator", {})), retriever_type="faiss", ) # Load memory bank to get embeddings for FAISS from hag.config import MemoryBankConfig mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) mb.load(args.memory_bank) # Build FAISS index from memory bank embeddings import numpy as np embeddings_np = mb.embeddings.T.numpy().astype(np.float32) # (N, d) faiss_ret = FAISSRetriever(top_k=args.top_k) faiss_ret.build_index(embeddings_np, mb.passages) encoder = Encoder(pipeline_config.encoder) generator = Generator(pipeline_config.generator) pipeline = RAGPipeline( config=pipeline_config, encoder=encoder, generator=generator, faiss_retriever=faiss_ret, ) result = pipeline.run(args.question) print(f"\nQuestion: {result.question}") print(f"Answer: {result.answer}") print(f"\nRetrieved passages:") for i, p in enumerate(result.retrieved_passages): print(f" [{i+1}] {p[:200]}...") if __name__ == "__main__": main()