"""Evaluate centered Hopfield on 100 questions. Memory bank is mean-centered (M̃ = M - μ), query is centered (q̃ = q - μ). β_critical = 37.6: below it origin is stable attractor, above it dynamics escape. Grid: β spanning both sides of β_critical, iter = [0, 1, 2, 3, 5, 8]. No residual — pure Hopfield update on centered space. Usage: CUDA_VISIBLE_DEVICES=1 nohup python -u scripts/eval_centered_grid.py \ --memory-bank data/processed/hotpotqa_memory_bank.pt \ --questions data/processed/hotpotqa_questions.jsonl \ --device cuda --max-samples 100 \ > data/processed/centered_grid.log 2>&1 & """ import argparse import json import logging import sys import time from pathlib import Path from typing import Dict, List, Tuple import numpy as np import torch import torch.nn.functional as F import yaml from hag.config import EncoderConfig, GeneratorConfig, MemoryBankConfig from hag.energy import compute_attention_entropy from hag.encoder import Encoder from hag.generator import Generator from hag.memory_bank import MemoryBank from hag.metrics import exact_match, f1_score from hag.retriever_faiss import FAISSRetriever logging.basicConfig( level=logging.INFO, format="%(asctime)s %(levelname)s %(name)s: %(message)s", stream=sys.stdout, ) logger = logging.getLogger(__name__) def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]: questions, gold_answers = [], [] with open(path) as f: for line in f: r = json.loads(line) questions.append(r["question"]) gold_answers.append(r["answer"]) if len(questions) >= max_samples: break return questions, gold_answers @torch.no_grad() def hopfield_retrieve_centered( query: torch.Tensor, memory_centered: torch.Tensor, mean: torch.Tensor, beta: float, max_iter: int, top_k: int, ) -> Tuple[torch.Tensor, float]: """Pure Hopfield retrieval on centered memory bank. Args: query: (batch, d) raw query embeddings memory_centered: (d, N) centered memory bank (M̃ = M - μ) mean: (d,) memory bank mean beta, max_iter, top_k: Hopfield params Returns: (top_k_indices (batch, top_k), avg_entropy) """ # Center the query q = query - mean.unsqueeze(0) # (batch, d) for _ in range(max_iter): logits = beta * (q @ memory_centered) # (batch, N) alpha = torch.softmax(logits, dim=-1) # (batch, N) q = alpha @ memory_centered.T # (batch, d) # Final attention logits = beta * (q @ memory_centered) alpha = torch.softmax(logits, dim=-1) _, indices = torch.topk(alpha, top_k, dim=-1) entropy = compute_attention_entropy(alpha) return indices, entropy def main() -> None: parser = argparse.ArgumentParser(description="Centered Hopfield grid evaluation") parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") parser.add_argument("--memory-bank", type=str, required=True) parser.add_argument("--questions", type=str, required=True) parser.add_argument("--device", type=str, default="cpu") parser.add_argument("--max-samples", type=int, default=100) parser.add_argument("--output", type=str, default="data/processed/centered_grid_results.json") parser.add_argument("--top-k", type=int, default=5) args = parser.parse_args() with open(args.config) as f: cfg = yaml.safe_load(f) # Grid: span β_critical ≈ 37.6 betas = [10.0, 20.0, 30.0, 38.0, 40.0, 45.0, 50.0, 60.0, 75.0, 100.0, 150.0, 200.0] max_iters_list = [0, 1, 2, 3, 5, 8] top_k = args.top_k total_configs = len(betas) * len(max_iters_list) logger.info("=" * 60) logger.info("Centered Hopfield Grid Search") logger.info(" β_critical ≈ 37.6") logger.info(" betas: %s", betas) logger.info(" max_iters: %s", max_iters_list) logger.info(" total configs: %d", total_configs) logger.info("=" * 60) t_start = time.time() questions, gold_answers = load_questions(args.questions, args.max_samples) n = len(questions) logger.info("Loaded %d questions", n) # Load memory bank (uncentered) mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) mb.load(args.memory_bank, device=args.device) M_raw = mb.embeddings # (d, N) d, N = M_raw.shape logger.info("Memory bank: %d passages, dim=%d", N, d) # Center the memory bank mu = M_raw.mean(dim=1) # (d,) M_cent = M_raw - mu.unsqueeze(1) # (d, N) logger.info("Centered memory bank. ‖μ‖=%.4f, ‖M̃·1/N‖=%.2e", mu.norm().item(), M_cent.mean(dim=1).norm().item()) # Compute β_critical S = torch.linalg.svdvals(M_cent) lambda_max_C = (S[0].item() ** 2) / N beta_crit = 1.0 / lambda_max_C logger.info("β_critical = %.2f (λ_max(C)=%.4f, σ_max=%.4f)", beta_crit, lambda_max_C, S[0].item()) encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device) generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device) logger.info("Encoding questions...") all_embs = [] batch_size = cfg.get("encoder", {}).get("batch_size", 64) for i in range(0, n, batch_size): all_embs.append(encoder.encode(questions[i : i + batch_size])) Q = torch.cat(all_embs, dim=0) # (n, d) logger.info("Encoded, shape=%s", Q.shape) # FAISS baseline (on raw embeddings) logger.info("Running FAISS baseline...") emb_np = M_raw.T.cpu().numpy().astype(np.float32) faiss_ret = FAISSRetriever(top_k=top_k) faiss_ret.build_index(emb_np, mb.passages) faiss_indices: Dict[int, Tuple[int, ...]] = {} llm_cache: Dict[Tuple[int, frozenset], str] = {} for i in range(n): q_np = Q[i].cpu().numpy().astype(np.float32) result = faiss_ret.retrieve(q_np) idx_tuple = tuple(sorted(result.indices.tolist())) faiss_indices[i] = idx_tuple cache_key = (i, frozenset(idx_tuple)) answer = generator.generate(questions[i], result.passages) llm_cache[cache_key] = answer if (i + 1) % 20 == 0: ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s)) faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1) # Phase 2: Retrieve all configs (centered) logger.info("Phase 2: Retrieving all %d configs (centered)...", total_configs) t_ret = time.time() retrieval_data: Dict[str, List[Tuple[Tuple[int, ...], float]]] = {} for beta in betas: for max_iter in max_iters_list: config_key = f"β={beta}_iter={max_iter}" indices_batch, entropy = hopfield_retrieve_centered( Q, M_cent, mu, beta=beta, max_iter=max_iter, top_k=top_k, ) per_q = [] for i in range(n): idx_tuple = tuple(sorted(indices_batch[i].tolist())) per_q.append((idx_tuple, entropy)) retrieval_data[config_key] = per_q logger.info("Retrieval done in %.1fs, %d configs", time.time() - t_ret, len(retrieval_data)) # Phase 3: Dedup + generate needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {} for key, per_q in retrieval_data.items(): for i, (idx_tuple, _) in enumerate(per_q): cache_key = (i, frozenset(idx_tuple)) if cache_key not in llm_cache and cache_key not in needed: needed[cache_key] = (i, idx_tuple) logger.info("Unique LLM calls needed: %d (cache has %d)", len(needed), len(llm_cache)) t_gen = time.time() for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()): passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long)) answer = generator.generate(questions[q_idx], passages) llm_cache[cache_key] = answer if (call_idx + 1) % 50 == 0: elapsed = time.time() - t_gen rate = (call_idx + 1) / elapsed logger.info(" Generated %d/%d (%.1f/s, ~%.0fs left)", call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate) logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen) # Phase 4: Evaluate logger.info("Phase 4: Evaluating...") results = [] for config_key, per_q in retrieval_data.items(): ems, f1s, overlaps = [], [], [] for i, (idx_tuple, ent) in enumerate(per_q): cache_key = (i, frozenset(idx_tuple)) answer = llm_cache[cache_key] ems.append(exact_match(answer, gold_answers[i])) f1s.append(f1_score(answer, gold_answers[i])) overlaps.append(len(set(idx_tuple) & set(faiss_indices[i])) / top_k) em, f1 = np.mean(ems), np.mean(f1s) # Parse beta from config_key beta_val = float(config_key.split("_")[0].split("=")[1]) iter_val = int(config_key.split("_")[1].split("=")[1]) r = { "config": config_key, "beta": beta_val, "max_iter": iter_val, "em": round(em, 4), "f1": round(f1, 4), "avg_faiss_overlap": round(np.mean(overlaps), 4), "avg_entropy": round(per_q[0][1], 4), "above_beta_crit": beta_val > beta_crit, } results.append(r) results.sort(key=lambda x: x["f1"], reverse=True) # Count how many beat FAISS n_beat = sum(1 for r in results if r["f1"] > faiss_f1) logger.info("\n%d/%d configs beat FAISS F1=%.3f", n_beat, len(results), faiss_f1) # Log top 20 logger.info("\nTop 20 configs:") for i, r in enumerate(results[:20]): marker = " ***" if r["f1"] > faiss_f1 else "" crit = ">" if r["above_beta_crit"] else "<" logger.info(" %2d. %-25s EM=%.3f F1=%.3f overlap=%.3f H=%.2f β%sβ_c%s", i + 1, r["config"], r["em"], r["f1"], r["avg_faiss_overlap"], r["avg_entropy"], crit, marker) # Summary by β: best iter for each β logger.info("\nBest iter per β:") for beta in betas: beta_results = [r for r in results if r["beta"] == beta] if beta_results: best = beta_results[0] crit = ">" if best["above_beta_crit"] else "<" logger.info(" β=%6.1f (β%sβ_c): best iter=%d EM=%.3f F1=%.3f overlap=%.3f", beta, crit, best["max_iter"], best["em"], best["f1"], best["avg_faiss_overlap"]) t_total = time.time() - t_start output = { "meta": { "n_questions": n, "total_configs": len(retrieval_data), "unique_llm_calls": len(needed), "total_time_s": round(t_total, 1), "beta_critical": round(beta_crit, 2), }, "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)}, "grid_results": results, "best_config": results[0], "top10": results[:10], } Path(args.output).parent.mkdir(parents=True, exist_ok=True) with open(args.output, "w") as f: json.dump(output, f, indent=2) logger.info("=" * 60) logger.info("RESULTS SUMMARY") logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1) logger.info(" β_critical = %.2f", beta_crit) logger.info(" Configs beating FAISS: %d/%d", n_beat, len(results)) logger.info(" Top 5:") for i, r in enumerate(results[:5]): logger.info(" %d. %-25s EM=%.3f F1=%.3f overlap=%.3f", i + 1, r["config"], r["em"], r["f1"], r["avg_faiss_overlap"]) logger.info(" Total time: %.1fs", t_total) logger.info(" Saved to: %s", args.output) logger.info("=" * 60) if __name__ == "__main__": main()