"""Run side-by-side comparison of FAISS (baseline) vs Hopfield (HAG) retrieval. Usage: CUDA_VISIBLE_DEVICES=1 python scripts/run_comparison.py \ --config configs/hotpotqa.yaml \ --memory-bank data/processed/hotpotqa_memory_bank.pt \ --questions data/processed/hotpotqa_questions.jsonl \ --device cuda \ --max-samples 500 """ import argparse import json import logging import time import numpy as np import torch import yaml from hag.config import ( EncoderConfig, GeneratorConfig, HopfieldConfig, MemoryBankConfig, PipelineConfig, ) from hag.encoder import Encoder from hag.generator import Generator from hag.hopfield import HopfieldRetrieval from hag.memory_bank import MemoryBank from hag.metrics import evaluate_dataset, exact_match, f1_score from hag.pipeline import RAGPipeline from hag.retriever_faiss import FAISSRetriever from hag.retriever_hopfield import HopfieldRetriever logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logger = logging.getLogger(__name__) def main() -> None: parser = argparse.ArgumentParser(description="Compare FAISS vs Hopfield retrieval") parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--questions", type=str, required=True) parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--max-samples", type=int, default=None) parser.add_argument("--output", type=str, default="data/processed/comparison_results.json") args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) memory_config = MemoryBankConfig(**cfg.get("memory", {})) encoder_config = EncoderConfig(**cfg.get("encoder", {})) generator_config = GeneratorConfig(**cfg.get("generator", {})) # Load questions with open(args.questions) as f: questions_data = [json.loads(line) for line in f] if args.max_samples and len(questions_data) > args.max_samples: questions_data = questions_data[: args.max_samples] questions = [q["question"] for q in questions_data] gold_answers = [q["answer"] for q in questions_data] logger.info("Loaded %d questions", len(questions)) # Load memory bank mb = MemoryBank(memory_config) mb.load(args.memory_bank, device=args.device) logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim) # Shared encoder and generator encoder = Encoder(encoder_config, device=args.device) generator = Generator(generator_config, device=args.device) # --- Build FAISS retriever --- embeddings_np = mb.embeddings.T.cpu().numpy().astype(np.float32) # (N, d) faiss_retriever = FAISSRetriever(top_k=hopfield_config.top_k) faiss_retriever.build_index(embeddings_np, mb.passages) # --- Build Hopfield retriever --- hopfield = HopfieldRetrieval(hopfield_config) hopfield_retriever = HopfieldRetriever(hopfield, mb, top_k=hopfield_config.top_k) # --- Build pipelines --- faiss_pipeline_cfg = PipelineConfig( hopfield=hopfield_config, memory=memory_config, encoder=encoder_config, generator=generator_config, retriever_type="faiss", device=args.device, ) faiss_pipeline = RAGPipeline( config=faiss_pipeline_cfg, encoder=encoder, generator=generator, faiss_retriever=faiss_retriever, ) hopfield_pipeline_cfg = PipelineConfig( hopfield=hopfield_config, memory=memory_config, encoder=encoder_config, generator=generator_config, retriever_type="hopfield", device=args.device, ) hopfield_pipeline = RAGPipeline( config=hopfield_pipeline_cfg, encoder=encoder, generator=generator, memory_bank=mb, ) # --- Run FAISS baseline --- logger.info("=" * 60) logger.info("Running FAISS baseline (%d questions)...", len(questions)) t0 = time.time() faiss_results = faiss_pipeline.run_batch(questions) faiss_time = time.time() - t0 faiss_metrics = evaluate_dataset(faiss_results, gold_answers) logger.info("FAISS done in %.1fs | EM=%.4f | F1=%.4f", faiss_time, faiss_metrics["em"], faiss_metrics["f1"]) # --- Run HAG --- logger.info("=" * 60) logger.info("Running HAG (beta=%.1f, max_iter=%d, top_k=%d) (%d questions)...", hopfield_config.beta, hopfield_config.max_iter, hopfield_config.top_k, len(questions)) t0 = time.time() hag_results = hopfield_pipeline.run_batch(questions) hag_time = time.time() - t0 hag_metrics = evaluate_dataset(hag_results, gold_answers) logger.info("HAG done in %.1fs | EM=%.4f | F1=%.4f", hag_time, hag_metrics["em"], hag_metrics["f1"]) # --- Summary --- logger.info("=" * 60) logger.info("COMPARISON SUMMARY") logger.info("%-20s %10s %10s", "", "FAISS", "HAG") logger.info("%-20s %10.4f %10.4f", "Exact Match", faiss_metrics["em"], hag_metrics["em"]) logger.info("%-20s %10.4f %10.4f", "F1 Score", faiss_metrics["f1"], hag_metrics["f1"]) logger.info("%-20s %10.1fs %10.1fs", "Time", faiss_time, hag_time) em_delta = hag_metrics["em"] - faiss_metrics["em"] f1_delta = hag_metrics["f1"] - faiss_metrics["f1"] logger.info("%-20s %+10.4f %+10.4f", "Delta (HAG - FAISS)", em_delta, f1_delta) # --- Per-question details --- per_question = [] for i, (fq, hq, gold) in enumerate(zip(faiss_results, hag_results, gold_answers)): per_question.append({ "id": questions_data[i].get("id", i), "question": questions[i], "gold_answer": gold, "faiss_answer": fq.answer, "hag_answer": hq.answer, "faiss_em": exact_match(fq.answer, gold), "hag_em": exact_match(hq.answer, gold), "faiss_f1": f1_score(fq.answer, gold), "hag_f1": f1_score(hq.answer, gold), "faiss_passages": fq.retrieved_passages, "hag_passages": hq.retrieved_passages, }) output = { "config": { "hopfield_beta": hopfield_config.beta, "hopfield_max_iter": hopfield_config.max_iter, "top_k": hopfield_config.top_k, "encoder": encoder_config.model_name, "generator": generator_config.model_name, "num_questions": len(questions), "num_passages": mb.size, }, "faiss_metrics": {**faiss_metrics, "time_seconds": faiss_time}, "hag_metrics": {**hag_metrics, "time_seconds": hag_time}, "per_question": per_question, } with open(args.output, "w") as f: json.dump(output, f, indent=2, ensure_ascii=False) logger.info("Full results saved to %s", args.output) if __name__ == "__main__": main()