"""Evaluate high-β Hopfield (standard + normalized update) on 100 questions. Tests whether high β (≥50) allows standard Hopfield to work without residual. Also tests normalized update: q → normalize(M @ softmax(β * M^T @ q)). Usage: CUDA_VISIBLE_DEVICES=0 python -u scripts/eval_highbeta_grid.py \ --config configs/hotpotqa.yaml \ --memory-bank data/processed/hotpotqa_memory_bank.pt \ --questions data/processed/hotpotqa_questions.jsonl \ --device cuda --max-samples 100 """ import argparse import json import logging import sys import time from pathlib import Path from typing import Dict, List, Tuple import numpy as np import torch import torch.nn.functional as F import yaml from hag.config import EncoderConfig, GeneratorConfig, MemoryBankConfig from hag.energy import compute_attention_entropy from hag.encoder import Encoder from hag.generator import Generator from hag.memory_bank import MemoryBank from hag.metrics import exact_match, f1_score from hag.retriever_faiss import FAISSRetriever logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", stream=sys.stdout, ) logger = logging.getLogger(__name__) def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]: questions, gold_answers = [], [] with open(path) as f: for line in f: r = json.loads(line) questions.append(r["question"]) gold_answers.append(r["answer"]) if len(questions) >= max_samples: break return questions, gold_answers @torch.no_grad() def hopfield_retrieve( query: torch.Tensor, memory: torch.Tensor, beta: float, max_iter: int, top_k: int, mode: str = "standard", lam: float = 0.0, ) -> Tuple[torch.Tensor, float]: """Hopfield retrieval with different update modes. Args: query: (batch, d) memory: (d, N) beta, max_iter, top_k: Hopfield params mode: "standard" | "normalized" | "residual" lam: residual weight (only for mode="residual") Returns: (top_k_indices (batch, top_k), avg_entropy) """ q = query.clone() if mode == "normalized": q = F.normalize(q, dim=-1) for _ in range(max_iter): logits = beta * (q @ memory) # (batch, N) alpha = torch.softmax(logits, dim=-1) # (batch, N) q_hop = alpha @ memory.T # (batch, d) if mode == "standard": q = q_hop elif mode == "normalized": q = F.normalize(q_hop, dim=-1) elif mode == "residual": q = lam * q + (1.0 - lam) * q_hop # Final attention logits = beta * (q @ memory) alpha = torch.softmax(logits, dim=-1) _, indices = torch.topk(alpha, top_k, dim=-1) entropy = compute_attention_entropy(alpha) return indices, entropy def main() -> None: parser = argparse.ArgumentParser(description="High-β Hopfield grid evaluation") parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--questions", type=str, required=True) parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--max-samples", type=int, default=100) parser.add_argument("--output", type=str, default="data/processed/highbeta_grid_results.json") parser.add_argument("--top-k", type=int, default=5) args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) # Grid: high β focus betas = [20.0, 50.0, 100.0, 200.0, 500.0] max_iters_list = [0, 1, 2, 3, 5, 8] # Modes: standard, normalized, residual(λ=0.9) modes = [ ("standard", 0.0), ("normalized", 0.0), ("residual_0.9", 0.9), ("residual_0.95", 0.95), ] top_k = args.top_k total_configs = len(betas) * len(max_iters_list) * len(modes) logger.info("=" * 60) logger.info("High-β Hopfield Grid Search") logger.info(" betas: %s", betas) logger.info(" max_iters: %s", max_iters_list) logger.info(" modes: %s", [m[0] for m in modes]) logger.info(" total configs: %d", total_configs) logger.info("=" * 60) t_start = time.time() questions, gold_answers = load_questions(args.questions, args.max_samples) n = len(questions) logger.info("Loaded %d questions", n) mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) mb.load(args.memory_bank, device=args.device) M = mb.embeddings logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim) encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device) generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device) logger.info("Encoding questions...") all_embs = [] batch_size = cfg.get("encoder", {}).get("batch_size", 64) for i in range(0, n, batch_size): all_embs.append(encoder.encode(questions[i : i + batch_size])) Q = torch.cat(all_embs, dim=0) logger.info("Encoded, shape=%s", Q.shape) # FAISS baseline logger.info("Running FAISS baseline...") emb_np = mb.embeddings.T.cpu().numpy().astype(np.float32) faiss_ret = FAISSRetriever(top_k=top_k) faiss_ret.build_index(emb_np, mb.passages) faiss_indices: Dict[int, Tuple[int, ...]] = {} llm_cache: Dict[Tuple[int, frozenset], str] = {} for i in range(n): q_np = Q[i].cpu().numpy().astype(np.float32) result = faiss_ret.retrieve(q_np) idx_tuple = tuple(sorted(result.indices.tolist())) faiss_indices[i] = idx_tuple cache_key = (i, frozenset(idx_tuple)) answer = generator.generate(questions[i], result.passages) llm_cache[cache_key] = answer if (i + 1) % 20 == 0: ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s)) faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1) # Phase 2: Retrieve all configs logger.info("Phase 2: Retrieving all %d configs...", total_configs) t_ret = time.time() retrieval_data: Dict[str, List[Tuple[Tuple[int, ...], float]]] = {} for beta in betas: for max_iter in max_iters_list: for mode_name, lam in modes: if max_iter == 0: # iter=0: just use initial query's softmax top-k (same for all modes) if mode_name != "standard": continue # skip duplicates for iter=0 indices_batch = (beta * (Q @ M)).softmax(dim=-1).topk(top_k, dim=-1).indices entropy = compute_attention_entropy((beta * (Q @ M)).softmax(dim=-1)) config_key = f"β={beta}_iter=0_standard" else: actual_mode = "residual" if mode_name.startswith("residual") else mode_name indices_batch, entropy = hopfield_retrieve( Q, M, beta=beta, max_iter=max_iter, top_k=top_k, mode=actual_mode, lam=lam, ) config_key = f"β={beta}_iter={max_iter}_{mode_name}" per_q = [] for i in range(n): idx_tuple = tuple(sorted(indices_batch[i].tolist())) per_q.append((idx_tuple, entropy)) retrieval_data[config_key] = per_q logger.info("Retrieval done in %.1fs, %d configs", time.time() - t_ret, len(retrieval_data)) # Phase 3: Dedup + generate needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {} for key, per_q in retrieval_data.items(): for i, (idx_tuple, _) in enumerate(per_q): cache_key = (i, frozenset(idx_tuple)) if cache_key not in llm_cache and cache_key not in needed: needed[cache_key] = (i, idx_tuple) logger.info("Unique LLM calls needed: %d (cache has %d)", len(needed), len(llm_cache)) t_gen = time.time() for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()): passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long)) answer = generator.generate(questions[q_idx], passages) llm_cache[cache_key] = answer if (call_idx + 1) % 50 == 0: elapsed = time.time() - t_gen rate = (call_idx + 1) / elapsed logger.info(" Generated %d/%d (%.1f/s, ~%.0fs left)", call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate) logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen) # Phase 4: Evaluate logger.info("Phase 4: Evaluating...") results = [] for config_key, per_q in retrieval_data.items(): ems, f1s, overlaps = [], [], [] for i, (idx_tuple, ent) in enumerate(per_q): cache_key = (i, frozenset(idx_tuple)) answer = llm_cache[cache_key] ems.append(exact_match(answer, gold_answers[i])) f1s.append(f1_score(answer, gold_answers[i])) overlaps.append(len(set(idx_tuple) & set(faiss_indices[i])) / top_k) em, f1 = np.mean(ems), np.mean(f1s) r = { "config": config_key, "em": round(em, 4), "f1": round(f1, 4), "avg_faiss_overlap": round(np.mean(overlaps), 4), "avg_entropy": round(per_q[0][1], 4), } results.append(r) results.sort(key=lambda x: x["f1"], reverse=True) # Log all that beat or match FAISS logger.info("\nConfigs matching or beating FAISS (F1≥%.3f):", faiss_f1) for r in results: if r["f1"] >= faiss_f1 - 0.005: marker = " ***" if r["f1"] > faiss_f1 else "" logger.info(" %s: EM=%.3f F1=%.3f overlap=%.3f%s", r["config"], r["em"], r["f1"], r["avg_faiss_overlap"], marker) t_total = time.time() - t_start output = { "meta": { "n_questions": n, "total_configs": len(retrieval_data), "unique_llm_calls": len(needed), "total_time_s": round(t_total, 1), }, "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)}, "grid_results": results, "best_config": results[0], "top10": results[:10], } Path(args.output).parent.mkdir(parents=True, exist_ok=True) with open(args.output, "w") as f: json.dump(output, f, indent=2) logger.info("=" * 60) logger.info("RESULTS") logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1) logger.info(" Top 10:") for i, r in enumerate(results[:10]): logger.info(" %2d. %-40s EM=%.3f F1=%.3f overlap=%.3f", i + 1, r["config"], r["em"], r["f1"], r["avg_faiss_overlap"]) logger.info(" Total time: %.1fs", t_total) logger.info(" Saved to: %s", args.output) logger.info("=" * 60) if __name__ == "__main__": main()