diff options
| author | YurenHao0426 <Blackhao0426@gmail.com> | 2026-02-16 14:44:42 -0600 |
|---|---|---|
| committer | YurenHao0426 <Blackhao0426@gmail.com> | 2026-02-16 14:44:42 -0600 |
| commit | 09d50e47860da0035e178a442dc936028808a0b3 (patch) | |
| tree | 9d651b0c7d289a9a0405953f2da989a3c431f147 /scripts/eval_residual_grid.py | |
| parent | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (diff) | |
- Add centering support to MemoryBank (center_query, apply_centering, mean
persistence in save/load) to remove centroid attractor in Hopfield dynamics
- Add center flag to MemoryBankConfig, device field to PipelineConfig
- Grid search scripts: initial (β≤8), residual, high-β, and centered grids
with dedup-based LLM caching (89-91% call savings)
- Energy landscape visualization: 2D contour, 1D profile, UMAP, PCA heatmap
comparing centered vs uncentered dynamics
- Experiment log (note.md) documenting 4 rounds of results and root cause
analysis of centroid attractor problem
- Key finding: β_critical ≈ 37.6 for centered memory; best configs beat
FAISS baseline by +3-4% F1
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts/eval_residual_grid.py')
| -rw-r--r-- | scripts/eval_residual_grid.py | 298 |
1 files changed, 298 insertions, 0 deletions
diff --git a/scripts/eval_residual_grid.py b/scripts/eval_residual_grid.py new file mode 100644 index 0000000..f716c51 --- /dev/null +++ b/scripts/eval_residual_grid.py @@ -0,0 +1,298 @@ +"""Evaluate Residual Hopfield configs on 100 questions with dedup-based LLM caching. + +Residual update: q_{t+1} = λ * q_t + (1-λ) * M @ softmax(β * M^T @ q_t) + +Usage: + CUDA_VISIBLE_DEVICES=1 python scripts/eval_residual_grid.py \ + --config configs/hotpotqa.yaml \ + --memory-bank data/processed/hotpotqa_memory_bank.pt \ + --questions data/processed/hotpotqa_questions.jsonl \ + --device cuda --max-samples 100 +""" + +import argparse +import json +import logging +import sys +import time +from pathlib import Path +from typing import Dict, List, Tuple + +import numpy as np +import torch +import yaml + +from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.energy import compute_attention_entropy +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.metrics import exact_match, f1_score +from hag.retriever_faiss import FAISSRetriever + +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s %(levelname)s %(name)s: %(message)s", + stream=sys.stdout, +) +logger = logging.getLogger(__name__) + + +def load_questions(path: str, max_samples: int) -> Tuple[List[str], List[str]]: + questions, gold_answers = [], [] + with open(path) as f: + for line in f: + r = json.loads(line) + questions.append(r["question"]) + gold_answers.append(r["answer"]) + if len(questions) >= max_samples: + break + return questions, gold_answers + + +@torch.no_grad() +def residual_hopfield_retrieve( + query: torch.Tensor, + memory: torch.Tensor, + beta: float, + lam: float, + max_iter: int, + top_k: int, +) -> Tuple[torch.Tensor, torch.Tensor, float]: + """Residual Hopfield retrieval on full memory bank. + + q_{t+1} = λ * q_t + (1-λ) * M @ softmax(β * M^T @ q_t) + + Args: + query: (batch, d) + memory: (d, N) + beta: inverse temperature + lam: residual weight (0=pure Hopfield, 1=no update) + max_iter: number of iterations + top_k: number of passages to return + + Returns: + (top_k_indices, top_k_scores, avg_entropy) for the batch. + indices: (batch, top_k), scores: (batch, top_k), entropy: float + """ + q = query.clone() + for _ in range(max_iter): + logits = beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + q_hop = alpha @ memory.T # (batch, d) + q = lam * q + (1.0 - lam) * q_hop # (batch, d) + + # Final attention + logits = beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + scores, indices = torch.topk(alpha, top_k, dim=-1) # (batch, top_k) + + entropy = compute_attention_entropy(alpha) + return indices, scores, entropy + + +def main() -> None: + parser = argparse.ArgumentParser(description="Residual Hopfield grid evaluation") + parser.add_argument("--config", type=str, default="configs/hotpotqa.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + parser.add_argument("--max-samples", type=int, default=100) + parser.add_argument("--output", type=str, default="data/processed/residual_grid_results.json") + parser.add_argument("--top-k", type=int, default=5) + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + # Grid + betas = [5.0, 10.0, 20.0, 50.0, 100.0] + lambdas = [0.5, 0.7, 0.8, 0.9, 0.95] + max_iters_list = [1, 3, 5, 8] + top_k = args.top_k + + total_configs = len(betas) * len(lambdas) * len(max_iters_list) + logger.info("=" * 60) + logger.info("Residual Hopfield Grid Search") + logger.info(" betas: %s", betas) + logger.info(" lambdas: %s", lambdas) + logger.info(" max_iters: %s", max_iters_list) + logger.info(" total configs: %d", total_configs) + logger.info("=" * 60) + + t_start = time.time() + + # Load + questions, gold_answers = load_questions(args.questions, args.max_samples) + n = len(questions) + logger.info("Loaded %d questions", n) + + mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) + mb.load(args.memory_bank, device=args.device) + M = mb.embeddings # (d, N) + logger.info("Memory bank: %d passages, dim=%d", mb.size, mb.dim) + + encoder = Encoder(EncoderConfig(**cfg.get("encoder", {})), device=args.device) + generator = Generator(GeneratorConfig(**cfg.get("generator", {})), device=args.device) + + logger.info("Encoding questions...") + all_embs = [] + batch_size = cfg.get("encoder", {}).get("batch_size", 64) + for i in range(0, n, batch_size): + all_embs.append(encoder.encode(questions[i : i + batch_size])) + Q = torch.cat(all_embs, dim=0) # (n, d) + logger.info("Encoded, shape=%s", Q.shape) + + # FAISS baseline + logger.info("Running FAISS baseline...") + emb_np = mb.embeddings.T.cpu().numpy().astype(np.float32) + faiss_ret = FAISSRetriever(top_k=top_k) + faiss_ret.build_index(emb_np, mb.passages) + + faiss_indices: Dict[int, Tuple[int, ...]] = {} + llm_cache: Dict[Tuple[int, frozenset], str] = {} + + for i in range(n): + q_np = Q[i].cpu().numpy().astype(np.float32) + result = faiss_ret.retrieve(q_np) + idx_tuple = tuple(sorted(result.indices.tolist())) + faiss_indices[i] = idx_tuple + cache_key = (i, frozenset(idx_tuple)) + answer = generator.generate(questions[i], result.passages) + llm_cache[cache_key] = answer + if (i + 1) % 20 == 0: + ems = [exact_match(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] + f1s = [f1_score(llm_cache[(j, frozenset(faiss_indices[j]))], gold_answers[j]) for j in range(i + 1)] + logger.info(" FAISS %d/%d: EM=%.3f F1=%.3f", i + 1, n, np.mean(ems), np.mean(f1s)) + + faiss_em = np.mean([exact_match(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) + faiss_f1 = np.mean([f1_score(llm_cache[(i, frozenset(faiss_indices[i]))], gold_answers[i]) for i in range(n)]) + logger.info("FAISS baseline: EM=%.4f F1=%.4f", faiss_em, faiss_f1) + + # Phase 2: Retrieve all configs (batched, fast) + logger.info("Phase 2: Retrieving all %d configs...", total_configs) + t_ret = time.time() + + # config_key -> list of (sorted_indices_tuple, entropy) per question + retrieval_data: Dict[Tuple[float, float, int], List[Tuple[Tuple[int, ...], float]]] = {} + + for beta in betas: + for lam in lambdas: + for max_iter in max_iters_list: + indices, scores, entropy = residual_hopfield_retrieve( + Q, M, beta=beta, lam=lam, max_iter=max_iter, top_k=top_k + ) + per_q = [] + for i in range(n): + idx_tuple = tuple(sorted(indices[i].tolist())) + # per-question entropy + per_q.append((idx_tuple, entropy)) + retrieval_data[(beta, lam, max_iter)] = per_q + + logger.info("Retrieval done in %.1fs", time.time() - t_ret) + + # Phase 3: Dedup + generate + needed: Dict[Tuple[int, frozenset], Tuple[int, Tuple[int, ...]]] = {} + for key, per_q in retrieval_data.items(): + for i, (idx_tuple, _) in enumerate(per_q): + cache_key = (i, frozenset(idx_tuple)) + if cache_key not in llm_cache and cache_key not in needed: + needed[cache_key] = (i, idx_tuple) + + total_grid_evals = total_configs * n + logger.info( + "Unique LLM calls needed: %d / %d grid evals (%.1f%% saving)", + len(needed), total_grid_evals, + (1 - len(needed) / total_grid_evals) * 100, + ) + + t_gen = time.time() + for call_idx, (cache_key, (q_idx, idx_tuple)) in enumerate(needed.items()): + passages = mb.get_passages_by_indices(torch.tensor(list(idx_tuple), dtype=torch.long)) + answer = generator.generate(questions[q_idx], passages) + llm_cache[cache_key] = answer + if (call_idx + 1) % 50 == 0: + elapsed = time.time() - t_gen + rate = (call_idx + 1) / elapsed + logger.info( + " Generated %d/%d (%.1f/s, ~%.0fs left)", + call_idx + 1, len(needed), rate, (len(needed) - call_idx - 1) / rate, + ) + logger.info("Generation done: %d calls in %.1fs", len(needed), time.time() - t_gen) + + # Phase 4: Evaluate + logger.info("Phase 4: Evaluating...") + results = [] + for beta in betas: + for lam in lambdas: + for max_iter in max_iters_list: + per_q = retrieval_data[(beta, lam, max_iter)] + ems, f1s, overlaps, entropies = [], [], [], [] + for i, (idx_tuple, ent) in enumerate(per_q): + cache_key = (i, frozenset(idx_tuple)) + answer = llm_cache[cache_key] + ems.append(exact_match(answer, gold_answers[i])) + f1s.append(f1_score(answer, gold_answers[i])) + overlap = len(set(idx_tuple) & set(faiss_indices[i])) / top_k + overlaps.append(overlap) + entropies.append(ent) + + em, f1 = np.mean(ems), np.mean(f1s) + r = { + "beta": beta, "lambda": lam, "max_iter": max_iter, + "em": round(em, 4), "f1": round(f1, 4), + "avg_faiss_overlap": round(np.mean(overlaps), 4), + "avg_entropy": round(np.mean(entropies), 4), + } + results.append(r) + if f1 >= faiss_f1 - 0.01: + marker = " ***" if f1 > faiss_f1 else "" + logger.info( + " β=%5.1f λ=%.2f iter=%d => EM=%.3f F1=%.3f overlap=%.3f%s", + beta, lam, max_iter, em, f1, np.mean(overlaps), marker, + ) + + # Sort by F1 + results.sort(key=lambda x: x["f1"], reverse=True) + best = results[0] + + t_total = time.time() - t_start + + output = { + "meta": { + "n_questions": n, + "total_configs": total_configs, + "unique_llm_calls": len(needed), + "faiss_llm_calls": n, + "total_time_s": round(t_total, 1), + }, + "faiss_baseline": {"em": round(faiss_em, 4), "f1": round(faiss_f1, 4)}, + "grid_results": results, + "best_config": best, + "top10": results[:10], + } + + Path(args.output).parent.mkdir(parents=True, exist_ok=True) + with open(args.output, "w") as f: + json.dump(output, f, indent=2) + + logger.info("=" * 60) + logger.info("RESULTS") + logger.info(" FAISS: EM=%.4f F1=%.4f", faiss_em, faiss_f1) + logger.info( + " Best: β=%.1f λ=%.2f iter=%d => EM=%.4f F1=%.4f", + best["beta"], best["lambda"], best["max_iter"], best["em"], best["f1"], + ) + logger.info(" Top 5:") + for i, r in enumerate(results[:5]): + logger.info( + " %d. β=%5.1f λ=%.2f iter=%d => EM=%.3f F1=%.3f overlap=%.3f", + i + 1, r["beta"], r["lambda"], r["max_iter"], r["em"], r["f1"], r["avg_faiss_overlap"], + ) + logger.info(" Total time: %.1fs", t_total) + logger.info(" Saved to: %s", args.output) + logger.info("=" * 60) + + +if __name__ == "__main__": + main() |
