"""Evaluate Residual Hopfield configs on 100 questions with dedup-based LLM caching. Residual update: q_{t+1} = λ * q_t + (1-λ) * M @ softmax(β * M^T @ q_t) Usage: CUDA_VISIBLE_DEVICES=1 python scripts/eval_residual_grid.py \ --config configs/hotpotqa.yaml \ --memory-bank data/processed/hotpotqa_memory_bank.pt \ --questions data/processed/hotpotqa_questions.jsonl \ --device cuda --max-samples 100 """ import argparse import json import logging import sys import time from pathlib import Path from typing import Dict, List, Tuple import numpy as np import torch import yaml from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, MemoryBankConfig from hag.encoder import Encoder from hag.energy import compute_attention_entropy from hag.generator import Generator from hag.memory_bank import MemoryBank from hag.metrics import exact_match, f1_score from hag.retriever_faiss import FAISSRetriever logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", stream=sys.stdout, ) logger = logging.getLogger(__name__) def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]: questions, gold_answers = [], [] with open(path) as f: for line in f: r = json.loads(line) questions.append(r["question"]) gold_answers.append(r["answer"]) if len(questions) >= max_samples: break return questions, gold_answers @torch.no_grad() def residual_hopfield_retrieve( query: torch.Tensor, memory: torch.Tensor, beta: float, lam: float, max_iter: int, top_k: int, ) -> Tuple[torch.Tensor, torch.Tensor, float]: """Residual Hopfield retrieval on full memory bank. q_{t+1} = λ * q_t + (1-λ) * M @ softmax(β * M^T @ q_t) Args: query: (batch, d) memory: (d, N) beta: inverse temperature lam: residual weight (0=pure Hopfield, 1=no update) max_iter: number of iterations top_k: number of passages to return Returns: (top_k_indices, top_k_scores, avg_entropy) for the batch. indices: (batch, top_k), scores: (batch, top_k), entropy: float """ q = query.clone() for _ in range(max_iter): logits = beta * (q @ memory) # (batch, N) alpha = torch.softmax(logits, dim=-1) # (batch, N) q_hop = alpha @ memory.T # (batch, d) q = lam * q + (1.0 - lam) * q_hop # (batch, d) # Final attention logits = beta * (q @ memory) # (batch, N) alpha = torch.softmax(logits, dim=-1) # (batch, N) scores, indices = torch.topk(alpha, top_k, dim=-1) # (batch, top_k) entropy = compute_attention_entropy(alpha) return indices, scores, entropy def main() -> None: parser = argparse.ArgumentParser(description="Residual Hopfield grid evaluation") parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--questions", type=str, required=True) parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--max-samples", type=int, default=100) parser.add_argument("--output", type=str, default="data/processed/residual_grid_results.json") parser.add_argument("--top-k", type=int, default=5) args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) # Grid betas = [5.0, 10.0, 20.0, 50.0, 100.0] lambdas = [0.5, 0.7, 0.8, 0.9, 0.95] max_iters_list = [1, 3, 5, 8] top_k = args.top_k total_configs = len(betas) * len(lambdas) * len(max_iters_list) logger.info("=" * 60) logger.info("Residual Hopfield Grid Search") logger.info(" betas: %s", betas) logger.info(" lambdas: %s", lambdas) logger.info(" max_iters: %s", max_iters_list) logger.info(" total configs: %d", total_configs) logger.info("=" * 60) t_start = time.time() # Load questions, gold_answers = load_questions(args.questions, args.max_samples) n = len(questions) logger.info("Loaded %d questions", n) mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) mb.load(args.memory_bank, device=args.device) M = mb.embeddings # (d, N) logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim) encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device) generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device) logger.info("Encoding questions...") all_embs = [] batch_size = cfg.get("encoder", {}).get("batch_size", 64) for i in range(0, n, batch_size): all_embs.append(encoder.encode(questions[i : i + batch_size])) Q = torch.cat(all_embs, dim=0) # (n, d) logger.info("Encoded, shape=%s", Q.shape) # FAISS baseline logger.info("Running FAISS baseline...") emb_np = mb.embeddings.T.cpu().numpy().astype(np.float32) faiss_ret = FAISSRetriever(top_k=top_k) faiss_ret.build_index(emb_np, mb.passages) faiss_indices: Dict[int, Tuple[int, ...]] = {} llm_cache: Dict[Tuple[int, frozenset], str] = {} for i in range(n): q_np = Q[i].cpu().numpy().astype(np.float32) result = faiss_ret.retrieve(q_np) idx_tuple = tuple(sorted(result.indices.tolist())) faiss_indices[i] = idx_tuple cache_key = (i, frozenset(idx_tuple)) answer = generator.generate(questions[i], result.passages) llm_cache[cache_key] = answer if (i + 1) % 20 == 0: ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s)) faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1) # Phase 2: Retrieve all configs (batched, fast) logger.info("Phase 2: Retrieving all %d configs...", total_configs) t_ret = time.time() # config_key -> list of (sorted_indices_tuple, entropy) per question retrieval_data: Dict[Tuple[float, float, int], List[Tuple[Tuple[int, ...], float]]] = {} for beta in betas: for lam in lambdas: for max_iter in max_iters_list: indices, scores, entropy = residual_hopfield_retrieve( Q, M, beta=beta, lam=lam, max_iter=max_iter, top_k=top_k ) per_q = [] for i in range(n): idx_tuple = tuple(sorted(indices[i].tolist())) # per-question entropy per_q.append((idx_tuple, entropy)) retrieval_data[(beta, lam, max_iter)] = per_q logger.info("Retrieval done in %.1fs", time.time() - t_ret) # Phase 3: Dedup + generate needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {} for key, per_q in retrieval_data.items(): for i, (idx_tuple, _) in enumerate(per_q): cache_key = (i, frozenset(idx_tuple)) if cache_key not in llm_cache and cache_key not in needed: needed[cache_key] = (i, idx_tuple) total_grid_evals = total_configs * n logger.info( "Unique LLM calls needed: %d / %d grid evals (%.1f%% saving)", len(needed), total_grid_evals, (1 - len(needed) / total_grid_evals) * 100, ) t_gen = time.time() for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()): passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long)) answer = generator.generate(questions[q_idx], passages) llm_cache[cache_key] = answer if (call_idx + 1) % 50 == 0: elapsed = time.time() - t_gen rate = (call_idx + 1) / elapsed logger.info( " Generated %d/%d (%.1f/s, ~%.0fs left)", call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate, ) logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen) # Phase 4: Evaluate logger.info("Phase 4: Evaluating...") results = [] for beta in betas: for lam in lambdas: for max_iter in max_iters_list: per_q = retrieval_data[(beta, lam, max_iter)] ems, f1s, overlaps, entropies = [], [], [], [] for i, (idx_tuple, ent) in enumerate(per_q): cache_key = (i, frozenset(idx_tuple)) answer = llm_cache[cache_key] ems.append(exact_match(answer, gold_answers[i])) f1s.append(f1_score(answer, gold_answers[i])) overlap = len(set(idx_tuple) & set(faiss_indices[i])) / top_k overlaps.append(overlap) entropies.append(ent) em, f1 = np.mean(ems), np.mean(f1s) r = { "beta": beta, "lambda": lam, "max_iter": max_iter, "em": round(em, 4), "f1": round(f1, 4), "avg_faiss_overlap": round(np.mean(overlaps), 4), "avg_entropy": round(np.mean(entropies), 4), } results.append(r) if f1 >= faiss_f1 - 0.01: marker = " ***" if f1 > faiss_f1 else "" logger.info( " β=%5.1f λ=%.2f iter=%d => EM=%.3f F1=%.3f overlap=%.3f%s", beta, lam, max_iter, em, f1, np.mean(overlaps), marker, ) # Sort by F1 results.sort(key=lambda x: x["f1"], reverse=True) best = results[0] t_total = time.time() - t_start output = { "meta": { "n_questions": n, "total_configs": total_configs, "unique_llm_calls": len(needed), "faiss_llm_calls": n, "total_time_s": round(t_total, 1), }, "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)}, "grid_results": results, "best_config": best, "top10": results[:10], } Path(args.output).parent.mkdir(parents=True, exist_ok=True) with open(args.output, "w") as f: json.dump(output, f, indent=2) logger.info("=" * 60) logger.info("RESULTS") logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1) logger.info( " Best: β=%.1f λ=%.2f iter=%d => EM=%.4f F1=%.4f", best["beta"], best["lambda"], best["max_iter"], best["em"], best["f1"], ) logger.info(" Top 5:") for i, r in enumerate(results[:5]): logger.info( " %d. β=%5.1f λ=%.2f iter=%d => EM=%.3f F1=%.3f overlap=%.3f", i + 1, r["beta"], r["lambda"], r["max_iter"], r["em"], r["f1"], r["avg_faiss_overlap"], ) logger.info(" Total time: %.1fs", t_total) logger.info(" Saved to: %s", args.output) logger.info("=" * 60) if __name__ == "__main__": main()