diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-02-16 14:44:42 -0600 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-02-16 14:44:42 -0600 |
| commit | 09d50e47860da0035e178a442dc936028808a0b3 (patch) | |
| tree | 9d651b0c7d289a9a0405953f2da989a3c431f147 /scripts/run_grid_search.py | |
| parent | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff) | |
- Add centering support to MemoryBank (center_query, apply_centering, mean
persistence in save/load) to remove centroid attractor in Hopfield dynamics
- Add center flag to MemoryBankConfig, device field to PipelineConfig
- Grid search scripts: initial (β≤8), residual, high-β, and centered grids
with dedup-based LLM caching (89-91% call savings)
- Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap
comparing centered vs uncentered dynamics
- Experiment log (note.md) documenting 4 rounds of results and root cause
analysis of centroid attractor problem
- Key finding: β_critical ≈ 37.6 for centered memory; best configs beat
FAISS baseline by +3-4% F1
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/run_grid_search.py')
| -rw-r--r-- | scripts/run_grid_search.py | 552 |
1 files changed, 552 insertions, 0 deletions
diff --git a/scripts/run_grid_search.py b/scripts/run_grid_search.py new file mode 100644 index 0000000..ddd5a8d --- /dev/null +++ b/scripts/run_grid_search.py @@ -0,0 +1,552 @@ +"""Grid search over HAG hyperparameters (beta, max_iter) with dedup-based LLM caching. + +Key insight: many (beta, max_iter) combos retrieve the same top-k passages for a given +question. By deduplicating on (question_idx, frozenset(top_k_indices)), we call the LLM +only for unique passage sets, saving ~80-89% of generation calls. + +Usage: + python scripts/run_grid_search.py \ + --config configs/hotpotqa.yaml \ + --memory-bank data/processed/hotpotqa_memory_bank.pt \ + --questions data/processed/hotpotqa_questions.jsonl \ + --device cuda \ + --max-samples 100 +""" + +import argparse +import json +import logging +import time +from dataclasses import dataclass +from pathlib import Path +from typing import Dict, List, Optional, Tuple + +import numpy as np +import torch + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import Encoder +from hag.energy import compute_attention_entropy, compute_energy_curve, compute_energy_gap +from hag.generator import Generator +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank +from hag.metrics import exact_match, f1_score +from hag.retriever_faiss import FAISSRetriever + +logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") +logger = logging.getLogger(__name__) + + +@dataclass +class GridPoint: + """Results for a single (beta, max_iter) configuration.""" + + beta: float + max_iter: int + em: float + f1: float + avg_entropy: float + avg_energy_gap: float + avg_faiss_overlap: float + avg_steps: float + + +def load_questions(path: str, max_samples: Optional[int] = None) -> Tuple[List[str], List[str]]: + """Load questions and gold answers from JSONL file. + + Args: + path: path to JSONL file with 'question' and 'answer' fields + max_samples: if set, limit to first N samples + + Returns: + Tuple of (questions, gold_answers). + """ + questions = [] + gold_answers = [] + with open(path) as f: + for line in f: + record = json.loads(line) + questions.append(record["question"]) + gold_answers.append(record["answer"]) + if max_samples and len(questions) >= max_samples: + break + return questions, gold_answers + + +def encode_questions_batched( + encoder: Encoder, questions: List[str], batch_size: int = 32 +) -> torch.Tensor: + """Encode all questions into embeddings, batched for efficiency. + + Args: + encoder: the encoder instance + questions: list of question strings + batch_size: encoding batch size + + Returns: + (N, d) tensor of query embeddings. + """ + all_embeddings = [] + for i in range(0, len(questions), batch_size): + batch = questions[i : i + batch_size] + embs = encoder.encode(batch) # (batch_size, d) + all_embeddings.append(embs) + return torch.cat(all_embeddings, dim=0) # (N, d) + + +def run_faiss_baseline( + query_embeddings: torch.Tensor, + memory_bank: MemoryBank, + generator: Generator, + questions: List[str], + gold_answers: List[str], + top_k: int, +) -> Tuple[Dict[str, float], Dict[int, Tuple[str, Tuple[int, ...]]]]: + """Run FAISS baseline and cache results. + + Args: + query_embeddings: (N, d) tensor + memory_bank: the memory bank + generator: LLM generator + questions: list of question strings + gold_answers: list of gold answer strings + top_k: number of passages to retrieve + + Returns: + Tuple of (metrics_dict, faiss_cache). + faiss_cache maps question_idx -> (answer, top_k_indices_tuple). + """ + logger.info("Building FAISS index...") + embeddings_np = memory_bank.embeddings.T.cpu().numpy().astype(np.float32) # (N_passages, d) + faiss_ret = FAISSRetriever(top_k=top_k) + faiss_ret.build_index(embeddings_np, memory_bank.passages) + + faiss_cache: Dict[int, Tuple[str, Tuple[int, ...]]] = {} + em_scores = [] + f1_scores = [] + + logger.info("Running FAISS baseline on %d questions...", len(questions)) + for i, question in enumerate(questions): + query_np = query_embeddings[i].cpu().numpy().astype(np.float32) # (d,) + result = faiss_ret.retrieve(query_np) + answer = generator.generate(question, result.passages) + indices_tuple = tuple(sorted(result.indices.tolist())) + + faiss_cache[i] = (answer, indices_tuple) + em_scores.append(exact_match(answer, gold_answers[i])) + f1_scores.append(f1_score(answer, gold_answers[i])) + + if (i + 1) % 20 == 0: + logger.info( + " FAISS baseline: %d/%d (EM=%.3f, F1=%.3f)", + i + 1, + len(questions), + sum(em_scores) / len(em_scores), + sum(f1_scores) / len(f1_scores), + ) + + metrics = { + "em": sum(em_scores) / len(em_scores), + "f1": sum(f1_scores) / len(f1_scores), + } + logger.info("FAISS baseline: EM=%.4f, F1=%.4f", metrics["em"], metrics["f1"]) + return metrics, faiss_cache + + +def run_hopfield_grid( + query_embeddings: torch.Tensor, + memory_bank: MemoryBank, + generator: Generator, + questions: List[str], + gold_answers: List[str], + faiss_cache: Dict[int, Tuple[str, Tuple[int, ...]]], + betas: List[float], + max_iters: List[int], + top_k: int, + device: str, +) -> Tuple[List[GridPoint], Dict]: + """Run grid search over (beta, max_iter) with dedup-based LLM caching. + + Phase 2: Retrieve all configs (fast, batched). + Phase 3: Deduplicate and generate (LLM calls only for unique passage sets). + Phase 4: Evaluate and collect results. + + Args: + query_embeddings: (N, d) tensor on device + memory_bank: memory bank (embeddings on device) + generator: LLM generator + questions: list of question strings + gold_answers: list of gold answer strings + faiss_cache: maps question_idx -> (answer, sorted_indices_tuple) + betas: list of beta values to sweep + max_iters: list of max_iter values to sweep + top_k: fixed top_k for retrieval + device: computation device + + Returns: + Tuple of (grid_results, meta_dict). + """ + n_questions = len(questions) + memory = memory_bank.embeddings # (d, N_passages) on device + + # ========================================================================= + # Phase 2: Retrieve all configurations (batched, milliseconds each) + # ========================================================================= + logger.info("Phase 2: Running %d retrieval configs...", len(betas) * len(max_iters)) + + # Structure: config_key -> per-question retrieval data + # retrieval_data[config_key][q_idx] = {indices_tuple, entropy, energy_gap, steps, faiss_overlap} + @dataclass + class RetrievalInfo: + indices_tuple: Tuple[int, ...] + entropy: float + energy_gap: float + steps: int + faiss_overlap: float + + retrieval_data: Dict[Tuple[float, int], List[RetrievalInfo]] = {} + + t_retrieve_start = time.time() + for beta in betas: + for max_iter in max_iters: + config = HopfieldConfig(beta=beta, max_iter=max_iter, top_k=top_k) + hopfield = HopfieldRetrieval(config) + + # Batched retrieval: all questions at once + result = hopfield.retrieve( + query_embeddings, memory, return_energy=True + ) # attention_weights: (N_questions, N_passages) + + alpha = result.attention_weights # (N_questions, N_passages) + k = min(top_k, alpha.shape[-1]) + scores, indices = torch.topk(alpha, k, dim=-1) # (N_questions, k) + + # Compute energy curve per-question (energy_curve contains batch tensors) + energy_curves_raw = result.energy_curve # list of (N_questions,) tensors + + infos = [] + for q_idx in range(n_questions): + q_indices = sorted(indices[q_idx].tolist()) + q_indices_tuple = tuple(q_indices) + + # Per-question entropy + q_entropy = compute_attention_entropy(alpha[q_idx]) + + # Per-question energy gap + if energy_curves_raw is not None: + q_energies = [e[q_idx].item() for e in energy_curves_raw] + q_energy_gap = compute_energy_gap(q_energies) + else: + q_energy_gap = 0.0 + + # FAISS overlap: fraction of top-k indices shared with FAISS + faiss_indices_set = set(faiss_cache[q_idx][1]) + hopfield_indices_set = set(q_indices) + overlap = len(faiss_indices_set & hopfield_indices_set) / k + + infos.append(RetrievalInfo( + indices_tuple=q_indices_tuple, + entropy=q_entropy, + energy_gap=q_energy_gap, + steps=result.num_steps, + faiss_overlap=overlap, + )) + + retrieval_data[(beta, max_iter)] = infos + + t_retrieve_end = time.time() + logger.info("Phase 2 complete: %.2fs for all retrieval configs", t_retrieve_end - t_retrieve_start) + + # ========================================================================= + # Phase 3: Deduplicate and generate + # ========================================================================= + logger.info("Phase 3: Deduplicating and generating...") + + # Build set of unique (question_idx, passage_set) combos needing LLM calls + # Cache key: (question_idx, frozenset(top_k_indices)) + llm_cache: Dict[Tuple[int, frozenset], str] = {} + + # Seed cache with FAISS answers (same passage sets don't need re-generation) + for q_idx, (answer, indices_tuple) in faiss_cache.items(): + cache_key = (q_idx, frozenset(indices_tuple)) + llm_cache[cache_key] = answer + + # Collect all unique keys we need + needed_keys: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {} + for (beta, max_iter), infos in retrieval_data.items(): + for q_idx, info in enumerate(infos): + cache_key = (q_idx, frozenset(info.indices_tuple)) + if cache_key not in llm_cache and cache_key not in needed_keys: + needed_keys[cache_key] = (q_idx, info.indices_tuple) + + total_grid_calls = n_questions * len(betas) * len(max_iters) + already_cached = total_grid_calls - len(needed_keys) # rough; some may still be unique + logger.info( + "Unique LLM calls needed: %d (out of %d grid points, %.1f%% saving)", + len(needed_keys), + total_grid_calls, + (1 - len(needed_keys) / total_grid_calls) * 100 if total_grid_calls > 0 else 0, + ) + + # Generate answers for unique passage sets + t_gen_start = time.time() + for call_idx, (cache_key, (q_idx, indices_tuple)) in enumerate(needed_keys.items()): + # Look up passages by sorted indices + indices_tensor = torch.tensor(list(indices_tuple), dtype=torch.long) + passages = memory_bank.get_passages_by_indices(indices_tensor) + answer = generator.generate(questions[q_idx], passages) + llm_cache[cache_key] = answer + + if (call_idx + 1) % 20 == 0: + elapsed = time.time() - t_gen_start + rate = (call_idx + 1) / elapsed + remaining = (len(needed_keys) - call_idx - 1) / rate + logger.info( + " Generated %d/%d (%.1f calls/s, ~%.0fs remaining)", + call_idx + 1, + len(needed_keys), + rate, + remaining, + ) + + t_gen_end = time.time() + logger.info("Phase 3 complete: %d LLM calls in %.1fs", len(needed_keys), t_gen_end - t_gen_start) + + # ========================================================================= + # Phase 4: Evaluate all grid points + # ========================================================================= + logger.info("Phase 4: Evaluating all grid points...") + + grid_results: List[GridPoint] = [] + for beta in betas: + for max_iter in max_iters: + infos = retrieval_data[(beta, max_iter)] + em_scores = [] + f1_scores = [] + entropies = [] + energy_gaps = [] + faiss_overlaps = [] + steps_list = [] + + for q_idx, info in enumerate(infos): + cache_key = (q_idx, frozenset(info.indices_tuple)) + answer = llm_cache[cache_key] + + em_scores.append(exact_match(answer, gold_answers[q_idx])) + f1_scores.append(f1_score(answer, gold_answers[q_idx])) + entropies.append(info.entropy) + energy_gaps.append(info.energy_gap) + faiss_overlaps.append(info.faiss_overlap) + steps_list.append(info.steps) + + gp = GridPoint( + beta=beta, + max_iter=max_iter, + em=sum(em_scores) / len(em_scores), + f1=sum(f1_scores) / len(f1_scores), + avg_entropy=sum(entropies) / len(entropies), + avg_energy_gap=sum(energy_gaps) / len(energy_gaps), + avg_faiss_overlap=sum(faiss_overlaps) / len(faiss_overlaps), + avg_steps=sum(steps_list) / len(steps_list), + ) + grid_results.append(gp) + logger.info( + " beta=%.2f max_iter=%2d => EM=%.3f F1=%.3f entropy=%.3f energy_gap=%.3f faiss_overlap=%.3f", + beta, + max_iter, + gp.em, + gp.f1, + gp.avg_entropy, + gp.avg_energy_gap, + gp.avg_faiss_overlap, + ) + + total_llm_calls = len(faiss_cache) + len(needed_keys) + meta = { + "grid_size": len(betas) * len(max_iters), + "n_questions": n_questions, + "total_grid_evaluations": total_grid_calls, + "unique_llm_calls": len(needed_keys), + "faiss_llm_calls": len(faiss_cache), + "total_llm_calls": total_llm_calls, + "savings_pct": round( + (1 - total_llm_calls / (total_grid_calls + len(faiss_cache))) * 100, 1 + ) + if (total_grid_calls + len(faiss_cache)) > 0 + else 0, + "retrieval_time_s": round(t_retrieve_end - t_retrieve_start, 2), + "generation_time_s": round(t_gen_end - t_gen_start, 2), + } + + return grid_results, meta + + +def main() -> None: + parser = argparse.ArgumentParser( + description="Grid search over HAG hyperparameters (beta, max_iter)" + ) + parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--max-samples", type=int, default=100) + parser.add_argument( + "--output", + type=str, + default=None, + help="Output JSON path (default: data/processed/grid_search_results.json)", + ) + parser.add_argument( + "--betas", + type=float, + nargs="+", + default=[0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 8.0], + ) + parser.add_argument( + "--max-iters", + type=int, + nargs="+", + default=[1, 2, 3, 5, 8, 15], + ) + parser.add_argument("--top-k", type=int, default=5) + args = parser.parse_args() + + import yaml + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + output_path = args.output or "data/processed/grid_search_results.json" + + # ========================================================================= + # Phase 1: Load everything once + # ========================================================================= + logger.info("=" * 60) + logger.info("HAG Grid Search") + logger.info(" betas: %s", args.betas) + logger.info(" max_iters: %s", args.max_iters) + logger.info(" top_k: %d", args.top_k) + logger.info(" grid points: %d", len(args.betas) * len(args.max_iters)) + logger.info(" max_samples: %d", args.max_samples) + logger.info(" device: %s", args.device) + logger.info("=" * 60) + + t_start = time.time() + + # Load questions + logger.info("Loading questions from %s...", args.questions) + questions, gold_answers = load_questions(args.questions, args.max_samples) + logger.info("Loaded %d questions", len(questions)) + + # Load memory bank + logger.info("Loading memory bank from %s...", args.memory_bank) + mb_config = MemoryBankConfig(**cfg.get("memory", {})) + memory_bank = MemoryBank(mb_config) + memory_bank.load(args.memory_bank, device=args.device) + logger.info("Memory bank: %d passages, dim=%d", memory_bank.size, memory_bank.dim) + + # Load encoder + logger.info("Loading encoder...") + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + encoder = Encoder(encoder_config, device=args.device) + + # Load generator + logger.info("Loading generator...") + generator_config = GeneratorConfig(**cfg.get("generator", {})) + generator = Generator(generator_config, device=args.device) + + # Encode all questions once + logger.info("Encoding %d questions...", len(questions)) + t_enc_start = time.time() + query_embeddings = encode_questions_batched( + encoder, questions, batch_size=encoder_config.batch_size + ) # (N, d) on device + t_enc_end = time.time() + logger.info("Encoded in %.2fs, shape=%s", t_enc_end - t_enc_start, query_embeddings.shape) + + # ========================================================================= + # Run FAISS baseline + # ========================================================================= + faiss_metrics, faiss_cache = run_faiss_baseline( + query_embeddings, memory_bank, generator, questions, gold_answers, args.top_k + ) + + # ========================================================================= + # Run Hopfield grid search + # ========================================================================= + grid_results, meta = run_hopfield_grid( + query_embeddings, + memory_bank, + generator, + questions, + gold_answers, + faiss_cache, + betas=args.betas, + max_iters=args.max_iters, + top_k=args.top_k, + device=args.device, + ) + + # ========================================================================= + # Find best config and save results + # ========================================================================= + best = max(grid_results, key=lambda gp: gp.f1) + + t_total = time.time() - t_start + meta["total_time_s"] = round(t_total, 1) + + output = { + "meta": meta, + "faiss_baseline": faiss_metrics, + "grid_results": [ + { + "beta": gp.beta, + "max_iter": gp.max_iter, + "em": round(gp.em, 4), + "f1": round(gp.f1, 4), + "avg_entropy": round(gp.avg_entropy, 4), + "avg_energy_gap": round(gp.avg_energy_gap, 4), + "avg_faiss_overlap": round(gp.avg_faiss_overlap, 4), + "avg_steps": round(gp.avg_steps, 2), + } + for gp in grid_results + ], + "best_config": { + "beta": best.beta, + "max_iter": best.max_iter, + "em": round(best.em, 4), + "f1": round(best.f1, 4), + "avg_entropy": round(best.avg_entropy, 4), + "avg_energy_gap": round(best.avg_energy_gap, 4), + "avg_faiss_overlap": round(best.avg_faiss_overlap, 4), + }, + } + + Path(output_path).parent.mkdir(parents=True, exist_ok=True) + with open(output_path, "w") as f: + json.dump(output, f, indent=2) + + logger.info("=" * 60) + logger.info("RESULTS SUMMARY") + logger.info(" FAISS baseline: EM=%.4f, F1=%.4f", faiss_metrics["em"], faiss_metrics["f1"]) + logger.info( + " Best HAG config: beta=%.2f, max_iter=%d => EM=%.4f, F1=%.4f", + best.beta, + best.max_iter, + best.em, + best.f1, + ) + logger.info(" Total LLM calls: %d (saved %.1f%%)", meta["total_llm_calls"], meta["savings_pct"]) + logger.info(" Total time: %.1fs", t_total) + logger.info(" Results saved to: %s", output_path) + logger.info("=" * 60) + + +if __name__ == "__main__": + main() |
