summaryrefslogtreecommitdiff
path: root/scripts/run_comparison.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
committerYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
commit09d50e47860da0035e178a442dc936028808a0b3 (patch)
tree9d651b0c7d289a9a0405953f2da989a3c431f147 /scripts/run_comparison.py
parentc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff)
Add memory centering, grid search experiments, and energy visualizationsHEADmaster
- Add centering support to MemoryBank (center_query, apply_centering, mean persistence in save/load) to remove centroid attractor in Hopfield dynamics - Add center flag to MemoryBankConfig, device field to PipelineConfig - Grid search scripts: initial (β≤8), residual, high-β, and centered grids with dedup-based LLM caching (89-91% call savings) - Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap comparing centered vs uncentered dynamics - Experiment log (note.md) documenting 4 rounds of results and root cause analysis of centroid attractor problem - Key finding: β_critical ≈ 37.6 for centered memory; best configs beat FAISS baseline by +3-4% F1 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/run_comparison.py')
-rw-r--r--scripts/run_comparison.py186
1 files changed, 186 insertions, 0 deletions
diff --git a/scripts/run_comparison.py b/scripts/run_comparison.py
new file mode 100644
index 0000000..29f23f8
--- /dev/null
+++ b/scripts/run_comparison.py
@@ -0,0 +1,186 @@
+"""Run side-by-side comparison of FAISS (baseline) vs Hopfield (HAG) retrieval.
+
+Usage:
+ CUDA_VISIBLE_DEVICES=1 python scripts/run_comparison.py \
+ --config configs/hotpotqa.yaml \
+ --memory-bank data/processed/hotpotqa_memory_bank.pt \
+ --questions data/processed/hotpotqa_questions.jsonl \
+ --device cuda \
+ --max-samples 500
+"""
+
+import argparse
+import json
+import logging
+import time
+
+import numpy as np
+import torch
+import yaml
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+from hag.metrics import evaluate_dataset, exact_match, f1_score
+from hag.pipeline import RAGPipeline
+from hag.retriever_faiss import FAISSRetriever
+from hag.retriever_hopfield import HopfieldRetriever
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Compare FAISS vs Hopfield retrieval")
+ parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--max-samples", type=int, default=None)
+ parser.add_argument("--output", type=str, default="data/processed/comparison_results.json")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ hopfield_config = HopfieldConfig(**cfg.get("hopfield", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+ generator_config = GeneratorConfig(**cfg.get("generator", {}))
+
+ # Load questions
+ with open(args.questions) as f:
+ questions_data = [json.loads(line) for line in f]
+ if args.max_samples and len(questions_data) > args.max_samples:
+ questions_data = questions_data[: args.max_samples]
+
+ questions = [q["question"] for q in questions_data]
+ gold_answers = [q["answer"] for q in questions_data]
+ logger.info("Loaded %d questions", len(questions))
+
+ # Load memory bank
+ mb = MemoryBank(memory_config)
+ mb.load(args.memory_bank, device=args.device)
+ logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim)
+
+ # Shared encoder and generator
+ encoder = Encoder(encoder_config, device=args.device)
+ generator = Generator(generator_config, device=args.device)
+
+ # --- Build FAISS retriever ---
+ embeddings_np = mb.embeddings.T.cpu().numpy().astype(np.float32) # (N, d)
+ faiss_retriever = FAISSRetriever(top_k=hopfield_config.top_k)
+ faiss_retriever.build_index(embeddings_np, mb.passages)
+
+ # --- Build Hopfield retriever ---
+ hopfield = HopfieldRetrieval(hopfield_config)
+ hopfield_retriever = HopfieldRetriever(hopfield, mb, top_k=hopfield_config.top_k)
+
+ # --- Build pipelines ---
+ faiss_pipeline_cfg = PipelineConfig(
+ hopfield=hopfield_config,
+ memory=memory_config,
+ encoder=encoder_config,
+ generator=generator_config,
+ retriever_type="faiss",
+ device=args.device,
+ )
+ faiss_pipeline = RAGPipeline(
+ config=faiss_pipeline_cfg,
+ encoder=encoder,
+ generator=generator,
+ faiss_retriever=faiss_retriever,
+ )
+
+ hopfield_pipeline_cfg = PipelineConfig(
+ hopfield=hopfield_config,
+ memory=memory_config,
+ encoder=encoder_config,
+ generator=generator_config,
+ retriever_type="hopfield",
+ device=args.device,
+ )
+ hopfield_pipeline = RAGPipeline(
+ config=hopfield_pipeline_cfg,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+
+ # --- Run FAISS baseline ---
+ logger.info("=" * 60)
+ logger.info("Running FAISS baseline (%d questions)...", len(questions))
+ t0 = time.time()
+ faiss_results = faiss_pipeline.run_batch(questions)
+ faiss_time = time.time() - t0
+ faiss_metrics = evaluate_dataset(faiss_results, gold_answers)
+ logger.info("FAISS done in %.1fs | EM=%.4f | F1=%.4f", faiss_time, faiss_metrics["em"], faiss_metrics["f1"])
+
+ # --- Run HAG ---
+ logger.info("=" * 60)
+ logger.info("Running HAG (beta=%.1f, max_iter=%d, top_k=%d) (%d questions)...",
+ hopfield_config.beta, hopfield_config.max_iter, hopfield_config.top_k, len(questions))
+ t0 = time.time()
+ hag_results = hopfield_pipeline.run_batch(questions)
+ hag_time = time.time() - t0
+ hag_metrics = evaluate_dataset(hag_results, gold_answers)
+ logger.info("HAG done in %.1fs | EM=%.4f | F1=%.4f", hag_time, hag_metrics["em"], hag_metrics["f1"])
+
+ # --- Summary ---
+ logger.info("=" * 60)
+ logger.info("COMPARISON SUMMARY")
+ logger.info("%-20s %10s %10s", "", "FAISS", "HAG")
+ logger.info("%-20s %10.4f %10.4f", "Exact Match", faiss_metrics["em"], hag_metrics["em"])
+ logger.info("%-20s %10.4f %10.4f", "F1 Score", faiss_metrics["f1"], hag_metrics["f1"])
+ logger.info("%-20s %10.1fs %10.1fs", "Time", faiss_time, hag_time)
+ em_delta = hag_metrics["em"] - faiss_metrics["em"]
+ f1_delta = hag_metrics["f1"] - faiss_metrics["f1"]
+ logger.info("%-20s %+10.4f %+10.4f", "Delta (HAG - FAISS)", em_delta, f1_delta)
+
+ # --- Per-question details ---
+ per_question = []
+ for i, (fq, hq, gold) in enumerate(zip(faiss_results, hag_results, gold_answers)):
+ per_question.append({
+ "id": questions_data[i].get("id", i),
+ "question": questions[i],
+ "gold_answer": gold,
+ "faiss_answer": fq.answer,
+ "hag_answer": hq.answer,
+ "faiss_em": exact_match(fq.answer, gold),
+ "hag_em": exact_match(hq.answer, gold),
+ "faiss_f1": f1_score(fq.answer, gold),
+ "hag_f1": f1_score(hq.answer, gold),
+ "faiss_passages": fq.retrieved_passages,
+ "hag_passages": hq.retrieved_passages,
+ })
+
+ output = {
+ "config": {
+ "hopfield_beta": hopfield_config.beta,
+ "hopfield_max_iter": hopfield_config.max_iter,
+ "top_k": hopfield_config.top_k,
+ "encoder": encoder_config.model_name,
+ "generator": generator_config.model_name,
+ "num_questions": len(questions),
+ "num_passages": mb.size,
+ },
+ "faiss_metrics": {**faiss_metrics, "time_seconds": faiss_time},
+ "hag_metrics": {**hag_metrics, "time_seconds": hag_time},
+ "per_question": per_question,
+ }
+
+ with open(args.output, "w") as f:
+ json.dump(output, f, indent=2, ensure_ascii=False)
+ logger.info("Full results saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()