diff options
Diffstat (limited to 'scripts/run_comparison.py')
| -rw-r--r-- | scripts/run_comparison.py | 186 |
1 files changed, 186 insertions, 0 deletions
diff --git a/scripts/run_comparison.py b/scripts/run_comparison.py new file mode 100644 index 0000000..29f23f8 --- /dev/null +++ b/scripts/run_comparison.py @@ -0,0 +1,186 @@ +"""Run side-by-side comparison of FAISS (baseline) vs Hopfield (HAG) retrieval. + +Usage: + CUDA_VISIBLE_DEVICES=1 python scripts/run_comparison.py \ + --config configs/hotpotqa.yaml \ + --memory-bank data/processed/hotpotqa_memory_bank.pt \ + --questions data/processed/hotpotqa_questions.jsonl \ + --device cuda \ + --max-samples 500 +""" + +import argparse +import json +import logging +import time + +import numpy as np +import torch +import yaml + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import Encoder +from hag.generator import Generator +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank +from hag.metrics import evaluate_dataset, exact_match, f1_score +from hag.pipeline import RAGPipeline +from hag.retriever_faiss import FAISSRetriever +from hag.retriever_hopfield import HopfieldRetriever + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Compare FAISS vs Hopfield retrieval") + parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--max-samples", type=int, default=None) + parser.add_argument("--output", type=str, default="data/processed/comparison_results.json") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + generator_config = GeneratorConfig(**cfg.get("generator", {})) + + # Load questions + with open(args.questions) as f: + questions_data = [json.loads(line) for line in f] + if args.max_samples and len(questions_data) > args.max_samples: + questions_data = questions_data[: args.max_samples] + + questions = [q["question"] for q in questions_data] + gold_answers = [q["answer"] for q in questions_data] + logger.info("Loaded %d questions", len(questions)) + + # Load memory bank + mb = MemoryBank(memory_config) + mb.load(args.memory_bank, device=args.device) + logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim) + + # Shared encoder and generator + encoder = Encoder(encoder_config, device=args.device) + generator = Generator(generator_config, device=args.device) + + # --- Build FAISS retriever --- + embeddings_np = mb.embeddings.T.cpu().numpy().astype(np.float32) # (N, d) + faiss_retriever = FAISSRetriever(top_k=hopfield_config.top_k) + faiss_retriever.build_index(embeddings_np, mb.passages) + + # --- Build Hopfield retriever --- + hopfield = HopfieldRetrieval(hopfield_config) + hopfield_retriever = HopfieldRetriever(hopfield, mb, top_k=hopfield_config.top_k) + + # --- Build pipelines --- + faiss_pipeline_cfg = PipelineConfig( + hopfield=hopfield_config, + memory=memory_config, + encoder=encoder_config, + generator=generator_config, + retriever_type="faiss", + device=args.device, + ) + faiss_pipeline = RAGPipeline( + config=faiss_pipeline_cfg, + encoder=encoder, + generator=generator, + faiss_retriever=faiss_retriever, + ) + + hopfield_pipeline_cfg = PipelineConfig( + hopfield=hopfield_config, + memory=memory_config, + encoder=encoder_config, + generator=generator_config, + retriever_type="hopfield", + device=args.device, + ) + hopfield_pipeline = RAGPipeline( + config=hopfield_pipeline_cfg, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + + # --- Run FAISS baseline --- + logger.info("=" * 60) + logger.info("Running FAISS baseline (%d questions)...", len(questions)) + t0 = time.time() + faiss_results = faiss_pipeline.run_batch(questions) + faiss_time = time.time() - t0 + faiss_metrics = evaluate_dataset(faiss_results, gold_answers) + logger.info("FAISS done in %.1fs | EM=%.4f | F1=%.4f", faiss_time, faiss_metrics["em"], faiss_metrics["f1"]) + + # --- Run HAG --- + logger.info("=" * 60) + logger.info("Running HAG (beta=%.1f, max_iter=%d, top_k=%d) (%d questions)...", + hopfield_config.beta, hopfield_config.max_iter, hopfield_config.top_k, len(questions)) + t0 = time.time() + hag_results = hopfield_pipeline.run_batch(questions) + hag_time = time.time() - t0 + hag_metrics = evaluate_dataset(hag_results, gold_answers) + logger.info("HAG done in %.1fs | EM=%.4f | F1=%.4f", hag_time, hag_metrics["em"], hag_metrics["f1"]) + + # --- Summary --- + logger.info("=" * 60) + logger.info("COMPARISON SUMMARY") + logger.info("%-20s %10s %10s", "", "FAISS", "HAG") + logger.info("%-20s %10.4f %10.4f", "Exact Match", faiss_metrics["em"], hag_metrics["em"]) + logger.info("%-20s %10.4f %10.4f", "F1 Score", faiss_metrics["f1"], hag_metrics["f1"]) + logger.info("%-20s %10.1fs %10.1fs", "Time", faiss_time, hag_time) + em_delta = hag_metrics["em"] - faiss_metrics["em"] + f1_delta = hag_metrics["f1"] - faiss_metrics["f1"] + logger.info("%-20s %+10.4f %+10.4f", "Delta (HAG - FAISS)", em_delta, f1_delta) + + # --- Per-question details --- + per_question = [] + for i, (fq, hq, gold) in enumerate(zip(faiss_results, hag_results, gold_answers)): + per_question.append({ + "id": questions_data[i].get("id", i), + "question": questions[i], + "gold_answer": gold, + "faiss_answer": fq.answer, + "hag_answer": hq.answer, + "faiss_em": exact_match(fq.answer, gold), + "hag_em": exact_match(hq.answer, gold), + "faiss_f1": f1_score(fq.answer, gold), + "hag_f1": f1_score(hq.answer, gold), + "faiss_passages": fq.retrieved_passages, + "hag_passages": hq.retrieved_passages, + }) + + output = { + "config": { + "hopfield_beta": hopfield_config.beta, + "hopfield_max_iter": hopfield_config.max_iter, + "top_k": hopfield_config.top_k, + "encoder": encoder_config.model_name, + "generator": generator_config.model_name, + "num_questions": len(questions), + "num_passages": mb.size, + }, + "faiss_metrics": {**faiss_metrics, "time_seconds": faiss_time}, + "hag_metrics": {**hag_metrics, "time_seconds": hag_time}, + "per_question": per_question, + } + + with open(args.output, "w") as f: + json.dump(output, f, indent=2, ensure_ascii=False) + logger.info("Full results saved to %s", args.output) + + +if __name__ == "__main__": + main() |
