diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /scripts | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts')
| -rw-r--r-- | scripts/analyze_energy.py | 78 | ||||
| -rw-r--r-- | scripts/build_memory_bank.py | 72 | ||||
| -rw-r--r-- | scripts/run_baseline.py | 75 | ||||
| -rw-r--r-- | scripts/run_eval.py | 90 | ||||
| -rw-r--r-- | scripts/run_hag.py | 79 | ||||
| -rw-r--r-- | scripts/visualize_trajectory.py | 80 |
6 files changed, 474 insertions, 0 deletions
diff --git a/scripts/analyze_energy.py b/scripts/analyze_energy.py new file mode 100644 index 0000000..fd044a4 --- /dev/null +++ b/scripts/analyze_energy.py @@ -0,0 +1,78 @@ +"""Analyze energy curves and convergence properties of Hopfield retrieval. + +Usage: + python scripts/analyze_energy.py --config configs/default.yaml --memory-bank data/memory_bank.pt --questions data/questions.jsonl --output energy_analysis.json +""" + +import argparse +import json +import logging + +import torch +import yaml + +from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.energy import ( + compute_attention_entropy, + compute_energy_curve, + compute_energy_gap, + verify_monotonic_decrease, +) +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Analyze Hopfield energy curves") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--output", type=str, default="energy_analysis.json") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + + # Load memory bank + mb = MemoryBank(memory_config) + mb.load(args.memory_bank) + + # Load questions + with open(args.questions) as f: + questions = [json.loads(line)["question"] for line in f] + + encoder = Encoder(encoder_config) + hopfield = HopfieldRetrieval(hopfield_config) + + analyses = [] + for q in questions: + query_emb = encoder.encode(q) # (1, d) + result = hopfield.retrieve( + query_emb, mb.embeddings, return_energy=True, return_trajectory=True + ) + + curve = compute_energy_curve(result) + analyses.append({ + "question": q, + "energy_curve": curve, + "energy_gap": compute_energy_gap(curve), + "monotonic": verify_monotonic_decrease(curve), + "num_steps": result.num_steps, + "attention_entropy": compute_attention_entropy(result.attention_weights), + }) + + with open(args.output, "w") as f: + json.dump(analyses, f, indent=2) + logger.info("Energy analysis saved to %s (%d questions)", args.output, len(analyses)) + + +if __name__ == "__main__": + main() diff --git a/scripts/build_memory_bank.py b/scripts/build_memory_bank.py new file mode 100644 index 0000000..2aff828 --- /dev/null +++ b/scripts/build_memory_bank.py @@ -0,0 +1,72 @@ +"""Offline script: encode corpus passages into a memory bank. + +Usage: + python scripts/build_memory_bank.py --config configs/default.yaml --corpus data/corpus.jsonl --output data/memory_bank.pt +""" + +import argparse +import json +import logging + +import torch +import yaml +from tqdm import tqdm + +from hag.config import EncoderConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.memory_bank import MemoryBank + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def load_corpus(path: str) -> list[str]: + """Load passages from a JSONL file (one JSON object per line with 'text' field).""" + passages = [] + with open(path) as f: + for line in f: + obj = json.loads(line) + passages.append(obj["text"]) + return passages + + +def main() -> None: + parser = argparse.ArgumentParser(description="Build memory bank from corpus") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--corpus", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + + # Load corpus + logger.info("Loading corpus from %s", args.corpus) + passages = load_corpus(args.corpus) + logger.info("Loaded %d passages", len(passages)) + + # Encode passages in batches + encoder = Encoder(encoder_config) + all_embeddings = [] + + for i in tqdm(range(0, len(passages), encoder_config.batch_size), desc="Encoding"): + batch = passages[i : i + encoder_config.batch_size] + emb = encoder.encode(batch) # (batch_size, d) + all_embeddings.append(emb.cpu()) + + embeddings = torch.cat(all_embeddings, dim=0) # (N, d) + logger.info("Encoded %d passages -> embeddings shape: %s", len(passages), embeddings.shape) + + # Build and save memory bank + mb = MemoryBank(memory_config) + mb.build_from_embeddings(embeddings, passages) + mb.save(args.output) + logger.info("Memory bank saved to %s", args.output) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_baseline.py b/scripts/run_baseline.py new file mode 100644 index 0000000..74c4710 --- /dev/null +++ b/scripts/run_baseline.py @@ -0,0 +1,75 @@ +"""Run vanilla RAG baseline with FAISS retrieval. + +Usage: + python scripts/run_baseline.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" +""" + +import argparse +import logging + +import torch +import yaml + +from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, PipelineConfig +from hag.encoder import Encoder +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.retriever_faiss import FAISSRetriever +from hag.pipeline import RAGPipeline + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run vanilla RAG baseline") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--question", type=str, required=True) + parser.add_argument("--top-k", type=int, default=5) + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + # Override retriever type to faiss + pipeline_config = PipelineConfig( + hopfield=HopfieldConfig(**{**cfg.get("hopfield", {}), "top_k": args.top_k}), + encoder=EncoderConfig(**cfg.get("encoder", {})), + generator=GeneratorConfig(**cfg.get("generator", {})), + retriever_type="faiss", + ) + + # Load memory bank to get embeddings for FAISS + from hag.config import MemoryBankConfig + + mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) + mb.load(args.memory_bank) + + # Build FAISS index from memory bank embeddings + import numpy as np + + embeddings_np = mb.embeddings.T.numpy().astype(np.float32) # (N, d) + faiss_ret = FAISSRetriever(top_k=args.top_k) + faiss_ret.build_index(embeddings_np, mb.passages) + + encoder = Encoder(pipeline_config.encoder) + generator = Generator(pipeline_config.generator) + + pipeline = RAGPipeline( + config=pipeline_config, + encoder=encoder, + generator=generator, + faiss_retriever=faiss_ret, + ) + + result = pipeline.run(args.question) + print(f"\nQuestion: {result.question}") + print(f"Answer: {result.answer}") + print(f"\nRetrieved passages:") + for i, p in enumerate(result.retrieved_passages): + print(f" [{i+1}] {p[:200]}...") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_eval.py b/scripts/run_eval.py new file mode 100644 index 0000000..713b3c2 --- /dev/null +++ b/scripts/run_eval.py @@ -0,0 +1,90 @@ +"""Run evaluation on a dataset with either FAISS or Hopfield retrieval. + +Usage: + python scripts/run_eval.py --config configs/hotpotqa.yaml --memory-bank data/memory_bank.pt --dataset hotpotqa --split validation --max-samples 500 +""" + +import argparse +import json +import logging + +import yaml + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import Encoder +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.metrics import evaluate_dataset +from hag.pipeline import RAGPipeline +from hag.retriever_faiss import FAISSRetriever + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run HAG/RAG evaluation") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--dataset", type=str, default="hotpotqa") + parser.add_argument("--split", type=str, default="validation") + parser.add_argument("--max-samples", type=int, default=500) + parser.add_argument("--output", type=str, default="results.json") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + pipeline_config = PipelineConfig( + hopfield=HopfieldConfig(**cfg.get("hopfield", {})), + memory=MemoryBankConfig(**cfg.get("memory", {})), + encoder=EncoderConfig(**cfg.get("encoder", {})), + generator=GeneratorConfig(**cfg.get("generator", {})), + retriever_type=cfg.get("retriever_type", "hopfield"), + ) + + # Load memory bank + mb = MemoryBank(pipeline_config.memory) + mb.load(args.memory_bank) + + # Build pipeline + encoder = Encoder(pipeline_config.encoder) + generator = Generator(pipeline_config.generator) + pipeline = RAGPipeline( + config=pipeline_config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + + # Load dataset + from datasets import load_dataset + + logger.info("Loading dataset: %s / %s", args.dataset, args.split) + ds = load_dataset(args.dataset, split=args.split) + if args.max_samples and len(ds) > args.max_samples: + ds = ds.select(range(args.max_samples)) + + questions = [ex["question"] for ex in ds] + gold_answers = [ex["answer"] for ex in ds] + + # Run evaluation + logger.info("Running evaluation on %d questions", len(questions)) + results = pipeline.run_batch(questions) + metrics = evaluate_dataset(results, gold_answers) + + logger.info("Results: %s", metrics) + + with open(args.output, "w") as f: + json.dump(metrics, f, indent=2) + logger.info("Results saved to %s", args.output) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_hag.py b/scripts/run_hag.py new file mode 100644 index 0000000..4cacd1a --- /dev/null +++ b/scripts/run_hag.py @@ -0,0 +1,79 @@ +"""Run HAG (Hopfield-Augmented Generation) on a question. + +Usage: + python scripts/run_hag.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" +""" + +import argparse +import logging + +import yaml + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import Encoder +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.pipeline import RAGPipeline + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run HAG") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--question", type=str, required=True) + parser.add_argument("--beta", type=float, default=None) + parser.add_argument("--max-iter", type=int, default=None) + parser.add_argument("--top-k", type=int, default=None) + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + hopfield_cfg = cfg.get("hopfield", {}) + if args.beta is not None: + hopfield_cfg["beta"] = args.beta + if args.max_iter is not None: + hopfield_cfg["max_iter"] = args.max_iter + if args.top_k is not None: + hopfield_cfg["top_k"] = args.top_k + + pipeline_config = PipelineConfig( + hopfield=HopfieldConfig(**hopfield_cfg), + memory=MemoryBankConfig(**cfg.get("memory", {})), + encoder=EncoderConfig(**cfg.get("encoder", {})), + generator=GeneratorConfig(**cfg.get("generator", {})), + retriever_type="hopfield", + ) + + mb = MemoryBank(pipeline_config.memory) + mb.load(args.memory_bank) + + encoder = Encoder(pipeline_config.encoder) + generator = Generator(pipeline_config.generator) + + pipeline = RAGPipeline( + config=pipeline_config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + + result = pipeline.run(args.question) + print(f"\nQuestion: {result.question}") + print(f"Answer: {result.answer}") + print(f"\nRetrieved passages:") + for i, p in enumerate(result.retrieved_passages): + print(f" [{i+1}] {p[:200]}...") + + +if __name__ == "__main__": + main() diff --git a/scripts/visualize_trajectory.py b/scripts/visualize_trajectory.py new file mode 100644 index 0000000..e4ba902 --- /dev/null +++ b/scripts/visualize_trajectory.py @@ -0,0 +1,80 @@ +"""UMAP visualization of query trajectory in Hopfield retrieval. + +Usage: + python scripts/visualize_trajectory.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" +""" + +import argparse +import logging + +import numpy as np +import torch +import yaml + +from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Visualize Hopfield query trajectory") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--question", type=str, required=True) + parser.add_argument("--output", type=str, default="trajectory.png") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + + mb = MemoryBank(memory_config) + mb.load(args.memory_bank) + + encoder = Encoder(encoder_config) + hopfield = HopfieldRetrieval(hopfield_config) + + query_emb = encoder.encode(args.question) # (1, d) + result = hopfield.retrieve( + query_emb, mb.embeddings, return_trajectory=True + ) + + # Gather all points for UMAP: memories + trajectory + memories_np = mb.embeddings.T.numpy() # (N, d) + trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d) + all_points = np.concatenate([memories_np, trajectory_np], axis=0) + + # UMAP projection + import umap + + reducer = umap.UMAP(n_components=2, random_state=42) + projected = reducer.fit_transform(all_points) + + mem_proj = projected[: len(memories_np)] + traj_proj = projected[len(memories_np) :] + + # Plot + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(10, 8)) + ax.scatter(mem_proj[:, 0], mem_proj[:, 1], c="lightgray", s=10, alpha=0.5, label="Memories") + ax.plot(traj_proj[:, 0], traj_proj[:, 1], "b-o", markersize=6, label="Query trajectory") + ax.scatter(traj_proj[0, 0], traj_proj[0, 1], c="green", s=100, zorder=5, label="q_0 (start)") + ax.scatter(traj_proj[-1, 0], traj_proj[-1, 1], c="red", s=100, zorder=5, label="q_T (final)") + + ax.set_title(f"Hopfield Query Trajectory ({result.num_steps} steps)") + ax.legend() + plt.tight_layout() + plt.savefig(args.output, dpi=150) + logger.info("Trajectory visualization saved to %s", args.output) + + +if __name__ == "__main__": + main() |
