summaryrefslogtreecommitdiff
path: root/scripts/run_grid_search.py
diff options
context:
space:
mode:
authorYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
committerYurenHao0426 <Blackhao0426@gmail.com>2026-02-16 14:44:42 -0600
commit09d50e47860da0035e178a442dc936028808a0b3 (patch)
tree9d651b0c7d289a9a0405953f2da989a3c431f147 /scripts/run_grid_search.py
parentc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff)
Add memory centering, grid search experiments, and energy visualizationsHEADmaster
- Add centering support to MemoryBank (center_query, apply_centering, mean persistence in save/load) to remove centroid attractor in Hopfield dynamics - Add center flag to MemoryBankConfig, device field to PipelineConfig - Grid search scripts: initial (β≤8), residual, high-β, and centered grids with dedup-based LLM caching (89-91% call savings) - Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap comparing centered vs uncentered dynamics - Experiment log (note.md) documenting 4 rounds of results and root cause analysis of centroid attractor problem - Key finding: β_critical ≈ 37.6 for centered memory; best configs beat FAISS baseline by +3-4% F1 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/run_grid_search.py')
-rw-r--r--scripts/run_grid_search.py552
1 files changed, 552 insertions, 0 deletions
diff --git a/scripts/run_grid_search.py b/scripts/run_grid_search.py
new file mode 100644
index 0000000..ddd5a8d
--- /dev/null
+++ b/scripts/run_grid_search.py
@@ -0,0 +1,552 @@
+"""Grid search over HAG hyperparameters (beta, max_iter) with dedup-based LLM caching.
+
+Key insight: many (beta, max_iter) combos retrieve the same top-k passages for a given
+question. By deduplicating on (question_idx, frozenset(top_k_indices)), we call the LLM
+only for unique passage sets, saving ~80-89% of generation calls.
+
+Usage:
+ python scripts/run_grid_search.py \
+ --config configs/hotpotqa.yaml \
+ --memory-bank data/processed/hotpotqa_memory_bank.pt \
+ --questions data/processed/hotpotqa_questions.jsonl \
+ --device cuda \
+ --max-samples 100
+"""
+
+import argparse
+import json
+import logging
+import time
+from dataclasses import dataclass
+from pathlib import Path
+from typing import Dict, List, Optional, Tuple
+
+import numpy as np
+import torch
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import Encoder
+from hag.energy import compute_attention_entropy, compute_energy_curve, compute_energy_gap
+from hag.generator import Generator
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+from hag.metrics import exact_match, f1_score
+from hag.retriever_faiss import FAISSRetriever
+
+logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s")
+logger = logging.getLogger(__name__)
+
+
+@dataclass
+class GridPoint:
+ """Results for a single (beta, max_iter) configuration."""
+
+ beta: float
+ max_iter: int
+ em: float
+ f1: float
+ avg_entropy: float
+ avg_energy_gap: float
+ avg_faiss_overlap: float
+ avg_steps: float
+
+
+def load_questions(path: str, max_samples: Optional[int] = None) -> Tuple[List[str], List[str]]:
+ """Load questions and gold answers from JSONL file.
+
+ Args:
+ path: path to JSONL file with 'question' and 'answer' fields
+ max_samples: if set, limit to first N samples
+
+ Returns:
+ Tuple of (questions, gold_answers).
+ """
+ questions = []
+ gold_answers = []
+ with open(path) as f:
+ for line in f:
+ record = json.loads(line)
+ questions.append(record["question"])
+ gold_answers.append(record["answer"])
+ if max_samples and len(questions) >= max_samples:
+ break
+ return questions, gold_answers
+
+
+def encode_questions_batched(
+ encoder: Encoder, questions: List[str], batch_size: int = 32
+) -> torch.Tensor:
+ """Encode all questions into embeddings, batched for efficiency.
+
+ Args:
+ encoder: the encoder instance
+ questions: list of question strings
+ batch_size: encoding batch size
+
+ Returns:
+ (N, d) tensor of query embeddings.
+ """
+ all_embeddings = []
+ for i in range(0, len(questions), batch_size):
+ batch = questions[i : i + batch_size]
+ embs = encoder.encode(batch) # (batch_size, d)
+ all_embeddings.append(embs)
+ return torch.cat(all_embeddings, dim=0) # (N, d)
+
+
+def run_faiss_baseline(
+ query_embeddings: torch.Tensor,
+ memory_bank: MemoryBank,
+ generator: Generator,
+ questions: List[str],
+ gold_answers: List[str],
+ top_k: int,
+) -> Tuple[Dict[str, float], Dict[int, Tuple[str, Tuple[int, ...]]]]:
+ """Run FAISS baseline and cache results.
+
+ Args:
+ query_embeddings: (N, d) tensor
+ memory_bank: the memory bank
+ generator: LLM generator
+ questions: list of question strings
+ gold_answers: list of gold answer strings
+ top_k: number of passages to retrieve
+
+ Returns:
+ Tuple of (metrics_dict, faiss_cache).
+ faiss_cache maps question_idx -> (answer, top_k_indices_tuple).
+ """
+ logger.info("Building FAISS index...")
+ embeddings_np = memory_bank.embeddings.T.cpu().numpy().astype(np.float32) # (N_passages, d)
+ faiss_ret = FAISSRetriever(top_k=top_k)
+ faiss_ret.build_index(embeddings_np, memory_bank.passages)
+
+ faiss_cache: Dict[int, Tuple[str, Tuple[int, ...]]] = {}
+ em_scores = []
+ f1_scores = []
+
+ logger.info("Running FAISS baseline on %d questions...", len(questions))
+ for i, question in enumerate(questions):
+ query_np = query_embeddings[i].cpu().numpy().astype(np.float32) # (d,)
+ result = faiss_ret.retrieve(query_np)
+ answer = generator.generate(question, result.passages)
+ indices_tuple = tuple(sorted(result.indices.tolist()))
+
+ faiss_cache[i] = (answer, indices_tuple)
+ em_scores.append(exact_match(answer, gold_answers[i]))
+ f1_scores.append(f1_score(answer, gold_answers[i]))
+
+ if (i + 1) % 20 == 0:
+ logger.info(
+ " FAISS baseline: %d/%d (EM=%.3f, F1=%.3f)",
+ i + 1,
+ len(questions),
+ sum(em_scores) / len(em_scores),
+ sum(f1_scores) / len(f1_scores),
+ )
+
+ metrics = {
+ "em": sum(em_scores) / len(em_scores),
+ "f1": sum(f1_scores) / len(f1_scores),
+ }
+ logger.info("FAISS baseline: EM=%.4f, F1=%.4f", metrics["em"], metrics["f1"])
+ return metrics, faiss_cache
+
+
+def run_hopfield_grid(
+ query_embeddings: torch.Tensor,
+ memory_bank: MemoryBank,
+ generator: Generator,
+ questions: List[str],
+ gold_answers: List[str],
+ faiss_cache: Dict[int, Tuple[str, Tuple[int, ...]]],
+ betas: List[float],
+ max_iters: List[int],
+ top_k: int,
+ device: str,
+) -> Tuple[List[GridPoint], Dict]:
+ """Run grid search over (beta, max_iter) with dedup-based LLM caching.
+
+ Phase 2: Retrieve all configs (fast, batched).
+ Phase 3: Deduplicate and generate (LLM calls only for unique passage sets).
+ Phase 4: Evaluate and collect results.
+
+ Args:
+ query_embeddings: (N, d) tensor on device
+ memory_bank: memory bank (embeddings on device)
+ generator: LLM generator
+ questions: list of question strings
+ gold_answers: list of gold answer strings
+ faiss_cache: maps question_idx -> (answer, sorted_indices_tuple)
+ betas: list of beta values to sweep
+ max_iters: list of max_iter values to sweep
+ top_k: fixed top_k for retrieval
+ device: computation device
+
+ Returns:
+ Tuple of (grid_results, meta_dict).
+ """
+ n_questions = len(questions)
+ memory = memory_bank.embeddings # (d, N_passages) on device
+
+ # =========================================================================
+ # Phase 2: Retrieve all configurations (batched, milliseconds each)
+ # =========================================================================
+ logger.info("Phase 2: Running %d retrieval configs...", len(betas) * len(max_iters))
+
+ # Structure: config_key -> per-question retrieval data
+ # retrieval_data[config_key][q_idx] = {indices_tuple, entropy, energy_gap, steps, faiss_overlap}
+ @dataclass
+ class RetrievalInfo:
+ indices_tuple: Tuple[int, ...]
+ entropy: float
+ energy_gap: float
+ steps: int
+ faiss_overlap: float
+
+ retrieval_data: Dict[Tuple[float, int], List[RetrievalInfo]] = {}
+
+ t_retrieve_start = time.time()
+ for beta in betas:
+ for max_iter in max_iters:
+ config = HopfieldConfig(beta=beta, max_iter=max_iter, top_k=top_k)
+ hopfield = HopfieldRetrieval(config)
+
+ # Batched retrieval: all questions at once
+ result = hopfield.retrieve(
+ query_embeddings, memory, return_energy=True
+ ) # attention_weights: (N_questions, N_passages)
+
+ alpha = result.attention_weights # (N_questions, N_passages)
+ k = min(top_k, alpha.shape[-1])
+ scores, indices = torch.topk(alpha, k, dim=-1) # (N_questions, k)
+
+ # Compute energy curve per-question (energy_curve contains batch tensors)
+ energy_curves_raw = result.energy_curve # list of (N_questions,) tensors
+
+ infos = []
+ for q_idx in range(n_questions):
+ q_indices = sorted(indices[q_idx].tolist())
+ q_indices_tuple = tuple(q_indices)
+
+ # Per-question entropy
+ q_entropy = compute_attention_entropy(alpha[q_idx])
+
+ # Per-question energy gap
+ if energy_curves_raw is not None:
+ q_energies = [e[q_idx].item() for e in energy_curves_raw]
+ q_energy_gap = compute_energy_gap(q_energies)
+ else:
+ q_energy_gap = 0.0
+
+ # FAISS overlap: fraction of top-k indices shared with FAISS
+ faiss_indices_set = set(faiss_cache[q_idx][1])
+ hopfield_indices_set = set(q_indices)
+ overlap = len(faiss_indices_set & hopfield_indices_set) / k
+
+ infos.append(RetrievalInfo(
+ indices_tuple=q_indices_tuple,
+ entropy=q_entropy,
+ energy_gap=q_energy_gap,
+ steps=result.num_steps,
+ faiss_overlap=overlap,
+ ))
+
+ retrieval_data[(beta, max_iter)] = infos
+
+ t_retrieve_end = time.time()
+ logger.info("Phase 2 complete: %.2fs for all retrieval configs", t_retrieve_end - t_retrieve_start)
+
+ # =========================================================================
+ # Phase 3: Deduplicate and generate
+ # =========================================================================
+ logger.info("Phase 3: Deduplicating and generating...")
+
+ # Build set of unique (question_idx, passage_set) combos needing LLM calls
+ # Cache key: (question_idx, frozenset(top_k_indices))
+ llm_cache: Dict[Tuple[int, frozenset], str] = {}
+
+ # Seed cache with FAISS answers (same passage sets don't need re-generation)
+ for q_idx, (answer, indices_tuple) in faiss_cache.items():
+ cache_key = (q_idx, frozenset(indices_tuple))
+ llm_cache[cache_key] = answer
+
+ # Collect all unique keys we need
+ needed_keys: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {}
+ for (beta, max_iter), infos in retrieval_data.items():
+ for q_idx, info in enumerate(infos):
+ cache_key = (q_idx, frozenset(info.indices_tuple))
+ if cache_key not in llm_cache and cache_key not in needed_keys:
+ needed_keys[cache_key] = (q_idx, info.indices_tuple)
+
+ total_grid_calls = n_questions * len(betas) * len(max_iters)
+ already_cached = total_grid_calls - len(needed_keys) # rough; some may still be unique
+ logger.info(
+ "Unique LLM calls needed: %d (out of %d grid points, %.1f%% saving)",
+ len(needed_keys),
+ total_grid_calls,
+ (1 - len(needed_keys) / total_grid_calls) * 100 if total_grid_calls > 0 else 0,
+ )
+
+ # Generate answers for unique passage sets
+ t_gen_start = time.time()
+ for call_idx, (cache_key, (q_idx, indices_tuple)) in enumerate(needed_keys.items()):
+ # Look up passages by sorted indices
+ indices_tensor = torch.tensor(list(indices_tuple), dtype=torch.long)
+ passages = memory_bank.get_passages_by_indices(indices_tensor)
+ answer = generator.generate(questions[q_idx], passages)
+ llm_cache[cache_key] = answer
+
+ if (call_idx + 1) % 20 == 0:
+ elapsed = time.time() - t_gen_start
+ rate = (call_idx + 1) / elapsed
+ remaining = (len(needed_keys) - call_idx - 1) / rate
+ logger.info(
+ " Generated %d/%d (%.1f calls/s, ~%.0fs remaining)",
+ call_idx + 1,
+ len(needed_keys),
+ rate,
+ remaining,
+ )
+
+ t_gen_end = time.time()
+ logger.info("Phase 3 complete: %d LLM calls in %.1fs", len(needed_keys), t_gen_end - t_gen_start)
+
+ # =========================================================================
+ # Phase 4: Evaluate all grid points
+ # =========================================================================
+ logger.info("Phase 4: Evaluating all grid points...")
+
+ grid_results: List[GridPoint] = []
+ for beta in betas:
+ for max_iter in max_iters:
+ infos = retrieval_data[(beta, max_iter)]
+ em_scores = []
+ f1_scores = []
+ entropies = []
+ energy_gaps = []
+ faiss_overlaps = []
+ steps_list = []
+
+ for q_idx, info in enumerate(infos):
+ cache_key = (q_idx, frozenset(info.indices_tuple))
+ answer = llm_cache[cache_key]
+
+ em_scores.append(exact_match(answer, gold_answers[q_idx]))
+ f1_scores.append(f1_score(answer, gold_answers[q_idx]))
+ entropies.append(info.entropy)
+ energy_gaps.append(info.energy_gap)
+ faiss_overlaps.append(info.faiss_overlap)
+ steps_list.append(info.steps)
+
+ gp = GridPoint(
+ beta=beta,
+ max_iter=max_iter,
+ em=sum(em_scores) / len(em_scores),
+ f1=sum(f1_scores) / len(f1_scores),
+ avg_entropy=sum(entropies) / len(entropies),
+ avg_energy_gap=sum(energy_gaps) / len(energy_gaps),
+ avg_faiss_overlap=sum(faiss_overlaps) / len(faiss_overlaps),
+ avg_steps=sum(steps_list) / len(steps_list),
+ )
+ grid_results.append(gp)
+ logger.info(
+ " beta=%.2f max_iter=%2d => EM=%.3f F1=%.3f entropy=%.3f energy_gap=%.3f faiss_overlap=%.3f",
+ beta,
+ max_iter,
+ gp.em,
+ gp.f1,
+ gp.avg_entropy,
+ gp.avg_energy_gap,
+ gp.avg_faiss_overlap,
+ )
+
+ total_llm_calls = len(faiss_cache) + len(needed_keys)
+ meta = {
+ "grid_size": len(betas) * len(max_iters),
+ "n_questions": n_questions,
+ "total_grid_evaluations": total_grid_calls,
+ "unique_llm_calls": len(needed_keys),
+ "faiss_llm_calls": len(faiss_cache),
+ "total_llm_calls": total_llm_calls,
+ "savings_pct": round(
+ (1 - total_llm_calls / (total_grid_calls + len(faiss_cache))) * 100, 1
+ )
+ if (total_grid_calls + len(faiss_cache)) > 0
+ else 0,
+ "retrieval_time_s": round(t_retrieve_end - t_retrieve_start, 2),
+ "generation_time_s": round(t_gen_end - t_gen_start, 2),
+ }
+
+ return grid_results, meta
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(
+ description="Grid search over HAG hyperparameters (beta, max_iter)"
+ )
+ parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--max-samples", type=int, default=100)
+ parser.add_argument(
+ "--output",
+ type=str,
+ default=None,
+ help="Output JSON path (default: data/processed/grid_search_results.json)",
+ )
+ parser.add_argument(
+ "--betas",
+ type=float,
+ nargs="+",
+ default=[0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 8.0],
+ )
+ parser.add_argument(
+ "--max-iters",
+ type=int,
+ nargs="+",
+ default=[1, 2, 3, 5, 8, 15],
+ )
+ parser.add_argument("--top-k", type=int, default=5)
+ args = parser.parse_args()
+
+ import yaml
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ output_path = args.output or "data/processed/grid_search_results.json"
+
+ # =========================================================================
+ # Phase 1: Load everything once
+ # =========================================================================
+ logger.info("=" * 60)
+ logger.info("HAG Grid Search")
+ logger.info(" betas: %s", args.betas)
+ logger.info(" max_iters: %s", args.max_iters)
+ logger.info(" top_k: %d", args.top_k)
+ logger.info(" grid points: %d", len(args.betas) * len(args.max_iters))
+ logger.info(" max_samples: %d", args.max_samples)
+ logger.info(" device: %s", args.device)
+ logger.info("=" * 60)
+
+ t_start = time.time()
+
+ # Load questions
+ logger.info("Loading questions from %s...", args.questions)
+ questions, gold_answers = load_questions(args.questions, args.max_samples)
+ logger.info("Loaded %d questions", len(questions))
+
+ # Load memory bank
+ logger.info("Loading memory bank from %s...", args.memory_bank)
+ mb_config = MemoryBankConfig(**cfg.get("memory", {}))
+ memory_bank = MemoryBank(mb_config)
+ memory_bank.load(args.memory_bank, device=args.device)
+ logger.info("Memory bank: %d passages, dim=%d", memory_bank.size, memory_bank.dim)
+
+ # Load encoder
+ logger.info("Loading encoder...")
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+ encoder = Encoder(encoder_config, device=args.device)
+
+ # Load generator
+ logger.info("Loading generator...")
+ generator_config = GeneratorConfig(**cfg.get("generator", {}))
+ generator = Generator(generator_config, device=args.device)
+
+ # Encode all questions once
+ logger.info("Encoding %d questions...", len(questions))
+ t_enc_start = time.time()
+ query_embeddings = encode_questions_batched(
+ encoder, questions, batch_size=encoder_config.batch_size
+ ) # (N, d) on device
+ t_enc_end = time.time()
+ logger.info("Encoded in %.2fs, shape=%s", t_enc_end - t_enc_start, query_embeddings.shape)
+
+ # =========================================================================
+ # Run FAISS baseline
+ # =========================================================================
+ faiss_metrics, faiss_cache = run_faiss_baseline(
+ query_embeddings, memory_bank, generator, questions, gold_answers, args.top_k
+ )
+
+ # =========================================================================
+ # Run Hopfield grid search
+ # =========================================================================
+ grid_results, meta = run_hopfield_grid(
+ query_embeddings,
+ memory_bank,
+ generator,
+ questions,
+ gold_answers,
+ faiss_cache,
+ betas=args.betas,
+ max_iters=args.max_iters,
+ top_k=args.top_k,
+ device=args.device,
+ )
+
+ # =========================================================================
+ # Find best config and save results
+ # =========================================================================
+ best = max(grid_results, key=lambda gp: gp.f1)
+
+ t_total = time.time() - t_start
+ meta["total_time_s"] = round(t_total, 1)
+
+ output = {
+ "meta": meta,
+ "faiss_baseline": faiss_metrics,
+ "grid_results": [
+ {
+ "beta": gp.beta,
+ "max_iter": gp.max_iter,
+ "em": round(gp.em, 4),
+ "f1": round(gp.f1, 4),
+ "avg_entropy": round(gp.avg_entropy, 4),
+ "avg_energy_gap": round(gp.avg_energy_gap, 4),
+ "avg_faiss_overlap": round(gp.avg_faiss_overlap, 4),
+ "avg_steps": round(gp.avg_steps, 2),
+ }
+ for gp in grid_results
+ ],
+ "best_config": {
+ "beta": best.beta,
+ "max_iter": best.max_iter,
+ "em": round(best.em, 4),
+ "f1": round(best.f1, 4),
+ "avg_entropy": round(best.avg_entropy, 4),
+ "avg_energy_gap": round(best.avg_energy_gap, 4),
+ "avg_faiss_overlap": round(best.avg_faiss_overlap, 4),
+ },
+ }
+
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
+ with open(output_path, "w") as f:
+ json.dump(output, f, indent=2)
+
+ logger.info("=" * 60)
+ logger.info("RESULTS SUMMARY")
+ logger.info(" FAISS baseline: EM=%.4f, F1=%.4f", faiss_metrics["em"], faiss_metrics["f1"])
+ logger.info(
+ " Best HAG config: beta=%.2f, max_iter=%d => EM=%.4f, F1=%.4f",
+ best.beta,
+ best.max_iter,
+ best.em,
+ best.f1,
+ )
+ logger.info(" Total LLM calls: %d (saved %.1f%%)", meta["total_llm_calls"], meta["savings_pct"])
+ logger.info(" Total time: %.1fs", t_total)
+ logger.info(" Results saved to: %s", output_path)
+ logger.info("=" * 60)
+
+
+if __name__ == "__main__":
+ main()