summaryrefslogtreecommitdiff
path: root/scripts/eval_residual_grid.py
diff options
context:
space:
mode:
Diffstat (limited to 'scripts/eval_residual_grid.py')
-rw-r--r--scripts/eval_residual_grid.py298
1 files changed, 298 insertions, 0 deletions
diff --git a/scripts/eval_residual_grid.py b/scripts/eval_residual_grid.py
new file mode 100644
index 0000000..f716c51
--- /dev/null
+++ b/scripts/eval_residual_grid.py
@@ -0,0 +1,298 @@
+"""Evaluate Residual Hopfield configs on 100 questions with dedup-based LLM caching.
+
+Residual update: q_{t+1} = λ * q_t + (1-λ) * M @ softmax(β * M^T @ q_t)
+
+Usage:
+ CUDA_VISIBLE_DEVICES=1 python scripts/eval_residual_grid.py \
+ --config configs/hotpotqa.yaml \
+ --memory-bank data/processed/hotpotqa_memory_bank.pt \
+ --questions data/processed/hotpotqa_questions.jsonl \
+ --device cuda --max-samples 100
+"""
+
+import argparse
+import json
+import logging
+import sys
+import time
+from pathlib import Path
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+import yaml
+
+from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.energy import compute_attention_entropy
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.metrics import exact_match, f1_score
+from hag.retriever_faiss import FAISSRetriever
+
+logging.basicConfig(
+ level=logging.INFO,
+ format="%(asctime)s %(levelname)s %(name)s: %(message)s",
+ stream=sys.stdout,
+)
+logger = logging.getLogger(__name__)
+
+
+def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]:
+ questions, gold_answers = [], []
+ with open(path) as f:
+ for line in f:
+ r = json.loads(line)
+ questions.append(r["question"])
+ gold_answers.append(r["answer"])
+ if len(questions) >= max_samples:
+ break
+ return questions, gold_answers
+
+
+@torch.no_grad()
+def residual_hopfield_retrieve(
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ beta: float,
+ lam: float,
+ max_iter: int,
+ top_k: int,
+) -> Tuple[torch.Tensor, torch.Tensor, float]:
+ """Residual Hopfield retrieval on full memory bank.
+
+ q_{t+1} = λ * q_t + (1-λ) * M @ softmax(β * M^T @ q_t)
+
+ Args:
+ query: (batch, d)
+ memory: (d, N)
+ beta: inverse temperature
+ lam: residual weight (0=pure Hopfield, 1=no update)
+ max_iter: number of iterations
+ top_k: number of passages to return
+
+ Returns:
+ (top_k_indices, top_k_scores, avg_entropy) for the batch.
+ indices: (batch, top_k), scores: (batch, top_k), entropy: float
+ """
+ q = query.clone()
+ for _ in range(max_iter):
+ logits = beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ q_hop = alpha @ memory.T # (batch, d)
+ q = lam * q + (1.0 - lam) * q_hop # (batch, d)
+
+ # Final attention
+ logits = beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ scores, indices = torch.topk(alpha, top_k, dim=-1) # (batch, top_k)
+
+ entropy = compute_attention_entropy(alpha)
+ return indices, scores, entropy
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Residual Hopfield grid evaluation")
+ parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ parser.add_argument("--max-samples", type=int, default=100)
+ parser.add_argument("--output", type=str, default="data/processed/residual_grid_results.json")
+ parser.add_argument("--top-k", type=int, default=5)
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ # Grid
+ betas = [5.0, 10.0, 20.0, 50.0, 100.0]
+ lambdas = [0.5, 0.7, 0.8, 0.9, 0.95]
+ max_iters_list = [1, 3, 5, 8]
+ top_k = args.top_k
+
+ total_configs = len(betas) * len(lambdas) * len(max_iters_list)
+ logger.info("=" * 60)
+ logger.info("Residual Hopfield Grid Search")
+ logger.info(" betas: %s", betas)
+ logger.info(" lambdas: %s", lambdas)
+ logger.info(" max_iters: %s", max_iters_list)
+ logger.info(" total configs: %d", total_configs)
+ logger.info("=" * 60)
+
+ t_start = time.time()
+
+ # Load
+ questions, gold_answers = load_questions(args.questions, args.max_samples)
+ n = len(questions)
+ logger.info("Loaded %d questions", n)
+
+ mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {})))
+ mb.load(args.memory_bank, device=args.device)
+ M = mb.embeddings # (d, N)
+ logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim)
+
+ encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device)
+ generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device)
+
+ logger.info("Encoding questions...")
+ all_embs = []
+ batch_size = cfg.get("encoder", {}).get("batch_size", 64)
+ for i in range(0, n, batch_size):
+ all_embs.append(encoder.encode(questions[i : i + batch_size]))
+ Q = torch.cat(all_embs, dim=0) # (n, d)
+ logger.info("Encoded, shape=%s", Q.shape)
+
+ # FAISS baseline
+ logger.info("Running FAISS baseline...")
+ emb_np = mb.embeddings.T.cpu().numpy().astype(np.float32)
+ faiss_ret = FAISSRetriever(top_k=top_k)
+ faiss_ret.build_index(emb_np, mb.passages)
+
+ faiss_indices: Dict[int, Tuple[int, ...]] = {}
+ llm_cache: Dict[Tuple[int, frozenset], str] = {}
+
+ for i in range(n):
+ q_np = Q[i].cpu().numpy().astype(np.float32)
+ result = faiss_ret.retrieve(q_np)
+ idx_tuple = tuple(sorted(result.indices.tolist()))
+ faiss_indices[i] = idx_tuple
+ cache_key = (i, frozenset(idx_tuple))
+ answer = generator.generate(questions[i], result.passages)
+ llm_cache[cache_key] = answer
+ if (i + 1) % 20 == 0:
+ ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)]
+ f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)]
+ logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s))
+
+ faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)])
+ faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)])
+ logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1)
+
+ # Phase 2: Retrieve all configs (batched, fast)
+ logger.info("Phase 2: Retrieving all %d configs...", total_configs)
+ t_ret = time.time()
+
+ # config_key -> list of (sorted_indices_tuple, entropy) per question
+ retrieval_data: Dict[Tuple[float, float, int], List[Tuple[Tuple[int, ...], float]]] = {}
+
+ for beta in betas:
+ for lam in lambdas:
+ for max_iter in max_iters_list:
+ indices, scores, entropy = residual_hopfield_retrieve(
+ Q, M, beta=beta, lam=lam, max_iter=max_iter, top_k=top_k
+ )
+ per_q = []
+ for i in range(n):
+ idx_tuple = tuple(sorted(indices[i].tolist()))
+ # per-question entropy
+ per_q.append((idx_tuple, entropy))
+ retrieval_data[(beta, lam, max_iter)] = per_q
+
+ logger.info("Retrieval done in %.1fs", time.time() - t_ret)
+
+ # Phase 3: Dedup + generate
+ needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {}
+ for key, per_q in retrieval_data.items():
+ for i, (idx_tuple, _) in enumerate(per_q):
+ cache_key = (i, frozenset(idx_tuple))
+ if cache_key not in llm_cache and cache_key not in needed:
+ needed[cache_key] = (i, idx_tuple)
+
+ total_grid_evals = total_configs * n
+ logger.info(
+ "Unique LLM calls needed: %d / %d grid evals (%.1f%% saving)",
+ len(needed), total_grid_evals,
+ (1 - len(needed) / total_grid_evals) * 100,
+ )
+
+ t_gen = time.time()
+ for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()):
+ passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long))
+ answer = generator.generate(questions[q_idx], passages)
+ llm_cache[cache_key] = answer
+ if (call_idx + 1) % 50 == 0:
+ elapsed = time.time() - t_gen
+ rate = (call_idx + 1) / elapsed
+ logger.info(
+ " Generated %d/%d (%.1f/s, ~%.0fs left)",
+ call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate,
+ )
+ logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen)
+
+ # Phase 4: Evaluate
+ logger.info("Phase 4: Evaluating...")
+ results = []
+ for beta in betas:
+ for lam in lambdas:
+ for max_iter in max_iters_list:
+ per_q = retrieval_data[(beta, lam, max_iter)]
+ ems, f1s, overlaps, entropies = [], [], [], []
+ for i, (idx_tuple, ent) in enumerate(per_q):
+ cache_key = (i, frozenset(idx_tuple))
+ answer = llm_cache[cache_key]
+ ems.append(exact_match(answer, gold_answers[i]))
+ f1s.append(f1_score(answer, gold_answers[i]))
+ overlap = len(set(idx_tuple) & set(faiss_indices[i])) / top_k
+ overlaps.append(overlap)
+ entropies.append(ent)
+
+ em, f1 = np.mean(ems), np.mean(f1s)
+ r = {
+ "beta": beta, "lambda": lam, "max_iter": max_iter,
+ "em": round(em, 4), "f1": round(f1, 4),
+ "avg_faiss_overlap": round(np.mean(overlaps), 4),
+ "avg_entropy": round(np.mean(entropies), 4),
+ }
+ results.append(r)
+ if f1 >= faiss_f1 - 0.01:
+ marker = " ***" if f1 > faiss_f1 else ""
+ logger.info(
+ " β=%5.1f λ=%.2f iter=%d => EM=%.3f F1=%.3f overlap=%.3f%s",
+ beta, lam, max_iter, em, f1, np.mean(overlaps), marker,
+ )
+
+ # Sort by F1
+ results.sort(key=lambda x: x["f1"], reverse=True)
+ best = results[0]
+
+ t_total = time.time() - t_start
+
+ output = {
+ "meta": {
+ "n_questions": n,
+ "total_configs": total_configs,
+ "unique_llm_calls": len(needed),
+ "faiss_llm_calls": n,
+ "total_time_s": round(t_total, 1),
+ },
+ "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)},
+ "grid_results": results,
+ "best_config": best,
+ "top10": results[:10],
+ }
+
+ Path(args.output).parent.mkdir(parents=True, exist_ok=True)
+ with open(args.output, "w") as f:
+ json.dump(output, f, indent=2)
+
+ logger.info("=" * 60)
+ logger.info("RESULTS")
+ logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1)
+ logger.info(
+ " Best: β=%.1f λ=%.2f iter=%d => EM=%.4f F1=%.4f",
+ best["beta"], best["lambda"], best["max_iter"], best["em"], best["f1"],
+ )
+ logger.info(" Top 5:")
+ for i, r in enumerate(results[:5]):
+ logger.info(
+ " %d. β=%5.1f λ=%.2f iter=%d => EM=%.3f F1=%.3f overlap=%.3f",
+ i + 1, r["beta"], r["lambda"], r["max_iter"], r["em"], r["f1"], r["avg_faiss_overlap"],
+ )
+ logger.info(" Total time: %.1fs", t_total)
+ logger.info(" Saved to: %s", args.output)
+ logger.info("=" * 60)
+
+
+if __name__ == "__main__":
+ main()