summaryrefslogtreecommitdiff
path: root/scripts
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /scripts
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'scripts')
-rw-r--r--scripts/analyze_energy.py78
-rw-r--r--scripts/build_memory_bank.py72
-rw-r--r--scripts/run_baseline.py75
-rw-r--r--scripts/run_eval.py90
-rw-r--r--scripts/run_hag.py79
-rw-r--r--scripts/visualize_trajectory.py80
6 files changed, 474 insertions, 0 deletions
diff --git a/scripts/analyze_energy.py b/scripts/analyze_energy.py
new file mode 100644
index 0000000..fd044a4
--- /dev/null
+++ b/scripts/analyze_energy.py
@@ -0,0 +1,78 @@
+"""Analyze energy curves and convergence properties of Hopfield retrieval.
+
+Usage:
+ python scripts/analyze_energy.py --config configs/default.yaml --memory-bank data/memory_bank.pt --questions data/questions.jsonl --output energy_analysis.json
+"""
+
+import argparse
+import json
+import logging
+
+import torch
+import yaml
+
+from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.energy import (
+ compute_attention_entropy,
+ compute_energy_curve,
+ compute_energy_gap,
+ verify_monotonic_decrease,
+)
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Analyze Hopfield energy curves")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--output", type=str, default="energy_analysis.json")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ hopfield_config = HopfieldConfig(**cfg.get("hopfield", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+
+ # Load memory bank
+ mb = MemoryBank(memory_config)
+ mb.load(args.memory_bank)
+
+ # Load questions
+ with open(args.questions) as f:
+ questions = [json.loads(line)["question"] for line in f]
+
+ encoder = Encoder(encoder_config)
+ hopfield = HopfieldRetrieval(hopfield_config)
+
+ analyses = []
+ for q in questions:
+ query_emb = encoder.encode(q) # (1, d)
+ result = hopfield.retrieve(
+ query_emb, mb.embeddings, return_energy=True, return_trajectory=True
+ )
+
+ curve = compute_energy_curve(result)
+ analyses.append({
+ "question": q,
+ "energy_curve": curve,
+ "energy_gap": compute_energy_gap(curve),
+ "monotonic": verify_monotonic_decrease(curve),
+ "num_steps": result.num_steps,
+ "attention_entropy": compute_attention_entropy(result.attention_weights),
+ })
+
+ with open(args.output, "w") as f:
+ json.dump(analyses, f, indent=2)
+ logger.info("Energy analysis saved to %s (%d questions)", args.output, len(analyses))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/build_memory_bank.py b/scripts/build_memory_bank.py
new file mode 100644
index 0000000..2aff828
--- /dev/null
+++ b/scripts/build_memory_bank.py
@@ -0,0 +1,72 @@
+"""Offline script: encode corpus passages into a memory bank.
+
+Usage:
+ python scripts/build_memory_bank.py --config configs/default.yaml --corpus data/corpus.jsonl --output data/memory_bank.pt
+"""
+
+import argparse
+import json
+import logging
+
+import torch
+import yaml
+from tqdm import tqdm
+
+from hag.config import EncoderConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.memory_bank import MemoryBank
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def load_corpus(path: str) -> list[str]:
+ """Load passages from a JSONL file (one JSON object per line with 'text' field)."""
+ passages = []
+ with open(path) as f:
+ for line in f:
+ obj = json.loads(line)
+ passages.append(obj["text"])
+ return passages
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Build memory bank from corpus")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--corpus", type=str, required=True)
+ parser.add_argument("--output", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+
+ # Load corpus
+ logger.info("Loading corpus from %s", args.corpus)
+ passages = load_corpus(args.corpus)
+ logger.info("Loaded %d passages", len(passages))
+
+ # Encode passages in batches
+ encoder = Encoder(encoder_config)
+ all_embeddings = []
+
+ for i in tqdm(range(0, len(passages), encoder_config.batch_size), desc="Encoding"):
+ batch = passages[i : i + encoder_config.batch_size]
+ emb = encoder.encode(batch) # (batch_size, d)
+ all_embeddings.append(emb.cpu())
+
+ embeddings = torch.cat(all_embeddings, dim=0) # (N, d)
+ logger.info("Encoded %d passages -> embeddings shape: %s", len(passages), embeddings.shape)
+
+ # Build and save memory bank
+ mb = MemoryBank(memory_config)
+ mb.build_from_embeddings(embeddings, passages)
+ mb.save(args.output)
+ logger.info("Memory bank saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_baseline.py b/scripts/run_baseline.py
new file mode 100644
index 0000000..74c4710
--- /dev/null
+++ b/scripts/run_baseline.py
@@ -0,0 +1,75 @@
+"""Run vanilla RAG baseline with FAISS retrieval.
+
+Usage:
+ python scripts/run_baseline.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?"
+"""
+
+import argparse
+import logging
+
+import torch
+import yaml
+
+from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, PipelineConfig
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.retriever_faiss import FAISSRetriever
+from hag.pipeline import RAGPipeline
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Run vanilla RAG baseline")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--question", type=str, required=True)
+ parser.add_argument("--top-k", type=int, default=5)
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ # Override retriever type to faiss
+ pipeline_config = PipelineConfig(
+ hopfield=HopfieldConfig(**{**cfg.get("hopfield", {}), "top_k": args.top_k}),
+ encoder=EncoderConfig(**cfg.get("encoder", {})),
+ generator=GeneratorConfig(**cfg.get("generator", {})),
+ retriever_type="faiss",
+ )
+
+ # Load memory bank to get embeddings for FAISS
+ from hag.config import MemoryBankConfig
+
+ mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {})))
+ mb.load(args.memory_bank)
+
+ # Build FAISS index from memory bank embeddings
+ import numpy as np
+
+ embeddings_np = mb.embeddings.T.numpy().astype(np.float32) # (N, d)
+ faiss_ret = FAISSRetriever(top_k=args.top_k)
+ faiss_ret.build_index(embeddings_np, mb.passages)
+
+ encoder = Encoder(pipeline_config.encoder)
+ generator = Generator(pipeline_config.generator)
+
+ pipeline = RAGPipeline(
+ config=pipeline_config,
+ encoder=encoder,
+ generator=generator,
+ faiss_retriever=faiss_ret,
+ )
+
+ result = pipeline.run(args.question)
+ print(f"\nQuestion: {result.question}")
+ print(f"Answer: {result.answer}")
+ print(f"\nRetrieved passages:")
+ for i, p in enumerate(result.retrieved_passages):
+ print(f" [{i+1}] {p[:200]}...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_eval.py b/scripts/run_eval.py
new file mode 100644
index 0000000..713b3c2
--- /dev/null
+++ b/scripts/run_eval.py
@@ -0,0 +1,90 @@
+"""Run evaluation on a dataset with either FAISS or Hopfield retrieval.
+
+Usage:
+ python scripts/run_eval.py --config configs/hotpotqa.yaml --memory-bank data/memory_bank.pt --dataset hotpotqa --split validation --max-samples 500
+"""
+
+import argparse
+import json
+import logging
+
+import yaml
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.metrics import evaluate_dataset
+from hag.pipeline import RAGPipeline
+from hag.retriever_faiss import FAISSRetriever
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Run HAG/RAG evaluation")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--dataset", type=str, default="hotpotqa")
+ parser.add_argument("--split", type=str, default="validation")
+ parser.add_argument("--max-samples", type=int, default=500)
+ parser.add_argument("--output", type=str, default="results.json")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ pipeline_config = PipelineConfig(
+ hopfield=HopfieldConfig(**cfg.get("hopfield", {})),
+ memory=MemoryBankConfig(**cfg.get("memory", {})),
+ encoder=EncoderConfig(**cfg.get("encoder", {})),
+ generator=GeneratorConfig(**cfg.get("generator", {})),
+ retriever_type=cfg.get("retriever_type", "hopfield"),
+ )
+
+ # Load memory bank
+ mb = MemoryBank(pipeline_config.memory)
+ mb.load(args.memory_bank)
+
+ # Build pipeline
+ encoder = Encoder(pipeline_config.encoder)
+ generator = Generator(pipeline_config.generator)
+ pipeline = RAGPipeline(
+ config=pipeline_config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+
+ # Load dataset
+ from datasets import load_dataset
+
+ logger.info("Loading dataset: %s / %s", args.dataset, args.split)
+ ds = load_dataset(args.dataset, split=args.split)
+ if args.max_samples and len(ds) > args.max_samples:
+ ds = ds.select(range(args.max_samples))
+
+ questions = [ex["question"] for ex in ds]
+ gold_answers = [ex["answer"] for ex in ds]
+
+ # Run evaluation
+ logger.info("Running evaluation on %d questions", len(questions))
+ results = pipeline.run_batch(questions)
+ metrics = evaluate_dataset(results, gold_answers)
+
+ logger.info("Results: %s", metrics)
+
+ with open(args.output, "w") as f:
+ json.dump(metrics, f, indent=2)
+ logger.info("Results saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_hag.py b/scripts/run_hag.py
new file mode 100644
index 0000000..4cacd1a
--- /dev/null
+++ b/scripts/run_hag.py
@@ -0,0 +1,79 @@
+"""Run HAG (Hopfield-Augmented Generation) on a question.
+
+Usage:
+ python scripts/run_hag.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?"
+"""
+
+import argparse
+import logging
+
+import yaml
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.pipeline import RAGPipeline
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Run HAG")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--question", type=str, required=True)
+ parser.add_argument("--beta", type=float, default=None)
+ parser.add_argument("--max-iter", type=int, default=None)
+ parser.add_argument("--top-k", type=int, default=None)
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ hopfield_cfg = cfg.get("hopfield", {})
+ if args.beta is not None:
+ hopfield_cfg["beta"] = args.beta
+ if args.max_iter is not None:
+ hopfield_cfg["max_iter"] = args.max_iter
+ if args.top_k is not None:
+ hopfield_cfg["top_k"] = args.top_k
+
+ pipeline_config = PipelineConfig(
+ hopfield=HopfieldConfig(**hopfield_cfg),
+ memory=MemoryBankConfig(**cfg.get("memory", {})),
+ encoder=EncoderConfig(**cfg.get("encoder", {})),
+ generator=GeneratorConfig(**cfg.get("generator", {})),
+ retriever_type="hopfield",
+ )
+
+ mb = MemoryBank(pipeline_config.memory)
+ mb.load(args.memory_bank)
+
+ encoder = Encoder(pipeline_config.encoder)
+ generator = Generator(pipeline_config.generator)
+
+ pipeline = RAGPipeline(
+ config=pipeline_config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+
+ result = pipeline.run(args.question)
+ print(f"\nQuestion: {result.question}")
+ print(f"Answer: {result.answer}")
+ print(f"\nRetrieved passages:")
+ for i, p in enumerate(result.retrieved_passages):
+ print(f" [{i+1}] {p[:200]}...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/visualize_trajectory.py b/scripts/visualize_trajectory.py
new file mode 100644
index 0000000..e4ba902
--- /dev/null
+++ b/scripts/visualize_trajectory.py
@@ -0,0 +1,80 @@
+"""UMAP visualization of query trajectory in Hopfield retrieval.
+
+Usage:
+ python scripts/visualize_trajectory.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?"
+"""
+
+import argparse
+import logging
+
+import numpy as np
+import torch
+import yaml
+
+from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Visualize Hopfield query trajectory")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--question", type=str, required=True)
+ parser.add_argument("--output", type=str, default="trajectory.png")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ hopfield_config = HopfieldConfig(**cfg.get("hopfield", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+
+ mb = MemoryBank(memory_config)
+ mb.load(args.memory_bank)
+
+ encoder = Encoder(encoder_config)
+ hopfield = HopfieldRetrieval(hopfield_config)
+
+ query_emb = encoder.encode(args.question) # (1, d)
+ result = hopfield.retrieve(
+ query_emb, mb.embeddings, return_trajectory=True
+ )
+
+ # Gather all points for UMAP: memories + trajectory
+ memories_np = mb.embeddings.T.numpy() # (N, d)
+ trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d)
+ all_points = np.concatenate([memories_np, trajectory_np], axis=0)
+
+ # UMAP projection
+ import umap
+
+ reducer = umap.UMAP(n_components=2, random_state=42)
+ projected = reducer.fit_transform(all_points)
+
+ mem_proj = projected[: len(memories_np)]
+ traj_proj = projected[len(memories_np) :]
+
+ # Plot
+ import matplotlib.pyplot as plt
+
+ fig, ax = plt.subplots(figsize=(10, 8))
+ ax.scatter(mem_proj[:, 0], mem_proj[:, 1], c="lightgray", s=10, alpha=0.5, label="Memories")
+ ax.plot(traj_proj[:, 0], traj_proj[:, 1], "b-o", markersize=6, label="Query trajectory")
+ ax.scatter(traj_proj[0, 0], traj_proj[0, 1], c="green", s=100, zorder=5, label="q_0 (start)")
+ ax.scatter(traj_proj[-1, 0], traj_proj[-1, 1], c="red", s=100, zorder=5, label="q_T (final)")
+
+ ax.set_title(f"Hopfield Query Trajectory ({result.num_steps} steps)")
+ ax.legend()
+ plt.tight_layout()
+ plt.savefig(args.output, dpi=150)
+ logger.info("Trajectory visualization saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()