"""Grid search over HAG hyperparameters (beta, max_iter) with dedup-based LLM caching. Key insight: many (beta, max_iter) combos retrieve the same top-k passages for a given question. By deduplicating on (question_idx, frozenset(top_k_indices)), we call the LLM only for unique passage sets, saving ~80-89% of generation calls. Usage: python scripts/run_grid_search.py \ --config configs/hotpotqa.yaml \ --memory-bank data/processed/hotpotqa_memory_bank.pt \ --questions data/processed/hotpotqa_questions.jsonl \ --device cuda \ --max-samples 100 """ import argparse import json import logging import time from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import torch from hag.config import ( EncoderConfig, GeneratorConfig, HopfieldConfig, MemoryBankConfig, PipelineConfig, ) from hag.encoder import Encoder from hag.energy import compute_attention_entropy, compute_energy_curve, compute_energy_gap from hag.generator import Generator from hag.hopfield import HopfieldRetrieval from hag.memory_bank import MemoryBank from hag.metrics import exact_match, f1_score from hag.retriever_faiss import FAISSRetriever logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s") logger = logging.getLogger(__name__) @dataclass class GridPoint: """Results for a single (beta, max_iter) configuration.""" beta: float max_iter: int em: float f1: float avg_entropy: float avg_energy_gap: float avg_faiss_overlap: float avg_steps: float def load_questions(path: str, max_samples: Optional[int] = None) -> Tuple[List[str], List[str]]: """Load questions and gold answers from JSONL file. Args: path: path to JSONL file with 'question' and 'answer' fields max_samples: if set, limit to first N samples Returns: Tuple of (questions, gold_answers). """ questions = [] gold_answers = [] with open(path) as f: for line in f: record = json.loads(line) questions.append(record["question"]) gold_answers.append(record["answer"]) if max_samples and len(questions) >= max_samples: break return questions, gold_answers def encode_questions_batched( encoder: Encoder, questions: List[str], batch_size: int = 32 ) -> torch.Tensor: """Encode all questions into embeddings, batched for efficiency. Args: encoder: the encoder instance questions: list of question strings batch_size: encoding batch size Returns: (N, d) tensor of query embeddings. """ all_embeddings = [] for i in range(0, len(questions), batch_size): batch = questions[i : i + batch_size] embs = encoder.encode(batch) # (batch_size, d) all_embeddings.append(embs) return torch.cat(all_embeddings, dim=0) # (N, d) def run_faiss_baseline( query_embeddings: torch.Tensor, memory_bank: MemoryBank, generator: Generator, questions: List[str], gold_answers: List[str], top_k: int, ) -> Tuple[Dict[str, float], Dict[int, Tuple[str, Tuple[int, ...]]]]: """Run FAISS baseline and cache results. Args: query_embeddings: (N, d) tensor memory_bank: the memory bank generator: LLM generator questions: list of question strings gold_answers: list of gold answer strings top_k: number of passages to retrieve Returns: Tuple of (metrics_dict, faiss_cache). faiss_cache maps question_idx -> (answer, top_k_indices_tuple). """ logger.info("Building FAISS index...") embeddings_np = memory_bank.embeddings.T.cpu().numpy().astype(np.float32) # (N_passages, d) faiss_ret = FAISSRetriever(top_k=top_k) faiss_ret.build_index(embeddings_np, memory_bank.passages) faiss_cache: Dict[int, Tuple[str, Tuple[int, ...]]] = {} em_scores = [] f1_scores = [] logger.info("Running FAISS baseline on %d questions...", len(questions)) for i, question in enumerate(questions): query_np = query_embeddings[i].cpu().numpy().astype(np.float32) # (d,) result = faiss_ret.retrieve(query_np) answer = generator.generate(question, result.passages) indices_tuple = tuple(sorted(result.indices.tolist())) faiss_cache[i] = (answer, indices_tuple) em_scores.append(exact_match(answer, gold_answers[i])) f1_scores.append(f1_score(answer, gold_answers[i])) if (i + 1) % 20 == 0: logger.info( " FAISS baseline: %d/%d (EM=%.3f, F1=%.3f)", i + 1, len(questions), sum(em_scores) / len(em_scores), sum(f1_scores) / len(f1_scores), ) metrics = { "em": sum(em_scores) / len(em_scores), "f1": sum(f1_scores) / len(f1_scores), } logger.info("FAISS baseline: EM=%.4f, F1=%.4f", metrics["em"], metrics["f1"]) return metrics, faiss_cache def run_hopfield_grid( query_embeddings: torch.Tensor, memory_bank: MemoryBank, generator: Generator, questions: List[str], gold_answers: List[str], faiss_cache: Dict[int, Tuple[str, Tuple[int, ...]]], betas: List[float], max_iters: List[int], top_k: int, device: str, ) -> Tuple[List[GridPoint], Dict]: """Run grid search over (beta, max_iter) with dedup-based LLM caching. Phase 2: Retrieve all configs (fast, batched). Phase 3: Deduplicate and generate (LLM calls only for unique passage sets). Phase 4: Evaluate and collect results. Args: query_embeddings: (N, d) tensor on device memory_bank: memory bank (embeddings on device) generator: LLM generator questions: list of question strings gold_answers: list of gold answer strings faiss_cache: maps question_idx -> (answer, sorted_indices_tuple) betas: list of beta values to sweep max_iters: list of max_iter values to sweep top_k: fixed top_k for retrieval device: computation device Returns: Tuple of (grid_results, meta_dict). """ n_questions = len(questions) memory = memory_bank.embeddings # (d, N_passages) on device # ========================================================================= # Phase 2: Retrieve all configurations (batched, milliseconds each) # ========================================================================= logger.info("Phase 2: Running %d retrieval configs...", len(betas) * len(max_iters)) # Structure: config_key -> per-question retrieval data # retrieval_data[config_key][q_idx] = {indices_tuple, entropy, energy_gap, steps, faiss_overlap} @dataclass class RetrievalInfo: indices_tuple: Tuple[int, ...] entropy: float energy_gap: float steps: int faiss_overlap: float retrieval_data: Dict[Tuple[float, int], List[RetrievalInfo]] = {} t_retrieve_start = time.time() for beta in betas: for max_iter in max_iters: config = HopfieldConfig(beta=beta, max_iter=max_iter, top_k=top_k) hopfield = HopfieldRetrieval(config) # Batched retrieval: all questions at once result = hopfield.retrieve( query_embeddings, memory, return_energy=True ) # attention_weights: (N_questions, N_passages) alpha = result.attention_weights # (N_questions, N_passages) k = min(top_k, alpha.shape[-1]) scores, indices = torch.topk(alpha, k, dim=-1) # (N_questions, k) # Compute energy curve per-question (energy_curve contains batch tensors) energy_curves_raw = result.energy_curve # list of (N_questions,) tensors infos = [] for q_idx in range(n_questions): q_indices = sorted(indices[q_idx].tolist()) q_indices_tuple = tuple(q_indices) # Per-question entropy q_entropy = compute_attention_entropy(alpha[q_idx]) # Per-question energy gap if energy_curves_raw is not None: q_energies = [e[q_idx].item() for e in energy_curves_raw] q_energy_gap = compute_energy_gap(q_energies) else: q_energy_gap = 0.0 # FAISS overlap: fraction of top-k indices shared with FAISS faiss_indices_set = set(faiss_cache[q_idx][1]) hopfield_indices_set = set(q_indices) overlap = len(faiss_indices_set & hopfield_indices_set) / k infos.append(RetrievalInfo( indices_tuple=q_indices_tuple, entropy=q_entropy, energy_gap=q_energy_gap, steps=result.num_steps, faiss_overlap=overlap, )) retrieval_data[(beta, max_iter)] = infos t_retrieve_end = time.time() logger.info("Phase 2 complete: %.2fs for all retrieval configs", t_retrieve_end - t_retrieve_start) # ========================================================================= # Phase 3: Deduplicate and generate # ========================================================================= logger.info("Phase 3: Deduplicating and generating...") # Build set of unique (question_idx, passage_set) combos needing LLM calls # Cache key: (question_idx, frozenset(top_k_indices)) llm_cache: Dict[Tuple[int, frozenset], str] = {} # Seed cache with FAISS answers (same passage sets don't need re-generation) for q_idx, (answer, indices_tuple) in faiss_cache.items(): cache_key = (q_idx, frozenset(indices_tuple)) llm_cache[cache_key] = answer # Collect all unique keys we need needed_keys: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {} for (beta, max_iter), infos in retrieval_data.items(): for q_idx, info in enumerate(infos): cache_key = (q_idx, frozenset(info.indices_tuple)) if cache_key not in llm_cache and cache_key not in needed_keys: needed_keys[cache_key] = (q_idx, info.indices_tuple) total_grid_calls = n_questions * len(betas) * len(max_iters) already_cached = total_grid_calls - len(needed_keys) # rough; some may still be unique logger.info( "Unique LLM calls needed: %d (out of %d grid points, %.1f%% saving)", len(needed_keys), total_grid_calls, (1 - len(needed_keys) / total_grid_calls) * 100 if total_grid_calls > 0 else 0, ) # Generate answers for unique passage sets t_gen_start = time.time() for call_idx, (cache_key, (q_idx, indices_tuple)) in enumerate(needed_keys.items()): # Look up passages by sorted indices indices_tensor = torch.tensor(list(indices_tuple), dtype=torch.long) passages = memory_bank.get_passages_by_indices(indices_tensor) answer = generator.generate(questions[q_idx], passages) llm_cache[cache_key] = answer if (call_idx + 1) % 20 == 0: elapsed = time.time() - t_gen_start rate = (call_idx + 1) / elapsed remaining = (len(needed_keys) - call_idx - 1) / rate logger.info( " Generated %d/%d (%.1f calls/s, ~%.0fs remaining)", call_idx + 1, len(needed_keys), rate, remaining, ) t_gen_end = time.time() logger.info("Phase 3 complete: %d LLM calls in %.1fs", len(needed_keys), t_gen_end - t_gen_start) # ========================================================================= # Phase 4: Evaluate all grid points # ========================================================================= logger.info("Phase 4: Evaluating all grid points...") grid_results: List[GridPoint] = [] for beta in betas: for max_iter in max_iters: infos = retrieval_data[(beta, max_iter)] em_scores = [] f1_scores = [] entropies = [] energy_gaps = [] faiss_overlaps = [] steps_list = [] for q_idx, info in enumerate(infos): cache_key = (q_idx, frozenset(info.indices_tuple)) answer = llm_cache[cache_key] em_scores.append(exact_match(answer, gold_answers[q_idx])) f1_scores.append(f1_score(answer, gold_answers[q_idx])) entropies.append(info.entropy) energy_gaps.append(info.energy_gap) faiss_overlaps.append(info.faiss_overlap) steps_list.append(info.steps) gp = GridPoint( beta=beta, max_iter=max_iter, em=sum(em_scores) / len(em_scores), f1=sum(f1_scores) / len(f1_scores), avg_entropy=sum(entropies) / len(entropies), avg_energy_gap=sum(energy_gaps) / len(energy_gaps), avg_faiss_overlap=sum(faiss_overlaps) / len(faiss_overlaps), avg_steps=sum(steps_list) / len(steps_list), ) grid_results.append(gp) logger.info( " beta=%.2f max_iter=%2d => EM=%.3f F1=%.3f entropy=%.3f energy_gap=%.3f faiss_overlap=%.3f", beta, max_iter, gp.em, gp.f1, gp.avg_entropy, gp.avg_energy_gap, gp.avg_faiss_overlap, ) total_llm_calls = len(faiss_cache) + len(needed_keys) meta = { "grid_size": len(betas) * len(max_iters), "n_questions": n_questions, "total_grid_evaluations": total_grid_calls, "unique_llm_calls": len(needed_keys), "faiss_llm_calls": len(faiss_cache), "total_llm_calls": total_llm_calls, "savings_pct": round( (1 - total_llm_calls / (total_grid_calls + len(faiss_cache))) * 100, 1 ) if (total_grid_calls + len(faiss_cache)) > 0 else 0, "retrieval_time_s": round(t_retrieve_end - t_retrieve_start, 2), "generation_time_s": round(t_gen_end - t_gen_start, 2), } return grid_results, meta def main() -> None: parser = argparse.ArgumentParser( description="Grid search over HAG hyperparameters (beta, max_iter)" ) parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--questions", type=str, required=True) parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--max-samples", type=int, default=100) parser.add_argument( "--output", type=str, default=None, help="Output JSON path (default: data/processed/grid_search_results.json)", ) parser.add_argument( "--betas", type=float, nargs="+", default=[0.25, 0.5, 1.0, 2.0, 3.0, 5.0, 8.0], ) parser.add_argument( "--max-iters", type=int, nargs="+", default=[1, 2, 3, 5, 8, 15], ) parser.add_argument("--top-k", type=int, default=5) args = parser.parse_args() import yaml with open(args.config) as f: cfg = yaml.safe_load(f) output_path = args.output or "data/processed/grid_search_results.json" # ========================================================================= # Phase 1: Load everything once # ========================================================================= logger.info("=" * 60) logger.info("HAG Grid Search") logger.info(" betas: %s", args.betas) logger.info(" max_iters: %s", args.max_iters) logger.info(" top_k: %d", args.top_k) logger.info(" grid points: %d", len(args.betas) * len(args.max_iters)) logger.info(" max_samples: %d", args.max_samples) logger.info(" device: %s", args.device) logger.info("=" * 60) t_start = time.time() # Load questions logger.info("Loading questions from %s...", args.questions) questions, gold_answers = load_questions(args.questions, args.max_samples) logger.info("Loaded %d questions", len(questions)) # Load memory bank logger.info("Loading memory bank from %s...", args.memory_bank) mb_config = MemoryBankConfig(**cfg.get("memory", {})) memory_bank = MemoryBank(mb_config) memory_bank.load(args.memory_bank, device=args.device) logger.info("Memory bank: %d passages, dim=%d", memory_bank.size, memory_bank.dim) # Load encoder logger.info("Loading encoder...") encoder_config = EncoderConfig(**cfg.get("encoder", {})) encoder = Encoder(encoder_config, device=args.device) # Load generator logger.info("Loading generator...") generator_config = GeneratorConfig(**cfg.get("generator", {})) generator = Generator(generator_config, device=args.device) # Encode all questions once logger.info("Encoding %d questions...", len(questions)) t_enc_start = time.time() query_embeddings = encode_questions_batched( encoder, questions, batch_size=encoder_config.batch_size ) # (N, d) on device t_enc_end = time.time() logger.info("Encoded in %.2fs, shape=%s", t_enc_end - t_enc_start, query_embeddings.shape) # ========================================================================= # Run FAISS baseline # ========================================================================= faiss_metrics, faiss_cache = run_faiss_baseline( query_embeddings, memory_bank, generator, questions, gold_answers, args.top_k ) # ========================================================================= # Run Hopfield grid search # ========================================================================= grid_results, meta = run_hopfield_grid( query_embeddings, memory_bank, generator, questions, gold_answers, faiss_cache, betas=args.betas, max_iters=args.max_iters, top_k=args.top_k, device=args.device, ) # ========================================================================= # Find best config and save results # ========================================================================= best = max(grid_results, key=lambda gp: gp.f1) t_total = time.time() - t_start meta["total_time_s"] = round(t_total, 1) output = { "meta": meta, "faiss_baseline": faiss_metrics, "grid_results": [ { "beta": gp.beta, "max_iter": gp.max_iter, "em": round(gp.em, 4), "f1": round(gp.f1, 4), "avg_entropy": round(gp.avg_entropy, 4), "avg_energy_gap": round(gp.avg_energy_gap, 4), "avg_faiss_overlap": round(gp.avg_faiss_overlap, 4), "avg_steps": round(gp.avg_steps, 2), } for gp in grid_results ], "best_config": { "beta": best.beta, "max_iter": best.max_iter, "em": round(best.em, 4), "f1": round(best.f1, 4), "avg_entropy": round(best.avg_entropy, 4), "avg_energy_gap": round(best.avg_energy_gap, 4), "avg_faiss_overlap": round(best.avg_faiss_overlap, 4), }, } Path(output_path).parent.mkdir(parents=True, exist_ok=True) with open(output_path, "w") as f: json.dump(output, f, indent=2) logger.info("=" * 60) logger.info("RESULTS SUMMARY") logger.info(" FAISS baseline: EM=%.4f, F1=%.4f", faiss_metrics["em"], faiss_metrics["f1"]) logger.info( " Best HAG config: beta=%.2f, max_iter=%d => EM=%.4f, F1=%.4f", best.beta, best.max_iter, best.em, best.f1, ) logger.info(" Total LLM calls: %d (saved %.1f%%)", meta["total_llm_calls"], meta["savings_pct"]) logger.info(" Total time: %.1fs", t_total) logger.info(" Results saved to: %s", output_path) logger.info("=" * 60) if __name__ == "__main__": main()