summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore8
-rw-r--r--CLAUDE.md870
-rw-r--r--README.md40
-rw-r--r--TBD.md23
-rw-r--r--configs/2wikimhqa.yaml22
-rw-r--r--configs/default.yaml21
-rw-r--r--configs/hotpotqa.yaml22
-rw-r--r--configs/musique.yaml22
-rw-r--r--hag/__init__.py21
-rw-r--r--hag/config.py50
-rw-r--r--hag/datatypes.py37
-rw-r--r--hag/encoder.py88
-rw-r--r--hag/energy.py83
-rw-r--r--hag/generator.py87
-rw-r--r--hag/hopfield.py124
-rw-r--r--hag/memory_bank.py93
-rw-r--r--hag/metrics.py113
-rw-r--r--hag/pipeline.py107
-rw-r--r--hag/retriever_faiss.py73
-rw-r--r--hag/retriever_hopfield.py77
-rw-r--r--notebooks/demo.ipynb62
-rw-r--r--pyproject.toml36
-rw-r--r--scripts/analyze_energy.py78
-rw-r--r--scripts/build_memory_bank.py72
-rw-r--r--scripts/run_baseline.py75
-rw-r--r--scripts/run_eval.py90
-rw-r--r--scripts/run_hag.py79
-rw-r--r--scripts/visualize_trajectory.py80
-rw-r--r--tests/__init__.py0
-rw-r--r--tests/test_energy.py80
-rw-r--r--tests/test_hopfield.py147
-rw-r--r--tests/test_memory_bank.py65
-rw-r--r--tests/test_metrics.py54
-rw-r--r--tests/test_pipeline.py132
-rw-r--r--tests/test_retriever.py149
35 files changed, 3180 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..0741f54
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,8 @@
+__pycache__/
+*.pyc
+*.egg-info/
+dist/
+build/
+.pytest_cache/
+*.pt
+.claude/
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 0000000..c99b846
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,870 @@
+# CLAUDE.md — Hopfield-Augmented Generation (HAG)
+
+## Project Overview
+
+We are building **HAG (Hopfield-Augmented Generation)**, a retrieval-augmented generation system that replaces standard vector similarity search (FAISS top-k) with **iterative Modern Hopfield Network retrieval**. This is a research project targeting a top ML venue (ICML/NeurIPS).
+
+### Core Idea
+
+Standard RAG does:
+```
+y = LLM(x, top_k_passages(FAISS(Enc(x))))
+```
+
+HAG does:
+```
+q₀ = Enc(x)
+qₜ₊₁ = M · softmax(β · Mᵀ · qₜ) # iterate T steps
+α_T = softmax(β · Mᵀ · q_T) # final attention weights
+retrieved_passages = top_k(α_T) # use weights to select passages
+y = LLM(x, retrieved_passages)
+```
+
+Where:
+- `M ∈ ℝ^{d × N}` is a memory bank of passage embeddings (same encoder as vanilla RAG)
+- `q` is the query embedding
+- `β` is an inverse temperature parameter
+- The iterative update is the retrieval dynamics of a **Modern Continuous Hopfield Network** (Ramsauer et al., 2020)
+
+### Why This Matters
+
+1. **Iterative refinement**: RAG retrieves once. HAG refines the query in embedding space over T steps, enabling implicit multi-hop reasoning.
+2. **Energy-based theory**: Retrieval is energy minimization with convergence guarantees (energy decreases monotonically).
+3. **Differentiable**: Unlike FAISS top-k, the entire retrieval is differentiable — enabling future end-to-end training.
+4. **No structural preprocessing**: Unlike GraphRAG methods (LinearRAG, HippoRAG), no NER, no graph construction, no PageRank. Just an embedding matrix.
+
+### What We Are NOT Doing (Scope Boundaries)
+
+- NOT modifying the LLM architecture (no cross-attention injection)
+- NOT training the memory bank in phase 1 (M is frozen embeddings)
+- NOT building a graph structure
+- NOT doing anything that requires the LLM during retrieval
+- The LLM is frozen and used only for final generation
+
+---
+
+## Repository Structure
+
+```
+hag/
+├── CLAUDE.md # This file
+├── pyproject.toml # Project config, dependencies
+├── README.md # Project README
+├── hag/
+│ ├── __init__.py
+│ ├── config.py # All hyperparameters and configs (dataclass-based)
+│ ├── encoder.py # Wrapper for encoding queries/passages
+│ ├── memory_bank.py # Memory bank construction and management
+│ ├── hopfield.py # Core: Hopfield retrieval module
+│ ├── retriever_faiss.py # Baseline: FAISS top-k retriever
+│ ├── retriever_hopfield.py # HAG retriever (wraps hopfield.py + passage selection)
+│ ├── generator.py # LLM generation wrapper
+│ ├── pipeline.py # End-to-end RAG/HAG pipeline
+│ ├── metrics.py # Evaluation metrics (EM, F1, retrieval recall)
+│ └── energy.py # Energy computation and analysis utilities
+├── scripts/
+│ ├── build_memory_bank.py # Offline: encode corpus → memory bank
+│ ├── run_eval.py # Run evaluation on a dataset
+│ ├── run_baseline.py # Run vanilla RAG baseline
+│ ├── run_hag.py # Run HAG
+│ ├── analyze_energy.py # Analyze energy curves and convergence
+│ └── visualize_trajectory.py # UMAP visualization of query trajectory
+├── tests/
+│ ├── test_hopfield.py # Unit tests for Hopfield module
+│ ├── test_memory_bank.py # Unit tests for memory bank
+│ ├── test_retriever.py # Unit tests for both retrievers
+│ ├── test_pipeline.py # Integration tests for pipeline
+│ ├── test_metrics.py # Unit tests for metrics
+│ └── test_energy.py # Unit tests for energy module
+├── configs/
+│ ├── default.yaml # Default experiment config
+│ ├── hotpotqa.yaml # HotpotQA-specific config
+│ ├── musique.yaml # MuSiQue-specific config
+│ └── 2wikimhqa.yaml # 2WikiMultiHopQA-specific config
+└── notebooks/
+ └── demo.ipynb # Quick demo notebook
+```
+
+---
+
+## Module Specifications
+
+### 1. `hag/config.py`
+
+Use Python dataclasses. All hyperparameters in one place.
+
+```python
+@dataclass
+class HopfieldConfig:
+ beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval
+ max_iter: int = 5 # Maximum Hopfield iteration steps
+ conv_threshold: float = 1e-4 # Convergence: stop if ||q_{t+1} - q_t|| < threshold
+ top_k: int = 5 # Number of passages to retrieve from final attention weights
+
+@dataclass
+class MemoryBankConfig:
+ embedding_dim: int = 768 # Must match encoder output dim
+ normalize: bool = True # L2-normalize embeddings in memory bank
+
+@dataclass
+class EncoderConfig:
+ model_name: str = "facebook/contriever-msmarco" # Default encoder
+ max_length: int = 512
+ batch_size: int = 64
+
+@dataclass
+class GeneratorConfig:
+ model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
+ max_new_tokens: int = 128
+ temperature: float = 0.0 # Greedy decoding for reproducibility
+
+@dataclass
+class PipelineConfig:
+ hopfield: HopfieldConfig = field(default_factory=HopfieldConfig)
+ memory: MemoryBankConfig = field(default_factory=MemoryBankConfig)
+ encoder: EncoderConfig = field(default_factory=EncoderConfig)
+ generator: GeneratorConfig = field(default_factory=GeneratorConfig)
+ retriever_type: str = "hopfield" # "hopfield" or "faiss"
+```
+
+### 2. `hag/hopfield.py` — THE CORE MODULE
+
+This is the most important file. Implement the Modern Continuous Hopfield Network retrieval.
+
+```python
+class HopfieldRetrieval:
+ """
+ Modern Continuous Hopfield Network for memory retrieval.
+
+ Given memory bank M ∈ ℝ^{d × N} and query q ∈ ℝ^d:
+ 1. Compute attention: α = softmax(β * Mᵀ @ q)
+ 2. Update query: q_new = M @ α
+ 3. Repeat until convergence or max_iter
+
+ The energy function is:
+ E(q) = -1/β * log(Σ_i exp(β * q^T m_i)) + 1/2 * ||q||^2
+
+ Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease)
+ """
+
+ def __init__(self, config: HopfieldConfig):
+ ...
+
+ def retrieve(
+ self,
+ query: torch.Tensor, # (d,) or (batch, d)
+ memory: torch.Tensor, # (d, N) — the memory bank
+ return_trajectory: bool = False,
+ return_energy: bool = False,
+ ) -> HopfieldResult:
+ """
+ Run iterative Hopfield retrieval.
+
+ Returns HopfieldResult with:
+ - attention_weights: (N,) or (batch, N) — final α_T
+ - converged_query: (d,) or (batch, d) — final q_T
+ - num_steps: int — actual steps taken
+ - trajectory: optional list of q_t at each step
+ - energy_curve: optional list of E(q_t) at each step
+ """
+ ...
+
+ def compute_energy(
+ self,
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ E(q) = -1/β * log(Σ_i exp(β * qᵀ m_i)) + 1/2 * ||q||^2
+
+ This is the log-sum-exp energy of the Modern Hopfield Network.
+ """
+ ...
+```
+
+**CRITICAL implementation details for `retrieve()`:**
+
+```python
+def retrieve(self, query, memory, return_trajectory=False, return_energy=False):
+ # Ensure correct shapes
+ # query: (batch, d) — unsqueeze if needed
+ # memory: (d, N)
+
+ q = query.clone()
+ trajectory = [q.clone()] if return_trajectory else None
+ energies = [self.compute_energy(q, memory)] if return_energy else None
+
+ for t in range(self.config.max_iter):
+ # Core Hopfield update
+ logits = self.config.beta * (memory.T @ q.T).T # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ q_new = (memory @ alpha.T).T # (batch, d)
+
+ # Check convergence
+ delta = torch.norm(q_new - q, dim=-1).max()
+ q = q_new
+
+ if return_trajectory:
+ trajectory.append(q.clone())
+ if return_energy:
+ energies.append(self.compute_energy(q, memory))
+
+ if delta < self.config.conv_threshold:
+ break
+
+ # Final attention weights (recompute to ensure consistency)
+ logits = self.config.beta * (memory.T @ q.T).T
+ alpha = torch.softmax(logits, dim=-1)
+
+ return HopfieldResult(
+ attention_weights=alpha,
+ converged_query=q,
+ num_steps=t + 1,
+ trajectory=trajectory,
+ energy_curve=energies,
+ )
+```
+
+**Pay special attention to:**
+- Numerical stability: use `torch.logsumexp` in energy computation
+- Batch support: all operations should work for both single query and batched queries
+- Memory efficiency: don't store trajectory/energy unless explicitly requested
+
+### 3. `hag/memory_bank.py`
+
+```python
+class MemoryBank:
+ """
+ Stores passage embeddings and provides lookup from indices back to text.
+
+ The memory bank is M ∈ ℝ^{d × N} where each column is a passage embedding.
+ Also maintains a mapping from column index → passage text for final retrieval.
+ """
+
+ def __init__(self, config: MemoryBankConfig):
+ self.embeddings: torch.Tensor = None # (d, N)
+ self.passages: List[str] = [] # N passages
+ self.config = config
+
+ def build_from_embeddings(self, embeddings: torch.Tensor, passages: List[str]):
+ """
+ embeddings: (N, d) — note: input is (N, d), internally stored as (d, N)
+ passages: list of N passage strings
+ """
+ if self.config.normalize:
+ embeddings = F.normalize(embeddings, dim=-1)
+ self.embeddings = embeddings.T # Store as (d, N) for efficient matmul
+ self.passages = passages
+
+ def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
+ """Given top-k indices, return corresponding passage texts."""
+ ...
+
+ def save(self, path: str): ...
+ def load(self, path: str): ...
+
+ @property
+ def size(self) -> int:
+ return self.embeddings.shape[1] if self.embeddings is not None else 0
+
+ @property
+ def dim(self) -> int:
+ return self.embeddings.shape[0] if self.embeddings is not None else 0
+```
+
+### 4. `hag/retriever_faiss.py`
+
+Vanilla RAG baseline retriever using FAISS.
+
+```python
+class FAISSRetriever:
+ """
+ Standard top-k retrieval using FAISS inner product search.
+ This is the baseline to compare against.
+ """
+
+ def __init__(self, top_k: int = 5):
+ self.index = None
+ self.passages: List[str] = []
+ self.top_k = top_k
+
+ def build_index(self, embeddings: np.ndarray, passages: List[str]):
+ """Build FAISS IndexFlatIP from (N, d) embeddings."""
+ ...
+
+ def retrieve(self, query: np.ndarray) -> RetrievalResult:
+ """
+ query: (d,) or (batch, d)
+ Returns RetrievalResult with top_k passage texts, scores, and indices.
+ """
+ ...
+```
+
+### 5. `hag/retriever_hopfield.py`
+
+```python
+class HopfieldRetriever:
+ """
+ Wraps HopfieldRetrieval + MemoryBank into a retriever interface.
+
+ The bridge between Hopfield's continuous retrieval and the discrete
+ passage selection needed for LLM prompting.
+ """
+
+ def __init__(self, hopfield: HopfieldRetrieval, memory_bank: MemoryBank, top_k: int = 5):
+ ...
+
+ def retrieve(self, query_embedding: torch.Tensor, return_analysis: bool = False) -> RetrievalResult:
+ """
+ 1. Run Hopfield iterative retrieval → get attention weights α_T
+ 2. Take top_k indices from α_T
+ 3. Look up corresponding passage texts from memory bank
+ 4. Optionally return trajectory and energy for analysis
+
+ Returns RetrievalResult with:
+ - passages: List[str]
+ - scores: attention weights for top-k
+ - indices: which memory slots were selected
+ - (optional) hopfield_result: full HopfieldResult for analysis
+ """
+ ...
+```
+
+### 6. `hag/pipeline.py`
+
+```python
+class RAGPipeline:
+ """
+ End-to-end pipeline: query → encode → retrieve → generate.
+ Supports both FAISS (baseline) and Hopfield (ours) retrieval.
+ """
+
+ def __init__(self, config: PipelineConfig):
+ self.encoder = Encoder(config.encoder)
+ self.generator = Generator(config.generator)
+
+ if config.retriever_type == "faiss":
+ self.retriever = FAISSRetriever(top_k=config.hopfield.top_k)
+ elif config.retriever_type == "hopfield":
+ hopfield = HopfieldRetrieval(config.hopfield)
+ self.retriever = HopfieldRetriever(hopfield, memory_bank, top_k=config.hopfield.top_k)
+
+ def run(self, question: str) -> PipelineResult:
+ """
+ Full pipeline:
+ 1. Encode question → query embedding
+ 2. Retrieve passages (FAISS or Hopfield)
+ 3. Format prompt with question + retrieved passages
+ 4. Generate answer with LLM
+ """
+ ...
+
+ def run_batch(self, questions: List[str]) -> List[PipelineResult]:
+ ...
+```
+
+### 7. `hag/metrics.py`
+
+```python
+def exact_match(prediction: str, ground_truth: str) -> float:
+ """Normalized exact match. Lowercase, strip, remove articles/punctuation."""
+ ...
+
+def f1_score(prediction: str, ground_truth: str) -> float:
+ """Token-level F1 between prediction and ground truth."""
+ ...
+
+def retrieval_recall_at_k(retrieved_indices: List[int], gold_indices: List[int], k: int) -> float:
+ """What fraction of gold passages appear in the retrieved top-k?"""
+ ...
+
+def evaluate_dataset(results: List[PipelineResult], gold_answers: List[str]) -> Dict[str, float]:
+ """Compute aggregate metrics over a dataset. Returns dict with EM, F1, etc."""
+ ...
+```
+
+### 8. `hag/energy.py`
+
+```python
+def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]:
+ """Extract energy values at each iteration step."""
+ ...
+
+def compute_energy_gap(energy_curve: List[float]) -> float:
+ """
+ ΔE = E(q_0) - E(q_T)
+ Larger gap = more refinement happened.
+ """
+ ...
+
+def verify_monotonic_decrease(energy_curve: List[float]) -> bool:
+ """Check that E(q_{t+1}) <= E(q_t) for all t. Should always be True."""
+ ...
+
+def compute_attention_entropy(attention_weights: torch.Tensor) -> float:
+ """
+ H(α) = -Σ_i α_i log α_i
+ Low entropy = sharp retrieval (confident).
+ High entropy = diffuse retrieval (uncertain).
+ """
+ ...
+```
+
+---
+
+## Data Types
+
+Define these as dataclasses or NamedTuples:
+
+```python
+@dataclass
+class HopfieldResult:
+ attention_weights: torch.Tensor # (batch, N) or (N,)
+ converged_query: torch.Tensor # (batch, d) or (d,)
+ num_steps: int
+ trajectory: Optional[List[torch.Tensor]] = None # list of q_t
+ energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t)
+
+@dataclass
+class RetrievalResult:
+ passages: List[str]
+ scores: torch.Tensor # top-k scores
+ indices: torch.Tensor # top-k indices
+ hopfield_result: Optional[HopfieldResult] = None
+
+@dataclass
+class PipelineResult:
+ question: str
+ answer: str
+ retrieved_passages: List[str]
+ retrieval_result: RetrievalResult
+```
+
+---
+
+## Test Specifications
+
+**ALL TESTS MUST RUN ON CPU. NO GPU REQUIRED.**
+
+Use small synthetic data (random tensors) for unit tests. Do not download any models or datasets in tests.
+
+### `tests/test_hopfield.py`
+
+```python
+class TestHopfieldRetrieval:
+ """Test the core Hopfield module with synthetic data on CPU."""
+
+ def setup_method(self):
+ """Create small synthetic memory bank and queries."""
+ torch.manual_seed(42)
+ self.d = 64 # embedding dim
+ self.N = 100 # number of memories
+ self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N)
+ self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d)
+ self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6)
+ self.hopfield = HopfieldRetrieval(self.config)
+
+ def test_output_shapes(self):
+ """attention_weights should be (1, N), converged_query should be (1, d)."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert result.attention_weights.shape == (1, self.N)
+ assert result.converged_query.shape == (1, self.d)
+
+ def test_attention_weights_sum_to_one(self):
+ """softmax output must sum to 1."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert torch.allclose(result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5)
+
+ def test_attention_weights_non_negative(self):
+ """All attention weights must be >= 0."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert (result.attention_weights >= 0).all()
+
+ def test_energy_monotonic_decrease(self):
+ """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property."""
+ result = self.hopfield.retrieve(self.query, self.memory, return_energy=True)
+ energies = [e.item() for e in result.energy_curve]
+ for i in range(len(energies) - 1):
+ assert energies[i+1] <= energies[i] + 1e-6, \
+ f"Energy increased at step {i}: {energies[i]} -> {energies[i+1]}"
+
+ def test_convergence(self):
+ """With enough iterations, query should converge (delta < threshold)."""
+ config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(self.query, self.memory, return_trajectory=True)
+ # Final two queries should be very close
+ q_last = result.trajectory[-1]
+ q_prev = result.trajectory[-2]
+ delta = torch.norm(q_last - q_prev)
+ assert delta < 1e-4, f"Did not converge: delta={delta}"
+
+ def test_high_beta_sharp_retrieval(self):
+ """Higher β should produce sharper (lower entropy) attention."""
+ low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5))
+ high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5))
+
+ result_low = low_beta.retrieve(self.query, self.memory)
+ result_high = high_beta.retrieve(self.query, self.memory)
+
+ entropy_low = -(result_low.attention_weights * result_low.attention_weights.log()).sum()
+ entropy_high = -(result_high.attention_weights * result_high.attention_weights.log()).sum()
+
+ assert entropy_high < entropy_low, "Higher β should give lower entropy"
+
+ def test_single_memory_converges_to_it(self):
+ """With N=1, retrieval should converge to the single memory."""
+ single_memory = F.normalize(torch.randn(self.d, 1), dim=0)
+ result = self.hopfield.retrieve(self.query, single_memory)
+ assert torch.allclose(result.attention_weights, torch.ones(1, 1), atol=1e-5)
+
+ def test_query_near_memory_retrieves_it(self):
+ """If query ≈ memory_i, attention should peak at index i."""
+ target_idx = 42
+ query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match
+ config = HopfieldConfig(beta=10.0, max_iter=5)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(query, self.memory)
+ top_idx = result.attention_weights.argmax(dim=-1).item()
+ assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}"
+
+ def test_batch_retrieval(self):
+ """Should handle batch of queries."""
+ batch_query = F.normalize(torch.randn(8, self.d), dim=-1)
+ result = self.hopfield.retrieve(batch_query, self.memory)
+ assert result.attention_weights.shape == (8, self.N)
+ assert result.converged_query.shape == (8, self.d)
+
+ def test_iteration_refines_query(self):
+ """
+ Multi-hop test: query starts far from target, iteration should bring it closer.
+
+ Setup: memory has two clusters. Query is near cluster A but the "answer"
+ is in cluster B, reachable through an intermediate memory that bridges both.
+ After iteration, the query should drift toward cluster B.
+ """
+ torch.manual_seed(0)
+ d = 32
+
+ # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...]
+ # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...]
+ # Bridge memory 10: between A and B
+ center_a = torch.zeros(d); center_a[0] = 1.0
+ center_b = torch.zeros(d); center_b[1] = 1.0
+ bridge = F.normalize(center_a + center_b, dim=0)
+
+ memories = []
+ for _ in range(5):
+ memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0))
+ for _ in range(5):
+ memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0))
+ memories.append(bridge)
+
+ M = torch.stack(memories, dim=1) # (d, 11)
+ q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0)
+
+ config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(q0, M, return_trajectory=True)
+
+ # After iteration, query should have drifted: its dot product with center_b
+ # should be higher than the initial query's dot product with center_b
+ initial_sim_b = (q0.squeeze() @ center_b).item()
+ final_sim_b = (result.converged_query.squeeze() @ center_b).item()
+ assert final_sim_b > initial_sim_b, \
+ f"Iteration should pull query toward cluster B: initial={initial_sim_b:.4f}, final={final_sim_b:.4f}"
+```
+
+### `tests/test_memory_bank.py`
+
+```python
+class TestMemoryBank:
+
+ def test_build_and_size(self):
+ """Build memory bank and verify size."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(100, 64)
+ passages = [f"passage {i}" for i in range(100)]
+ mb.build_from_embeddings(embeddings, passages)
+ assert mb.size == 100
+ assert mb.dim == 64
+
+ def test_normalization(self):
+ """If normalize=True, stored embeddings should have unit norm."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(50, 64) * 5 # non-unit norm
+ mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
+ norms = torch.norm(mb.embeddings, dim=0)
+ assert torch.allclose(norms, torch.ones(50), atol=1e-5)
+
+ def test_get_passages_by_indices(self):
+ """Index → passage text lookup."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=False)
+ mb = MemoryBank(config)
+ passages = [f"passage {i}" for i in range(100)]
+ mb.build_from_embeddings(torch.randn(100, 64), passages)
+ result = mb.get_passages_by_indices(torch.tensor([0, 50, 99]))
+ assert result == ["passage 0", "passage 50", "passage 99"]
+
+ def test_save_and_load(self, tmp_path):
+ """Save and reload memory bank, verify contents match."""
+ config = MemoryBankConfig(embedding_dim=32, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(20, 32)
+ passages = [f"text {i}" for i in range(20)]
+ mb.build_from_embeddings(embeddings, passages)
+
+ save_path = str(tmp_path / "mb.pt")
+ mb.save(save_path)
+
+ mb2 = MemoryBank(config)
+ mb2.load(save_path)
+ assert mb2.size == 20
+ assert mb2.passages == passages
+ assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6)
+```
+
+### `tests/test_retriever.py`
+
+```python
+class TestHopfieldRetriever:
+
+ def test_retrieves_correct_number_of_passages(self):
+ """top_k=3 should return exactly 3 passages."""
+ ...
+
+ def test_scores_are_sorted_descending(self):
+ """Returned scores should be in descending order."""
+ ...
+
+ def test_passages_match_indices(self):
+ """Returned passage texts should correspond to returned indices."""
+ ...
+
+class TestFAISSRetriever:
+
+ def test_retrieves_correct_number(self):
+ ...
+
+ def test_nearest_neighbor_is_self(self):
+ """If query = passage_i's embedding, top-1 should be passage_i."""
+ ...
+
+class TestRetrieverComparison:
+
+ def test_same_top1_for_obvious_query(self):
+ """
+ When query is very close to one memory, both FAISS and Hopfield
+ should agree on the top-1 result.
+ """
+ ...
+```
+
+### `tests/test_metrics.py`
+
+```python
+class TestMetrics:
+
+ def test_exact_match_basic(self):
+ assert exact_match("Paris", "paris") == 1.0
+ assert exact_match("Paris", "London") == 0.0
+
+ def test_exact_match_normalization(self):
+ """Should strip whitespace, lowercase, remove articles."""
+ assert exact_match(" The Paris ", "paris") == 1.0
+ assert exact_match("A dog", "dog") == 1.0
+
+ def test_f1_score_perfect(self):
+ assert f1_score("the cat sat", "the cat sat") == 1.0
+
+ def test_f1_score_partial(self):
+ score = f1_score("the cat sat on the mat", "the cat sat")
+ assert 0.5 < score < 1.0
+
+ def test_f1_score_no_overlap(self):
+ assert f1_score("hello world", "foo bar") == 0.0
+
+ def test_retrieval_recall(self):
+ assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2/3
+```
+
+### `tests/test_energy.py`
+
+```python
+class TestEnergy:
+
+ def test_energy_computation_matches_formula(self):
+ """
+ Manually compute E(q) = -1/β * log(Σ exp(β * qᵀ m_i)) + 1/2 * ||q||^2
+ and verify it matches the module output.
+ """
+ torch.manual_seed(42)
+ d, N = 16, 10
+ memory = torch.randn(d, N)
+ query = torch.randn(1, d)
+ beta = 2.0
+
+ config = HopfieldConfig(beta=beta)
+ hopfield = HopfieldRetrieval(config)
+ computed = hopfield.compute_energy(query, memory)
+
+ # Manual computation
+ logits = beta * (memory.T @ query.T).T # (1, N)
+ expected = -1.0/beta * torch.logsumexp(logits, dim=-1) + 0.5 * (query ** 2).sum(dim=-1)
+ assert torch.allclose(computed, expected, atol=1e-5)
+
+ def test_verify_monotonic(self):
+ energies = [10.0, 8.0, 6.5, 6.4, 6.39]
+ assert verify_monotonic_decrease(energies) == True
+
+ def test_verify_monotonic_fails(self):
+ energies = [10.0, 8.0, 9.0, 6.0]
+ assert verify_monotonic_decrease(energies) == False
+
+ def test_attention_entropy(self):
+ """Uniform attention should have max entropy, one-hot should have zero."""
+ uniform = torch.ones(100) / 100
+ one_hot = torch.zeros(100); one_hot[0] = 1.0
+ assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot)
+ assert compute_attention_entropy(one_hot) < 0.01
+```
+
+### `tests/test_pipeline.py`
+
+```python
+class TestPipeline:
+ """
+ Integration tests using MOCK encoder and generator.
+ These test the wiring, not the model quality.
+ """
+
+ def test_pipeline_runs_end_to_end_hopfield(self):
+ """Mock encoder + mock LLM + Hopfield retriever → produces an answer string."""
+ ...
+
+ def test_pipeline_runs_end_to_end_faiss(self):
+ """Same but with FAISS retriever."""
+ ...
+
+ def test_pipeline_retrieval_result_included(self):
+ """PipelineResult should contain retrieval metadata."""
+ ...
+```
+
+**For pipeline tests, use mock/fake encoder and generator:**
+```python
+class FakeEncoder:
+ def encode(self, text: str) -> torch.Tensor:
+ # Deterministic hash-based embedding for testing
+ torch.manual_seed(hash(text) % 2**32)
+ return F.normalize(torch.randn(1, 64), dim=-1)
+
+class FakeGenerator:
+ def generate(self, prompt: str) -> str:
+ return "mock answer"
+```
+
+---
+
+## Dependencies
+
+```toml
+[project]
+name = "hag"
+version = "0.1.0"
+requires-python = ">=3.10"
+
+dependencies = [
+ "torch>=2.0",
+ "numpy>=1.24",
+ "faiss-cpu>=1.7",
+ "transformers>=4.36",
+ "datasets>=2.14",
+ "pyyaml>=6.0",
+ "tqdm>=4.65",
+]
+
+[project.optional-dependencies]
+dev = [
+ "pytest>=7.0",
+ "pytest-cov>=4.0",
+]
+eval = [
+ "scikit-learn>=1.3",
+ "umap-learn>=0.5",
+ "matplotlib>=3.7",
+]
+```
+
+---
+
+## Implementation Order
+
+**Phase 1: Core module (do this first)**
+1. `config.py` — all dataclasses
+2. `hopfield.py` — the core Hopfield retrieval
+3. `energy.py` — energy utilities
+4. `tests/test_hopfield.py` — run and pass all tests
+5. `tests/test_energy.py` — run and pass all tests
+
+**Phase 2: Memory and retrievers**
+6. `memory_bank.py`
+7. `retriever_hopfield.py`
+8. `retriever_faiss.py`
+9. `tests/test_memory_bank.py`
+10. `tests/test_retriever.py`
+
+**Phase 3: Metrics and pipeline**
+11. `metrics.py`
+12. `pipeline.py` (with mock encoder/generator for now)
+13. `tests/test_metrics.py`
+14. `tests/test_pipeline.py`
+
+**Phase 4: Scripts (need GPU/models — implement structure only)**
+15. `scripts/build_memory_bank.py`
+16. `scripts/run_eval.py`
+17. `scripts/run_baseline.py`
+18. `scripts/run_hag.py`
+
+---
+
+## Coding Conventions
+
+- **Type hints everywhere.** Every function signature must have full type hints.
+- **Docstrings on every public method.** Include shapes in docstrings for tensor arguments.
+- **Tensor shapes in comments.** Any line with a tensor operation should have a shape comment: `# (batch, N)`
+- **No global state.** Everything parameterized through config dataclasses.
+- **CPU-first.** Default device is CPU. GPU is opt-in via `.to(device)`.
+- **Deterministic tests.** Always set `torch.manual_seed()` in tests.
+- **Use `torch.no_grad()` in retrieval** when not training (phase 1).
+- Use `logging` module, not print statements.
+
+---
+
+## Key Mathematical References
+
+The Modern Continuous Hopfield Network update rule and energy function come from:
+
+> Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021)
+
+Key equations:
+- **Update**: `q_{t+1} = M · softmax(β · Mᵀ · qₜ)`
+- **Energy**: `E(q) = -1/β · log(Σᵢ exp(β · qᵀmᵢ)) + 1/2 · ||q||²`
+- **Property**: Energy decreases monotonically: `E(q_{t+1}) ≤ E(q_t)`
+- **Property**: Fixed points are local minima of E
+- **Capacity**: Exponential in d (unlike classical Hopfield's ~0.14N)
+
+The connection to attention: the Hopfield update is mathematically equivalent to one step of cross-attention where queries attend to stored memories. But the ITERATIVE application (multiple steps) and the energy-based analysis are what differentiate this from standard attention.
+
+---
+
+## What NOT To Do
+
+- Do NOT use `torch.cuda` in any test
+- Do NOT download models/datasets in tests — use synthetic data and mocks
+- Do NOT implement the encoder or generator with real models yet — provide the interface and use mocks in tests
+- Do NOT add wandb, tensorboard, or other logging frameworks yet
+- Do NOT try to optimize for speed yet — correctness first
+- Do NOT implement the training loop (memory bank training) — that's phase 2 of the research, not this codebase yet
diff --git a/README.md b/README.md
new file mode 100644
index 0000000..7bd7a72
--- /dev/null
+++ b/README.md
@@ -0,0 +1,40 @@
+# HAG: Hopfield-Augmented Generation
+
+A retrieval-augmented generation system that replaces standard vector similarity search (FAISS top-k) with **iterative Modern Hopfield Network retrieval**.
+
+## Quick Start
+
+```bash
+pip install -e ".[dev]"
+pytest
+```
+
+## Core Idea
+
+Standard RAG retrieves passages once via FAISS similarity. HAG iteratively refines the query in embedding space using Modern Continuous Hopfield Network dynamics:
+
+```
+q_{t+1} = M * softmax(beta * M^T * q_t)
+```
+
+This enables implicit multi-hop reasoning through energy minimization with convergence guarantees.
+
+## Project Structure
+
+- `hag/hopfield.py` — Core Hopfield retrieval module
+- `hag/memory_bank.py` — Passage embedding storage
+- `hag/retriever_hopfield.py` — Hopfield-based retriever
+- `hag/retriever_faiss.py` — FAISS baseline retriever
+- `hag/pipeline.py` — End-to-end RAG/HAG pipeline
+- `hag/metrics.py` — Evaluation metrics (EM, F1, retrieval recall)
+- `hag/energy.py` — Energy analysis utilities
+- `scripts/` — Evaluation and analysis scripts
+- `tests/` — Comprehensive test suite
+
+## Running Tests
+
+```bash
+pytest tests/ -v
+```
+
+All tests run on CPU with synthetic data — no GPU or model downloads required.
diff --git a/TBD.md b/TBD.md
new file mode 100644
index 0000000..93b2771
--- /dev/null
+++ b/TBD.md
@@ -0,0 +1,23 @@
+# TBD — GPU Ready 后待完成
+
+## 1. Encoder 加 GPU 支持
+- `hag/encoder.py` 的 `Encoder` 类加 `device` 参数
+- `_load_model()` 时 `.to(device)`,推理时 tensor 搬到 GPU
+- batch encode 结果 `.cpu()` 回传
+
+## 2. 数据准备脚本
+- 新建 `scripts/prepare_hotpotqa.py`
+- 从 HuggingFace 下载 HotpotQA,提取所有 context passages
+- 去重,存成 memory bank 可消费的格式(JSONL 或直接构建 memory_bank.pt)
+- 同时导出 questions + gold answers + gold passage indices 供评测用
+
+## 3. 对比评测脚本
+- 新建 `scripts/run_comparison.py`
+- 同一个 memory bank,分别跑 FAISS 和 Hopfield retriever
+- 同一个 LLM 生成答案
+- 输出对比表格:EM / F1 / Retrieval Recall@k
+- 支持不同 beta、max_iter 的 sweep
+
+## 4. run_eval.py 补 FAISS 路径
+- 当前 `run_eval.py` 只接 memory_bank,FAISS 模式需要额外从 memory_bank 构建 FAISS 索引
+- 加上自动转换逻辑
diff --git a/configs/2wikimhqa.yaml b/configs/2wikimhqa.yaml
new file mode 100644
index 0000000..0048cf2
--- /dev/null
+++ b/configs/2wikimhqa.yaml
@@ -0,0 +1,22 @@
+hopfield:
+ beta: 2.0
+ max_iter: 8
+ conv_threshold: 1.0e-4
+ top_k: 5
+
+memory:
+ embedding_dim: 768
+ normalize: true
+
+encoder:
+ model_name: "facebook/contriever-msmarco"
+ max_length: 512
+ batch_size: 64
+
+generator:
+ model_name: "meta-llama/Llama-3.1-8B-Instruct"
+ max_new_tokens: 128
+ temperature: 0.0
+
+retriever_type: "hopfield"
+dataset: "2wikimultihopqa"
diff --git a/configs/default.yaml b/configs/default.yaml
new file mode 100644
index 0000000..cb76d34
--- /dev/null
+++ b/configs/default.yaml
@@ -0,0 +1,21 @@
+hopfield:
+ beta: 1.0
+ max_iter: 5
+ conv_threshold: 1.0e-4
+ top_k: 5
+
+memory:
+ embedding_dim: 768
+ normalize: true
+
+encoder:
+ model_name: "facebook/contriever-msmarco"
+ max_length: 512
+ batch_size: 64
+
+generator:
+ model_name: "meta-llama/Llama-3.1-8B-Instruct"
+ max_new_tokens: 128
+ temperature: 0.0
+
+retriever_type: "hopfield"
diff --git a/configs/hotpotqa.yaml b/configs/hotpotqa.yaml
new file mode 100644
index 0000000..99ea9d7
--- /dev/null
+++ b/configs/hotpotqa.yaml
@@ -0,0 +1,22 @@
+hopfield:
+ beta: 2.0
+ max_iter: 8
+ conv_threshold: 1.0e-4
+ top_k: 5
+
+memory:
+ embedding_dim: 768
+ normalize: true
+
+encoder:
+ model_name: "facebook/contriever-msmarco"
+ max_length: 512
+ batch_size: 64
+
+generator:
+ model_name: "meta-llama/Llama-3.1-8B-Instruct"
+ max_new_tokens: 128
+ temperature: 0.0
+
+retriever_type: "hopfield"
+dataset: "hotpotqa"
diff --git a/configs/musique.yaml b/configs/musique.yaml
new file mode 100644
index 0000000..52dd7c4
--- /dev/null
+++ b/configs/musique.yaml
@@ -0,0 +1,22 @@
+hopfield:
+ beta: 2.0
+ max_iter: 8
+ conv_threshold: 1.0e-4
+ top_k: 5
+
+memory:
+ embedding_dim: 768
+ normalize: true
+
+encoder:
+ model_name: "facebook/contriever-msmarco"
+ max_length: 512
+ batch_size: 64
+
+generator:
+ model_name: "meta-llama/Llama-3.1-8B-Instruct"
+ max_new_tokens: 128
+ temperature: 0.0
+
+retriever_type: "hopfield"
+dataset: "musique"
diff --git a/hag/__init__.py b/hag/__init__.py
new file mode 100644
index 0000000..18496e9
--- /dev/null
+++ b/hag/__init__.py
@@ -0,0 +1,21 @@
+"""HAG: Hopfield-Augmented Generation."""
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.datatypes import HopfieldResult, PipelineResult, RetrievalResult
+
+__all__ = [
+ "HopfieldConfig",
+ "MemoryBankConfig",
+ "EncoderConfig",
+ "GeneratorConfig",
+ "PipelineConfig",
+ "HopfieldResult",
+ "RetrievalResult",
+ "PipelineResult",
+]
diff --git a/hag/config.py b/hag/config.py
new file mode 100644
index 0000000..793e3a6
--- /dev/null
+++ b/hag/config.py
@@ -0,0 +1,50 @@
+"""All hyperparameters and configuration dataclasses for HAG."""
+
+from dataclasses import dataclass, field
+
+
+@dataclass
+class HopfieldConfig:
+ """Configuration for the Hopfield retrieval module."""
+
+ beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval
+ max_iter: int = 5 # Maximum Hopfield iteration steps
+ conv_threshold: float = 1e-4 # Stop if ||q_{t+1} - q_t|| < threshold
+ top_k: int = 5 # Number of passages to retrieve from final attention weights
+
+
+@dataclass
+class MemoryBankConfig:
+ """Configuration for the memory bank."""
+
+ embedding_dim: int = 768 # Must match encoder output dim
+ normalize: bool = True # L2-normalize embeddings in memory bank
+
+
+@dataclass
+class EncoderConfig:
+ """Configuration for the query/passage encoder."""
+
+ model_name: str = "facebook/contriever-msmarco"
+ max_length: int = 512
+ batch_size: int = 64
+
+
+@dataclass
+class GeneratorConfig:
+ """Configuration for the LLM generator."""
+
+ model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
+ max_new_tokens: int = 128
+ temperature: float = 0.0 # Greedy decoding for reproducibility
+
+
+@dataclass
+class PipelineConfig:
+ """Top-level pipeline configuration."""
+
+ hopfield: HopfieldConfig = field(default_factory=HopfieldConfig)
+ memory: MemoryBankConfig = field(default_factory=MemoryBankConfig)
+ encoder: EncoderConfig = field(default_factory=EncoderConfig)
+ generator: GeneratorConfig = field(default_factory=GeneratorConfig)
+ retriever_type: str = "hopfield" # "hopfield" or "faiss"
diff --git a/hag/datatypes.py b/hag/datatypes.py
new file mode 100644
index 0000000..0f4254d
--- /dev/null
+++ b/hag/datatypes.py
@@ -0,0 +1,37 @@
+"""Data types used across HAG modules."""
+
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import torch
+
+
+@dataclass
+class HopfieldResult:
+ """Result from Hopfield iterative retrieval."""
+
+ attention_weights: torch.Tensor # (batch, N) or (N,)
+ converged_query: torch.Tensor # (batch, d) or (d,)
+ num_steps: int
+ trajectory: Optional[List[torch.Tensor]] = None # list of q_t
+ energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t)
+
+
+@dataclass
+class RetrievalResult:
+ """Result from a retriever (FAISS or Hopfield)."""
+
+ passages: List[str]
+ scores: torch.Tensor # top-k scores
+ indices: torch.Tensor # top-k indices
+ hopfield_result: Optional[HopfieldResult] = None
+
+
+@dataclass
+class PipelineResult:
+ """Result from the full RAG/HAG pipeline."""
+
+ question: str
+ answer: str
+ retrieved_passages: List[str]
+ retrieval_result: RetrievalResult
diff --git a/hag/encoder.py b/hag/encoder.py
new file mode 100644
index 0000000..7e103f3
--- /dev/null
+++ b/hag/encoder.py
@@ -0,0 +1,88 @@
+"""Wrapper for encoding queries and passages into embeddings."""
+
+import logging
+from typing import List, Union
+
+import torch
+
+from hag.config import EncoderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class Encoder:
+ """Encodes text queries/passages into dense embeddings.
+
+ Uses a HuggingFace transformer model (e.g., Contriever).
+ For testing, use FakeEncoder instead.
+ """
+
+ def __init__(self, config: EncoderConfig) -> None:
+ self.config = config
+ self._tokenizer = None
+ self._model = None
+
+ def _load_model(self) -> None:
+ """Lazy-load the model and tokenizer."""
+ from transformers import AutoModel, AutoTokenizer
+
+ logger.info("Loading encoder model: %s", self.config.model_name)
+ self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
+ self._model = AutoModel.from_pretrained(self.config.model_name)
+ self._model.eval()
+
+ @torch.no_grad()
+ def encode(self, texts: Union[str, List[str]]) -> torch.Tensor:
+ """Encode text(s) into embedding(s).
+
+ Args:
+ texts: single string or list of strings
+
+ Returns:
+ (1, d) tensor for single input, (N, d) for list input.
+ """
+ if self._model is None:
+ self._load_model()
+
+ if isinstance(texts, str):
+ texts = [texts]
+
+ inputs = self._tokenizer(
+ texts,
+ max_length=self.config.max_length,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ )
+ outputs = self._model(**inputs)
+ # Mean pooling over token embeddings
+ embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d)
+ return embeddings
+
+
+class FakeEncoder:
+ """Deterministic hash-based encoder for testing. No model download needed."""
+
+ def __init__(self, dim: int = 64) -> None:
+ self.dim = dim
+
+ def encode(self, texts: Union[str, List[str]]) -> torch.Tensor:
+ """Produce deterministic embeddings based on text hash.
+
+ Args:
+ texts: single string or list of strings
+
+ Returns:
+ (1, d) or (N, d) normalized tensor.
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ embeddings = []
+ for text in texts:
+ torch.manual_seed(hash(text) % 2**32)
+ emb = torch.randn(1, self.dim)
+ embeddings.append(emb)
+
+ result = torch.cat(embeddings, dim=0) # (N, d)
+ return torch.nn.functional.normalize(result, dim=-1)
diff --git a/hag/energy.py b/hag/energy.py
new file mode 100644
index 0000000..62a39e9
--- /dev/null
+++ b/hag/energy.py
@@ -0,0 +1,83 @@
+"""Energy computation and analysis utilities for Hopfield retrieval."""
+
+import logging
+from typing import List
+
+import torch
+
+from hag.datatypes import HopfieldResult
+
+logger = logging.getLogger(__name__)
+
+
+def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]:
+ """Extract energy values at each iteration step.
+
+ Args:
+ hopfield_result: result from HopfieldRetrieval.retrieve() with return_energy=True
+
+ Returns:
+ List of energy values (floats) at each step.
+ """
+ if hopfield_result.energy_curve is None:
+ return []
+ return [e.item() if e.dim() == 0 else e.mean().item() for e in hopfield_result.energy_curve]
+
+
+def compute_energy_gap(energy_curve: List[float]) -> float:
+ """Compute the energy gap: Delta_E = E(q_0) - E(q_T).
+
+ Larger gap means more refinement happened during iteration.
+
+ Args:
+ energy_curve: list of energy values at each step
+
+ Returns:
+ Energy gap (float). Positive if energy decreased.
+ """
+ if len(energy_curve) < 2:
+ return 0.0
+ return energy_curve[0] - energy_curve[-1]
+
+
+def verify_monotonic_decrease(energy_curve: List[float], tol: float = 1e-6) -> bool:
+ """Check that E(q_{t+1}) <= E(q_t) for all t.
+
+ This should always be True for the Modern Hopfield Network.
+
+ Args:
+ energy_curve: list of energy values at each step
+ tol: numerical tolerance for comparison
+
+ Returns:
+ True if energy decreases monotonically (within tolerance).
+ """
+ for i in range(len(energy_curve) - 1):
+ if energy_curve[i + 1] > energy_curve[i] + tol:
+ return False
+ return True
+
+
+def compute_attention_entropy(attention_weights: torch.Tensor) -> float:
+ """Compute the entropy of attention weights.
+
+ H(alpha) = -sum_i alpha_i * log(alpha_i)
+
+ Low entropy = sharp retrieval (confident).
+ High entropy = diffuse retrieval (uncertain).
+
+ Args:
+ attention_weights: (N,) or (batch, N) — attention distribution
+
+ Returns:
+ Entropy value (float). Averaged over batch if batched.
+ """
+ if attention_weights.dim() == 1:
+ attention_weights = attention_weights.unsqueeze(0) # (1, N)
+
+ # Clamp to avoid log(0)
+ eps = 1e-12
+ alpha = attention_weights.clamp(min=eps)
+ entropy = -(alpha * alpha.log()).sum(dim=-1) # (batch,)
+
+ return entropy.mean().item()
diff --git a/hag/generator.py b/hag/generator.py
new file mode 100644
index 0000000..2142e0c
--- /dev/null
+++ b/hag/generator.py
@@ -0,0 +1,87 @@
+"""LLM generation wrapper for producing answers from retrieved context."""
+
+import logging
+from typing import List
+
+from hag.config import GeneratorConfig
+
+logger = logging.getLogger(__name__)
+
+PROMPT_TEMPLATE = """Answer the following question based on the provided context passages.
+
+Context:
+{context}
+
+Question: {question}
+
+Answer:"""
+
+
+class Generator:
+ """LLM-based answer generator.
+
+ Uses a HuggingFace causal LM (e.g., Llama-3.1-8B-Instruct).
+ For testing, use FakeGenerator instead.
+ """
+
+ def __init__(self, config: GeneratorConfig) -> None:
+ self.config = config
+ self._tokenizer = None
+ self._model = None
+
+ def _load_model(self) -> None:
+ """Lazy-load the model and tokenizer."""
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ logger.info("Loading generator model: %s", self.config.model_name)
+ self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
+ self._model = AutoModelForCausalLM.from_pretrained(
+ self.config.model_name,
+ torch_dtype="auto",
+ )
+ self._model.eval()
+
+ def generate(self, question: str, passages: List[str]) -> str:
+ """Generate an answer given a question and retrieved passages.
+
+ Args:
+ question: the user question
+ passages: list of retrieved passage texts
+
+ Returns:
+ Generated answer string.
+ """
+ if self._model is None:
+ self._load_model()
+
+ context = "\n\n".join(
+ f"[{i+1}] {p}" for i, p in enumerate(passages)
+ )
+ prompt = PROMPT_TEMPLATE.format(context=context, question=question)
+
+ inputs = self._tokenizer(prompt, return_tensors="pt")
+ outputs = self._model.generate(
+ **inputs,
+ max_new_tokens=self.config.max_new_tokens,
+ temperature=self.config.temperature if self.config.temperature > 0 else None,
+ do_sample=self.config.temperature > 0,
+ )
+ # Decode only the generated tokens (skip the prompt)
+ generated = outputs[0][inputs["input_ids"].shape[1]:]
+ return self._tokenizer.decode(generated, skip_special_tokens=True).strip()
+
+
+class FakeGenerator:
+ """Deterministic mock generator for testing. No model download needed."""
+
+ def generate(self, question: str, passages: List[str]) -> str:
+ """Return a mock answer.
+
+ Args:
+ question: the user question
+ passages: list of retrieved passages
+
+ Returns:
+ Mock answer string.
+ """
+ return "mock answer"
diff --git a/hag/hopfield.py b/hag/hopfield.py
new file mode 100644
index 0000000..287e4af
--- /dev/null
+++ b/hag/hopfield.py
@@ -0,0 +1,124 @@
+"""Core Modern Continuous Hopfield Network retrieval module.
+
+Implements the iterative retrieval dynamics from:
+ Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021)
+
+Update rule: q_{t+1} = M * softmax(beta * M^T * q_t)
+Energy: E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+"""
+
+import logging
+from typing import Optional
+
+import torch
+
+from hag.config import HopfieldConfig
+from hag.datatypes import HopfieldResult
+
+logger = logging.getLogger(__name__)
+
+
+class HopfieldRetrieval:
+ """Modern Continuous Hopfield Network for memory retrieval.
+
+ Given memory bank M in R^{d x N} and query q in R^d:
+ 1. Compute attention: alpha = softmax(beta * M^T @ q)
+ 2. Update query: q_new = M @ alpha
+ 3. Repeat until convergence or max_iter
+
+ The energy function is:
+ E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+
+ Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease)
+ """
+
+ def __init__(self, config: HopfieldConfig) -> None:
+ self.config = config
+
+ @torch.no_grad()
+ def retrieve(
+ self,
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ return_trajectory: bool = False,
+ return_energy: bool = False,
+ ) -> HopfieldResult:
+ """Run iterative Hopfield retrieval.
+
+ Args:
+ query: (d,) or (batch, d) — query embedding(s)
+ memory: (d, N) — memory bank of passage embeddings
+ return_trajectory: if True, store q_t at each step
+ return_energy: if True, store E(q_t) at each step
+
+ Returns:
+ HopfieldResult with attention_weights, converged_query, num_steps,
+ and optionally trajectory and energy_curve.
+ """
+ # Ensure query is 2D: (batch, d)
+ if query.dim() == 1:
+ query = query.unsqueeze(0) # (1, d)
+
+ q = query.clone() # (batch, d)
+
+ trajectory = [q.clone()] if return_trajectory else None
+ energies = [self.compute_energy(q, memory)] if return_energy else None
+
+ num_steps = 0
+ for t in range(self.config.max_iter):
+ # Core Hopfield update
+ logits = self.config.beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ q_new = alpha @ memory.T # (batch, d)
+
+ # Check convergence
+ delta = torch.norm(q_new - q, dim=-1).max() # scalar
+ q = q_new
+
+ if return_trajectory:
+ trajectory.append(q.clone())
+ if return_energy:
+ energies.append(self.compute_energy(q, memory))
+
+ num_steps = t + 1
+
+ if delta < self.config.conv_threshold:
+ break
+
+ # Final attention weights (recompute to ensure consistency)
+ logits = self.config.beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+
+ return HopfieldResult(
+ attention_weights=alpha,
+ converged_query=q,
+ num_steps=num_steps,
+ trajectory=trajectory,
+ energy_curve=energies,
+ )
+
+ def compute_energy(
+ self,
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute the Hopfield energy function.
+
+ E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+
+ Args:
+ query: (batch, d) or (d,) — query embedding(s)
+ memory: (d, N) — memory bank
+
+ Returns:
+ Energy scalar or (batch,) tensor.
+ """
+ if query.dim() == 1:
+ query = query.unsqueeze(0) # (1, d)
+
+ logits = self.config.beta * (query @ memory) # (batch, N)
+ lse = torch.logsumexp(logits, dim=-1) # (batch,)
+ norm_sq = 0.5 * (query**2).sum(dim=-1) # (batch,)
+ energy = -1.0 / self.config.beta * lse + norm_sq # (batch,)
+
+ return energy
diff --git a/hag/memory_bank.py b/hag/memory_bank.py
new file mode 100644
index 0000000..42dcc73
--- /dev/null
+++ b/hag/memory_bank.py
@@ -0,0 +1,93 @@
+"""Memory bank construction and management for passage embeddings."""
+
+import logging
+from typing import Dict, List, Optional
+
+import torch
+import torch.nn.functional as F
+
+from hag.config import MemoryBankConfig
+
+logger = logging.getLogger(__name__)
+
+
+class MemoryBank:
+ """Stores passage embeddings and provides lookup from indices back to text.
+
+ The memory bank is M in R^{d x N} where each column is a passage embedding.
+ Also maintains a mapping from column index to passage text for final retrieval.
+ """
+
+ def __init__(self, config: MemoryBankConfig) -> None:
+ self.config = config
+ self.embeddings: Optional[torch.Tensor] = None # (d, N)
+ self.passages: List[str] = []
+
+ def build_from_embeddings(
+ self, embeddings: torch.Tensor, passages: List[str]
+ ) -> None:
+ """Build memory bank from precomputed embeddings.
+
+ Args:
+ embeddings: (N, d) — passage embeddings (note: input is N x d)
+ passages: list of N passage strings
+ """
+ assert embeddings.shape[0] == len(passages), (
+ f"Number of embeddings ({embeddings.shape[0]}) must match "
+ f"number of passages ({len(passages)})"
+ )
+ if self.config.normalize:
+ embeddings = F.normalize(embeddings, dim=-1)
+ self.embeddings = embeddings.T # Store as (d, N) for efficient matmul
+ self.passages = list(passages)
+ logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim)
+
+ def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
+ """Given top-k indices, return corresponding passage texts.
+
+ Args:
+ indices: (k,) or (batch, k) tensor of integer indices
+
+ Returns:
+ List of passage strings.
+ """
+ flat_indices = indices.flatten().tolist()
+ return [self.passages[i] for i in flat_indices]
+
+ def save(self, path: str) -> None:
+ """Save memory bank to disk.
+
+ Args:
+ path: file path for saving (e.g., 'memory_bank.pt')
+ """
+ data: Dict = {
+ "embeddings": self.embeddings,
+ "passages": self.passages,
+ "config": {
+ "embedding_dim": self.config.embedding_dim,
+ "normalize": self.config.normalize,
+ },
+ }
+ torch.save(data, path)
+ logger.info("Saved memory bank to %s", path)
+
+ def load(self, path: str) -> None:
+ """Load memory bank from disk.
+
+ Args:
+ path: file path to load from
+ """
+ data = torch.load(path, weights_only=False)
+ self.embeddings = data["embeddings"]
+ self.passages = data["passages"]
+ logger.info("Loaded memory bank from %s (%d passages)", path, self.size)
+
+ @property
+ def size(self) -> int:
+ """Number of passages in the memory bank."""
+ return self.embeddings.shape[1] if self.embeddings is not None else 0
+
+ @property
+ def dim(self) -> int:
+ """Embedding dimensionality."""
+ return self.embeddings.shape[0] if self.embeddings is not None else 0
diff --git a/hag/metrics.py b/hag/metrics.py
new file mode 100644
index 0000000..6a196df
--- /dev/null
+++ b/hag/metrics.py
@@ -0,0 +1,113 @@
+"""Evaluation metrics for HAG: exact match, F1, retrieval recall."""
+
+import logging
+import re
+import string
+from collections import Counter
+from typing import Dict, List
+
+from hag.datatypes import PipelineResult
+
+logger = logging.getLogger(__name__)
+
+
+def _normalize_answer(text: str) -> str:
+ """Normalize answer text: lowercase, strip, remove articles and punctuation."""
+ text = text.lower().strip()
+ # Remove articles
+ text = re.sub(r"\b(a|an|the)\b", " ", text)
+ # Remove punctuation
+ text = text.translate(str.maketrans("", "", string.punctuation))
+ # Collapse whitespace
+ text = " ".join(text.split())
+ return text
+
+
+def exact_match(prediction: str, ground_truth: str) -> float:
+ """Normalized exact match.
+
+ Args:
+ prediction: predicted answer string
+ ground_truth: gold answer string
+
+ Returns:
+ 1.0 if normalized strings match, 0.0 otherwise.
+ """
+ return float(_normalize_answer(prediction) == _normalize_answer(ground_truth))
+
+
+def f1_score(prediction: str, ground_truth: str) -> float:
+ """Token-level F1 between prediction and ground truth.
+
+ Args:
+ prediction: predicted answer string
+ ground_truth: gold answer string
+
+ Returns:
+ F1 score between 0.0 and 1.0.
+ """
+ pred_tokens = _normalize_answer(prediction).split()
+ gold_tokens = _normalize_answer(ground_truth).split()
+
+ if not pred_tokens and not gold_tokens:
+ return 1.0
+ if not pred_tokens or not gold_tokens:
+ return 0.0
+
+ common = Counter(pred_tokens) & Counter(gold_tokens)
+ num_same = sum(common.values())
+
+ if num_same == 0:
+ return 0.0
+
+ precision = num_same / len(pred_tokens)
+ recall = num_same / len(gold_tokens)
+ f1 = 2 * precision * recall / (precision + recall)
+ return f1
+
+
+def retrieval_recall_at_k(
+ retrieved_indices: List[int], gold_indices: List[int], k: int
+) -> float:
+ """What fraction of gold passages appear in the retrieved top-k?
+
+ Args:
+ retrieved_indices: list of retrieved passage indices (top-k)
+ gold_indices: list of gold/relevant passage indices
+ k: number of retrieved passages to consider
+
+ Returns:
+ Recall score between 0.0 and 1.0.
+ """
+ if not gold_indices:
+ return 1.0
+ retrieved_set = set(retrieved_indices[:k])
+ gold_set = set(gold_indices)
+ return len(retrieved_set & gold_set) / len(gold_set)
+
+
+def evaluate_dataset(
+ results: List[PipelineResult], gold_answers: List[str]
+) -> Dict[str, float]:
+ """Compute aggregate metrics over a dataset.
+
+ Args:
+ results: list of PipelineResult from the pipeline
+ gold_answers: list of gold answer strings
+
+ Returns:
+ Dict with keys 'em', 'f1' containing averaged scores.
+ """
+ assert len(results) == len(gold_answers)
+
+ em_scores = []
+ f1_scores = []
+
+ for result, gold in zip(results, gold_answers):
+ em_scores.append(exact_match(result.answer, gold))
+ f1_scores.append(f1_score(result.answer, gold))
+
+ return {
+ "em": sum(em_scores) / len(em_scores) if em_scores else 0.0,
+ "f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
+ }
diff --git a/hag/pipeline.py b/hag/pipeline.py
new file mode 100644
index 0000000..1fefb84
--- /dev/null
+++ b/hag/pipeline.py
@@ -0,0 +1,107 @@
+"""End-to-end RAG/HAG pipeline: query -> encode -> retrieve -> generate."""
+
+import logging
+from typing import List, Optional, Protocol, Union
+
+import numpy as np
+import torch
+
+from hag.config import PipelineConfig
+from hag.datatypes import PipelineResult, RetrievalResult
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+from hag.retriever_faiss import FAISSRetriever
+from hag.retriever_hopfield import HopfieldRetriever
+
+logger = logging.getLogger(__name__)
+
+
+class EncoderProtocol(Protocol):
+ """Protocol for encoder interface."""
+
+ def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: ...
+
+
+class GeneratorProtocol(Protocol):
+ """Protocol for generator interface."""
+
+ def generate(self, question: str, passages: List[str]) -> str: ...
+
+
+class RAGPipeline:
+ """End-to-end pipeline: query -> encode -> retrieve -> generate.
+
+ Supports both FAISS (baseline) and Hopfield (ours) retrieval.
+ """
+
+ def __init__(
+ self,
+ config: PipelineConfig,
+ encoder: EncoderProtocol,
+ generator: GeneratorProtocol,
+ memory_bank: Optional[MemoryBank] = None,
+ faiss_retriever: Optional[FAISSRetriever] = None,
+ ) -> None:
+ self.config = config
+ self.encoder = encoder
+ self.generator = generator
+
+ if config.retriever_type == "faiss":
+ assert faiss_retriever is not None, "FAISSRetriever required for faiss mode"
+ self.retriever_type = "faiss"
+ self.faiss_retriever = faiss_retriever
+ self.hopfield_retriever: Optional[HopfieldRetriever] = None
+ elif config.retriever_type == "hopfield":
+ assert memory_bank is not None, "MemoryBank required for hopfield mode"
+ hopfield = HopfieldRetrieval(config.hopfield)
+ self.retriever_type = "hopfield"
+ self.hopfield_retriever = HopfieldRetriever(
+ hopfield, memory_bank, top_k=config.hopfield.top_k
+ )
+ self.faiss_retriever = None
+ else:
+ raise ValueError(f"Unknown retriever_type: {config.retriever_type}")
+
+ def run(self, question: str) -> PipelineResult:
+ """Run the full pipeline on a single question.
+
+ 1. Encode question -> query embedding
+ 2. Retrieve passages (FAISS or Hopfield)
+ 3. Generate answer with LLM
+
+ Args:
+ question: input question string
+
+ Returns:
+ PipelineResult with answer and retrieval metadata.
+ """
+ # Encode
+ query_emb = self.encoder.encode(question) # (1, d)
+
+ # Retrieve
+ if self.retriever_type == "hopfield":
+ retrieval_result = self.hopfield_retriever.retrieve(query_emb)
+ else:
+ query_np = query_emb.detach().numpy().astype(np.float32)
+ retrieval_result = self.faiss_retriever.retrieve(query_np)
+
+ # Generate
+ answer = self.generator.generate(question, retrieval_result.passages)
+
+ return PipelineResult(
+ question=question,
+ answer=answer,
+ retrieved_passages=retrieval_result.passages,
+ retrieval_result=retrieval_result,
+ )
+
+ def run_batch(self, questions: List[str]) -> List[PipelineResult]:
+ """Run pipeline on a batch of questions.
+
+ Args:
+ questions: list of question strings
+
+ Returns:
+ List of PipelineResult, one per question.
+ """
+ return [self.run(q) for q in questions]
diff --git a/hag/retriever_faiss.py b/hag/retriever_faiss.py
new file mode 100644
index 0000000..cd54a85
--- /dev/null
+++ b/hag/retriever_faiss.py
@@ -0,0 +1,73 @@
+"""Baseline FAISS top-k retriever for vanilla RAG."""
+
+import logging
+from typing import List, Optional
+
+import faiss
+import numpy as np
+import torch
+
+from hag.datatypes import RetrievalResult
+
+logger = logging.getLogger(__name__)
+
+
+class FAISSRetriever:
+ """Standard top-k retrieval using FAISS inner product search.
+
+ This is the baseline to compare against Hopfield retrieval.
+ """
+
+ def __init__(self, top_k: int = 5) -> None:
+ self.index: Optional[faiss.IndexFlatIP] = None
+ self.passages: List[str] = []
+ self.top_k = top_k
+
+ def build_index(self, embeddings: np.ndarray, passages: List[str]) -> None:
+ """Build FAISS IndexFlatIP from embeddings.
+
+ Args:
+ embeddings: (N, d) numpy array of passage embeddings
+ passages: list of N passage strings
+ """
+ assert embeddings.shape[0] == len(passages)
+ d = embeddings.shape[1]
+ self.index = faiss.IndexFlatIP(d)
+ # Normalize for cosine similarity via inner product
+ faiss.normalize_L2(embeddings)
+ self.index.add(embeddings)
+ self.passages = list(passages)
+ logger.info("Built FAISS index with %d passages, dim=%d", len(passages), d)
+
+ def retrieve(self, query: np.ndarray) -> RetrievalResult:
+ """Retrieve top-k passages for a query.
+
+ Args:
+ query: (d,) or (batch, d) numpy array
+
+ Returns:
+ RetrievalResult with passages, scores, and indices.
+ """
+ assert self.index is not None, "Index not built. Call build_index first."
+
+ if query.ndim == 1:
+ query = query.reshape(1, -1) # (1, d)
+
+ # Normalize query for cosine similarity
+ query_copy = query.copy()
+ faiss.normalize_L2(query_copy)
+
+ scores, indices = self.index.search(query_copy, self.top_k) # (batch, k)
+
+ # Flatten for single query case
+ if scores.shape[0] == 1:
+ scores = scores[0] # (k,)
+ indices = indices[0] # (k,)
+
+ passages = [self.passages[i] for i in indices.flatten().tolist()]
+
+ return RetrievalResult(
+ passages=passages,
+ scores=torch.from_numpy(scores).float(),
+ indices=torch.from_numpy(indices).long(),
+ )
diff --git a/hag/retriever_hopfield.py b/hag/retriever_hopfield.py
new file mode 100644
index 0000000..1cb6968
--- /dev/null
+++ b/hag/retriever_hopfield.py
@@ -0,0 +1,77 @@
+"""Hopfield-based retriever wrapping HopfieldRetrieval + MemoryBank."""
+
+import logging
+from typing import List
+
+import torch
+
+from hag.datatypes import RetrievalResult
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+
+logger = logging.getLogger(__name__)
+
+
+class HopfieldRetriever:
+ """Wraps HopfieldRetrieval + MemoryBank into a retriever interface.
+
+ The bridge between Hopfield's continuous retrieval and the discrete
+ passage selection needed for LLM prompting.
+ """
+
+ def __init__(
+ self,
+ hopfield: HopfieldRetrieval,
+ memory_bank: MemoryBank,
+ top_k: int = 5,
+ ) -> None:
+ self.hopfield = hopfield
+ self.memory_bank = memory_bank
+ self.top_k = top_k
+
+ def retrieve(
+ self,
+ query_embedding: torch.Tensor,
+ return_analysis: bool = False,
+ ) -> RetrievalResult:
+ """Retrieve top-k passages using iterative Hopfield retrieval.
+
+ 1. Run Hopfield iterative retrieval -> get attention weights alpha_T
+ 2. Take top_k indices from alpha_T
+ 3. Look up corresponding passage texts from memory bank
+ 4. Optionally return trajectory and energy for analysis
+
+ Args:
+ query_embedding: (d,) or (batch, d) — query embedding
+ return_analysis: if True, include full HopfieldResult
+
+ Returns:
+ RetrievalResult with passages, scores, indices, and optionally
+ the full hopfield_result.
+ """
+ hopfield_result = self.hopfield.retrieve(
+ query_embedding,
+ self.memory_bank.embeddings,
+ return_trajectory=return_analysis,
+ return_energy=return_analysis,
+ )
+
+ alpha = hopfield_result.attention_weights # (batch, N) or (1, N)
+
+ # Get top-k indices and scores
+ k = min(self.top_k, alpha.shape[-1])
+ scores, indices = torch.topk(alpha, k, dim=-1) # (batch, k)
+
+ # Flatten for single-query case
+ if scores.shape[0] == 1:
+ scores = scores.squeeze(0) # (k,)
+ indices = indices.squeeze(0) # (k,)
+
+ passages = self.memory_bank.get_passages_by_indices(indices)
+
+ return RetrievalResult(
+ passages=passages,
+ scores=scores,
+ indices=indices,
+ hopfield_result=hopfield_result if return_analysis else None,
+ )
diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb
new file mode 100644
index 0000000..caace5b
--- /dev/null
+++ b/notebooks/demo.ipynb
@@ -0,0 +1,62 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "metadata": {},
+ "source": [
+ "# HAG: Hopfield-Augmented Generation Demo\n",
+ "\n",
+ "This notebook demonstrates the core Hopfield retrieval mechanism with synthetic data."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import torch\n",
+ "import torch.nn.functional as F\n",
+ "\n",
+ "from hag.config import HopfieldConfig\n",
+ "from hag.hopfield import HopfieldRetrieval\n",
+ "from hag.energy import compute_energy_curve, verify_monotonic_decrease, compute_attention_entropy"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create synthetic memory bank and query\n",
+ "torch.manual_seed(42)\n",
+ "d, N = 64, 200\n",
+ "memory = F.normalize(torch.randn(d, N), dim=0)\n",
+ "query = F.normalize(torch.randn(1, d), dim=-1)\n",
+ "\n",
+ "# Run Hopfield retrieval with different beta values\n",
+ "for beta in [0.5, 1.0, 2.0, 5.0]:\n",
+ " config = HopfieldConfig(beta=beta, max_iter=20, conv_threshold=1e-6)\n",
+ " hopfield = HopfieldRetrieval(config)\n",
+ " result = hopfield.retrieve(query, memory, return_energy=True)\n",
+ " curve = compute_energy_curve(result)\n",
+ " entropy = compute_attention_entropy(result.attention_weights)\n",
+ " print(f'beta={beta}: steps={result.num_steps}, entropy={entropy:.4f}, monotonic={verify_monotonic_decrease(curve)}')"
+ ]
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "name": "python",
+ "version": "3.10.0"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 4
+}
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000..24f7806
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,36 @@
+[build-system]
+requires = ["setuptools>=68.0", "wheel"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "hag"
+version = "0.1.0"
+requires-python = ">=3.10"
+description = "Hopfield-Augmented Generation: RAG with iterative Modern Hopfield Network retrieval"
+
+dependencies = [
+ "torch>=2.0",
+ "numpy>=1.24",
+ "faiss-cpu>=1.7",
+ "transformers>=4.36",
+ "datasets>=2.14",
+ "pyyaml>=6.0",
+ "tqdm>=4.65",
+]
+
+[project.optional-dependencies]
+dev = [
+ "pytest>=7.0",
+ "pytest-cov>=4.0",
+]
+eval = [
+ "scikit-learn>=1.3",
+ "umap-learn>=0.5",
+ "matplotlib>=3.7",
+]
+
+[tool.setuptools.packages.find]
+include = ["hag*"]
+
+[tool.pytest.ini_options]
+testpaths = ["tests"]
diff --git a/scripts/analyze_energy.py b/scripts/analyze_energy.py
new file mode 100644
index 0000000..fd044a4
--- /dev/null
+++ b/scripts/analyze_energy.py
@@ -0,0 +1,78 @@
+"""Analyze energy curves and convergence properties of Hopfield retrieval.
+
+Usage:
+ python scripts/analyze_energy.py --config configs/default.yaml --memory-bank data/memory_bank.pt --questions data/questions.jsonl --output energy_analysis.json
+"""
+
+import argparse
+import json
+import logging
+
+import torch
+import yaml
+
+from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.energy import (
+ compute_attention_entropy,
+ compute_energy_curve,
+ compute_energy_gap,
+ verify_monotonic_decrease,
+)
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Analyze Hopfield energy curves")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--questions", type=str, required=True)
+ parser.add_argument("--output", type=str, default="energy_analysis.json")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ hopfield_config = HopfieldConfig(**cfg.get("hopfield", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+
+ # Load memory bank
+ mb = MemoryBank(memory_config)
+ mb.load(args.memory_bank)
+
+ # Load questions
+ with open(args.questions) as f:
+ questions = [json.loads(line)["question"] for line in f]
+
+ encoder = Encoder(encoder_config)
+ hopfield = HopfieldRetrieval(hopfield_config)
+
+ analyses = []
+ for q in questions:
+ query_emb = encoder.encode(q) # (1, d)
+ result = hopfield.retrieve(
+ query_emb, mb.embeddings, return_energy=True, return_trajectory=True
+ )
+
+ curve = compute_energy_curve(result)
+ analyses.append({
+ "question": q,
+ "energy_curve": curve,
+ "energy_gap": compute_energy_gap(curve),
+ "monotonic": verify_monotonic_decrease(curve),
+ "num_steps": result.num_steps,
+ "attention_entropy": compute_attention_entropy(result.attention_weights),
+ })
+
+ with open(args.output, "w") as f:
+ json.dump(analyses, f, indent=2)
+ logger.info("Energy analysis saved to %s (%d questions)", args.output, len(analyses))
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/build_memory_bank.py b/scripts/build_memory_bank.py
new file mode 100644
index 0000000..2aff828
--- /dev/null
+++ b/scripts/build_memory_bank.py
@@ -0,0 +1,72 @@
+"""Offline script: encode corpus passages into a memory bank.
+
+Usage:
+ python scripts/build_memory_bank.py --config configs/default.yaml --corpus data/corpus.jsonl --output data/memory_bank.pt
+"""
+
+import argparse
+import json
+import logging
+
+import torch
+import yaml
+from tqdm import tqdm
+
+from hag.config import EncoderConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.memory_bank import MemoryBank
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def load_corpus(path: str) -> list[str]:
+ """Load passages from a JSONL file (one JSON object per line with 'text' field)."""
+ passages = []
+ with open(path) as f:
+ for line in f:
+ obj = json.loads(line)
+ passages.append(obj["text"])
+ return passages
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Build memory bank from corpus")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--corpus", type=str, required=True)
+ parser.add_argument("--output", type=str, required=True)
+ parser.add_argument("--device", type=str, default="cpu")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+
+ # Load corpus
+ logger.info("Loading corpus from %s", args.corpus)
+ passages = load_corpus(args.corpus)
+ logger.info("Loaded %d passages", len(passages))
+
+ # Encode passages in batches
+ encoder = Encoder(encoder_config)
+ all_embeddings = []
+
+ for i in tqdm(range(0, len(passages), encoder_config.batch_size), desc="Encoding"):
+ batch = passages[i : i + encoder_config.batch_size]
+ emb = encoder.encode(batch) # (batch_size, d)
+ all_embeddings.append(emb.cpu())
+
+ embeddings = torch.cat(all_embeddings, dim=0) # (N, d)
+ logger.info("Encoded %d passages -> embeddings shape: %s", len(passages), embeddings.shape)
+
+ # Build and save memory bank
+ mb = MemoryBank(memory_config)
+ mb.build_from_embeddings(embeddings, passages)
+ mb.save(args.output)
+ logger.info("Memory bank saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_baseline.py b/scripts/run_baseline.py
new file mode 100644
index 0000000..74c4710
--- /dev/null
+++ b/scripts/run_baseline.py
@@ -0,0 +1,75 @@
+"""Run vanilla RAG baseline with FAISS retrieval.
+
+Usage:
+ python scripts/run_baseline.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?"
+"""
+
+import argparse
+import logging
+
+import torch
+import yaml
+
+from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, PipelineConfig
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.retriever_faiss import FAISSRetriever
+from hag.pipeline import RAGPipeline
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Run vanilla RAG baseline")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--question", type=str, required=True)
+ parser.add_argument("--top-k", type=int, default=5)
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ # Override retriever type to faiss
+ pipeline_config = PipelineConfig(
+ hopfield=HopfieldConfig(**{**cfg.get("hopfield", {}), "top_k": args.top_k}),
+ encoder=EncoderConfig(**cfg.get("encoder", {})),
+ generator=GeneratorConfig(**cfg.get("generator", {})),
+ retriever_type="faiss",
+ )
+
+ # Load memory bank to get embeddings for FAISS
+ from hag.config import MemoryBankConfig
+
+ mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {})))
+ mb.load(args.memory_bank)
+
+ # Build FAISS index from memory bank embeddings
+ import numpy as np
+
+ embeddings_np = mb.embeddings.T.numpy().astype(np.float32) # (N, d)
+ faiss_ret = FAISSRetriever(top_k=args.top_k)
+ faiss_ret.build_index(embeddings_np, mb.passages)
+
+ encoder = Encoder(pipeline_config.encoder)
+ generator = Generator(pipeline_config.generator)
+
+ pipeline = RAGPipeline(
+ config=pipeline_config,
+ encoder=encoder,
+ generator=generator,
+ faiss_retriever=faiss_ret,
+ )
+
+ result = pipeline.run(args.question)
+ print(f"\nQuestion: {result.question}")
+ print(f"Answer: {result.answer}")
+ print(f"\nRetrieved passages:")
+ for i, p in enumerate(result.retrieved_passages):
+ print(f" [{i+1}] {p[:200]}...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_eval.py b/scripts/run_eval.py
new file mode 100644
index 0000000..713b3c2
--- /dev/null
+++ b/scripts/run_eval.py
@@ -0,0 +1,90 @@
+"""Run evaluation on a dataset with either FAISS or Hopfield retrieval.
+
+Usage:
+ python scripts/run_eval.py --config configs/hotpotqa.yaml --memory-bank data/memory_bank.pt --dataset hotpotqa --split validation --max-samples 500
+"""
+
+import argparse
+import json
+import logging
+
+import yaml
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.metrics import evaluate_dataset
+from hag.pipeline import RAGPipeline
+from hag.retriever_faiss import FAISSRetriever
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Run HAG/RAG evaluation")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--dataset", type=str, default="hotpotqa")
+ parser.add_argument("--split", type=str, default="validation")
+ parser.add_argument("--max-samples", type=int, default=500)
+ parser.add_argument("--output", type=str, default="results.json")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ pipeline_config = PipelineConfig(
+ hopfield=HopfieldConfig(**cfg.get("hopfield", {})),
+ memory=MemoryBankConfig(**cfg.get("memory", {})),
+ encoder=EncoderConfig(**cfg.get("encoder", {})),
+ generator=GeneratorConfig(**cfg.get("generator", {})),
+ retriever_type=cfg.get("retriever_type", "hopfield"),
+ )
+
+ # Load memory bank
+ mb = MemoryBank(pipeline_config.memory)
+ mb.load(args.memory_bank)
+
+ # Build pipeline
+ encoder = Encoder(pipeline_config.encoder)
+ generator = Generator(pipeline_config.generator)
+ pipeline = RAGPipeline(
+ config=pipeline_config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+
+ # Load dataset
+ from datasets import load_dataset
+
+ logger.info("Loading dataset: %s / %s", args.dataset, args.split)
+ ds = load_dataset(args.dataset, split=args.split)
+ if args.max_samples and len(ds) > args.max_samples:
+ ds = ds.select(range(args.max_samples))
+
+ questions = [ex["question"] for ex in ds]
+ gold_answers = [ex["answer"] for ex in ds]
+
+ # Run evaluation
+ logger.info("Running evaluation on %d questions", len(questions))
+ results = pipeline.run_batch(questions)
+ metrics = evaluate_dataset(results, gold_answers)
+
+ logger.info("Results: %s", metrics)
+
+ with open(args.output, "w") as f:
+ json.dump(metrics, f, indent=2)
+ logger.info("Results saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/run_hag.py b/scripts/run_hag.py
new file mode 100644
index 0000000..4cacd1a
--- /dev/null
+++ b/scripts/run_hag.py
@@ -0,0 +1,79 @@
+"""Run HAG (Hopfield-Augmented Generation) on a question.
+
+Usage:
+ python scripts/run_hag.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?"
+"""
+
+import argparse
+import logging
+
+import yaml
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import Encoder
+from hag.generator import Generator
+from hag.memory_bank import MemoryBank
+from hag.pipeline import RAGPipeline
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Run HAG")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--question", type=str, required=True)
+ parser.add_argument("--beta", type=float, default=None)
+ parser.add_argument("--max-iter", type=int, default=None)
+ parser.add_argument("--top-k", type=int, default=None)
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ hopfield_cfg = cfg.get("hopfield", {})
+ if args.beta is not None:
+ hopfield_cfg["beta"] = args.beta
+ if args.max_iter is not None:
+ hopfield_cfg["max_iter"] = args.max_iter
+ if args.top_k is not None:
+ hopfield_cfg["top_k"] = args.top_k
+
+ pipeline_config = PipelineConfig(
+ hopfield=HopfieldConfig(**hopfield_cfg),
+ memory=MemoryBankConfig(**cfg.get("memory", {})),
+ encoder=EncoderConfig(**cfg.get("encoder", {})),
+ generator=GeneratorConfig(**cfg.get("generator", {})),
+ retriever_type="hopfield",
+ )
+
+ mb = MemoryBank(pipeline_config.memory)
+ mb.load(args.memory_bank)
+
+ encoder = Encoder(pipeline_config.encoder)
+ generator = Generator(pipeline_config.generator)
+
+ pipeline = RAGPipeline(
+ config=pipeline_config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+
+ result = pipeline.run(args.question)
+ print(f"\nQuestion: {result.question}")
+ print(f"Answer: {result.answer}")
+ print(f"\nRetrieved passages:")
+ for i, p in enumerate(result.retrieved_passages):
+ print(f" [{i+1}] {p[:200]}...")
+
+
+if __name__ == "__main__":
+ main()
diff --git a/scripts/visualize_trajectory.py b/scripts/visualize_trajectory.py
new file mode 100644
index 0000000..e4ba902
--- /dev/null
+++ b/scripts/visualize_trajectory.py
@@ -0,0 +1,80 @@
+"""UMAP visualization of query trajectory in Hopfield retrieval.
+
+Usage:
+ python scripts/visualize_trajectory.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?"
+"""
+
+import argparse
+import logging
+
+import numpy as np
+import torch
+import yaml
+
+from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig
+from hag.encoder import Encoder
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+
+logging.basicConfig(level=logging.INFO)
+logger = logging.getLogger(__name__)
+
+
+def main() -> None:
+ parser = argparse.ArgumentParser(description="Visualize Hopfield query trajectory")
+ parser.add_argument("--config", type=str, default="configs/default.yaml")
+ parser.add_argument("--memory-bank", type=str, required=True)
+ parser.add_argument("--question", type=str, required=True)
+ parser.add_argument("--output", type=str, default="trajectory.png")
+ args = parser.parse_args()
+
+ with open(args.config) as f:
+ cfg = yaml.safe_load(f)
+
+ hopfield_config = HopfieldConfig(**cfg.get("hopfield", {}))
+ memory_config = MemoryBankConfig(**cfg.get("memory", {}))
+ encoder_config = EncoderConfig(**cfg.get("encoder", {}))
+
+ mb = MemoryBank(memory_config)
+ mb.load(args.memory_bank)
+
+ encoder = Encoder(encoder_config)
+ hopfield = HopfieldRetrieval(hopfield_config)
+
+ query_emb = encoder.encode(args.question) # (1, d)
+ result = hopfield.retrieve(
+ query_emb, mb.embeddings, return_trajectory=True
+ )
+
+ # Gather all points for UMAP: memories + trajectory
+ memories_np = mb.embeddings.T.numpy() # (N, d)
+ trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d)
+ all_points = np.concatenate([memories_np, trajectory_np], axis=0)
+
+ # UMAP projection
+ import umap
+
+ reducer = umap.UMAP(n_components=2, random_state=42)
+ projected = reducer.fit_transform(all_points)
+
+ mem_proj = projected[: len(memories_np)]
+ traj_proj = projected[len(memories_np) :]
+
+ # Plot
+ import matplotlib.pyplot as plt
+
+ fig, ax = plt.subplots(figsize=(10, 8))
+ ax.scatter(mem_proj[:, 0], mem_proj[:, 1], c="lightgray", s=10, alpha=0.5, label="Memories")
+ ax.plot(traj_proj[:, 0], traj_proj[:, 1], "b-o", markersize=6, label="Query trajectory")
+ ax.scatter(traj_proj[0, 0], traj_proj[0, 1], c="green", s=100, zorder=5, label="q_0 (start)")
+ ax.scatter(traj_proj[-1, 0], traj_proj[-1, 1], c="red", s=100, zorder=5, label="q_T (final)")
+
+ ax.set_title(f"Hopfield Query Trajectory ({result.num_steps} steps)")
+ ax.legend()
+ plt.tight_layout()
+ plt.savefig(args.output, dpi=150)
+ logger.info("Trajectory visualization saved to %s", args.output)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/tests/__init__.py b/tests/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/tests/__init__.py
diff --git a/tests/test_energy.py b/tests/test_energy.py
new file mode 100644
index 0000000..256f2d8
--- /dev/null
+++ b/tests/test_energy.py
@@ -0,0 +1,80 @@
+"""Unit tests for the energy computation module."""
+
+import torch
+
+from hag.config import HopfieldConfig
+from hag.energy import (
+ compute_attention_entropy,
+ compute_energy_curve,
+ compute_energy_gap,
+ verify_monotonic_decrease,
+)
+from hag.hopfield import HopfieldRetrieval
+
+
+class TestEnergy:
+ """Tests for energy computation and analysis utilities."""
+
+ def test_energy_computation_matches_formula(self) -> None:
+ """Manually compute E(q) and verify it matches the module output."""
+ torch.manual_seed(42)
+ d, N = 16, 10
+ memory = torch.randn(d, N)
+ query = torch.randn(1, d)
+ beta = 2.0
+
+ config = HopfieldConfig(beta=beta)
+ hopfield = HopfieldRetrieval(config)
+ computed = hopfield.compute_energy(query, memory)
+
+ # Manual computation:
+ # E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+ logits = beta * (query @ memory) # (1, N)
+ expected = (
+ -1.0 / beta * torch.logsumexp(logits, dim=-1)
+ + 0.5 * (query**2).sum(dim=-1)
+ )
+ assert torch.allclose(computed, expected, atol=1e-5)
+
+ def test_verify_monotonic(self) -> None:
+ """Monotonically decreasing energy curve should pass."""
+ energies = [10.0, 8.0, 6.5, 6.4, 6.39]
+ assert verify_monotonic_decrease(energies) is True
+
+ def test_verify_monotonic_fails(self) -> None:
+ """Non-monotonic energy curve should fail."""
+ energies = [10.0, 8.0, 9.0, 6.0]
+ assert verify_monotonic_decrease(energies) is False
+
+ def test_attention_entropy(self) -> None:
+ """Uniform attention should have max entropy, one-hot should have zero."""
+ uniform = torch.ones(100) / 100
+ one_hot = torch.zeros(100)
+ one_hot[0] = 1.0
+ assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot)
+ assert compute_attention_entropy(one_hot) < 0.01
+
+ def test_energy_curve_extraction(self) -> None:
+ """Energy curve should be extractable from a HopfieldResult."""
+ torch.manual_seed(42)
+ d, N = 32, 50
+ memory = torch.randn(d, N)
+ query = torch.randn(1, d)
+ config = HopfieldConfig(beta=1.0, max_iter=5)
+ hopfield = HopfieldRetrieval(config)
+
+ result = hopfield.retrieve(query, memory, return_energy=True)
+ curve = compute_energy_curve(result)
+ assert len(curve) > 1
+ assert all(isinstance(v, float) for v in curve)
+
+ def test_energy_gap(self) -> None:
+ """Energy gap should be positive when energy decreases."""
+ energies = [10.0, 8.0, 6.0]
+ gap = compute_energy_gap(energies)
+ assert gap == 4.0
+
+ def test_energy_gap_single_point(self) -> None:
+ """Energy gap with fewer than 2 points should be 0."""
+ assert compute_energy_gap([5.0]) == 0.0
+ assert compute_energy_gap([]) == 0.0
diff --git a/tests/test_hopfield.py b/tests/test_hopfield.py
new file mode 100644
index 0000000..246b120
--- /dev/null
+++ b/tests/test_hopfield.py
@@ -0,0 +1,147 @@
+"""Unit tests for the core Hopfield retrieval module."""
+
+import torch
+import torch.nn.functional as F
+
+from hag.config import HopfieldConfig
+from hag.hopfield import HopfieldRetrieval
+
+
+class TestHopfieldRetrieval:
+ """Test the core Hopfield module with synthetic data on CPU."""
+
+ def setup_method(self) -> None:
+ """Create small synthetic memory bank and queries."""
+ torch.manual_seed(42)
+ self.d = 64 # embedding dim
+ self.N = 100 # number of memories
+ self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N)
+ self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d)
+ self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6)
+ self.hopfield = HopfieldRetrieval(self.config)
+
+ def test_output_shapes(self) -> None:
+ """attention_weights should be (1, N), converged_query should be (1, d)."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert result.attention_weights.shape == (1, self.N)
+ assert result.converged_query.shape == (1, self.d)
+
+ def test_attention_weights_sum_to_one(self) -> None:
+ """softmax output must sum to 1."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert torch.allclose(
+ result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5
+ )
+
+ def test_attention_weights_non_negative(self) -> None:
+ """All attention weights must be >= 0."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert (result.attention_weights >= 0).all()
+
+ def test_energy_monotonic_decrease(self) -> None:
+ """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property."""
+ result = self.hopfield.retrieve(
+ self.query, self.memory, return_energy=True
+ )
+ energies = [e.item() for e in result.energy_curve]
+ for i in range(len(energies) - 1):
+ assert energies[i + 1] <= energies[i] + 1e-6, (
+ f"Energy increased at step {i}: {energies[i]} -> {energies[i + 1]}"
+ )
+
+ def test_convergence(self) -> None:
+ """With enough iterations, query should converge (delta < threshold)."""
+ config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(
+ self.query, self.memory, return_trajectory=True
+ )
+ # Final two queries should be very close
+ q_last = result.trajectory[-1]
+ q_prev = result.trajectory[-2]
+ delta = torch.norm(q_last - q_prev)
+ assert delta < 1e-4, f"Did not converge: delta={delta}"
+
+ def test_high_beta_sharp_retrieval(self) -> None:
+ """Higher beta should produce sharper (lower entropy) attention."""
+ low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5))
+ high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5))
+
+ result_low = low_beta.retrieve(self.query, self.memory)
+ result_high = high_beta.retrieve(self.query, self.memory)
+
+ entropy_low = -(
+ result_low.attention_weights * result_low.attention_weights.log()
+ ).sum()
+ entropy_high = -(
+ result_high.attention_weights * result_high.attention_weights.log()
+ ).sum()
+
+ assert entropy_high < entropy_low, "Higher beta should give lower entropy"
+
+ def test_single_memory_converges_to_it(self) -> None:
+ """With N=1, retrieval should converge to the single memory."""
+ single_memory = F.normalize(torch.randn(self.d, 1), dim=0)
+ result = self.hopfield.retrieve(self.query, single_memory)
+ assert torch.allclose(
+ result.attention_weights, torch.ones(1, 1), atol=1e-5
+ )
+
+ def test_query_near_memory_retrieves_it(self) -> None:
+ """If query ~= memory_i, attention should peak at index i."""
+ target_idx = 42
+ query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match
+ config = HopfieldConfig(beta=10.0, max_iter=5)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(query, self.memory)
+ top_idx = result.attention_weights.argmax(dim=-1).item()
+ assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}"
+
+ def test_batch_retrieval(self) -> None:
+ """Should handle batch of queries."""
+ batch_query = F.normalize(torch.randn(8, self.d), dim=-1)
+ result = self.hopfield.retrieve(batch_query, self.memory)
+ assert result.attention_weights.shape == (8, self.N)
+ assert result.converged_query.shape == (8, self.d)
+
+ def test_iteration_refines_query(self) -> None:
+ """Multi-hop test: query starts far from target, iteration should bring it closer.
+
+ Setup: memory has two clusters. Query is near cluster A but the "answer"
+ is in cluster B, reachable through an intermediate memory that bridges both.
+ After iteration, the query should drift toward cluster B.
+ """
+ torch.manual_seed(0)
+ d = 32
+
+ # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...]
+ # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...]
+ # Bridge memory 10: between A and B
+ center_a = torch.zeros(d)
+ center_a[0] = 1.0
+ center_b = torch.zeros(d)
+ center_b[1] = 1.0
+ bridge = F.normalize(center_a + center_b, dim=0)
+
+ memories = []
+ for _ in range(5):
+ memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0))
+ for _ in range(5):
+ memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0))
+ memories.append(bridge)
+
+ M = torch.stack(memories, dim=1) # (d, 11)
+ q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0)
+
+ config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(q0, M, return_trajectory=True)
+
+ # After iteration, query should have drifted: its dot product with center_b
+ # should be higher than the initial query's dot product with center_b
+ initial_sim_b = (q0.squeeze() @ center_b).item()
+ final_sim_b = (result.converged_query.squeeze() @ center_b).item()
+ assert final_sim_b > initial_sim_b, (
+ f"Iteration should pull query toward cluster B: "
+ f"initial={initial_sim_b:.4f}, final={final_sim_b:.4f}"
+ )
diff --git a/tests/test_memory_bank.py b/tests/test_memory_bank.py
new file mode 100644
index 0000000..0087bbd
--- /dev/null
+++ b/tests/test_memory_bank.py
@@ -0,0 +1,65 @@
+"""Unit tests for the memory bank module."""
+
+import torch
+
+from hag.config import MemoryBankConfig
+from hag.memory_bank import MemoryBank
+
+
+class TestMemoryBank:
+ """Tests for MemoryBank construction, lookup, and persistence."""
+
+ def test_build_and_size(self) -> None:
+ """Build memory bank and verify size."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(100, 64)
+ passages = [f"passage {i}" for i in range(100)]
+ mb.build_from_embeddings(embeddings, passages)
+ assert mb.size == 100
+ assert mb.dim == 64
+
+ def test_normalization(self) -> None:
+ """If normalize=True, stored embeddings should have unit norm."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(50, 64) * 5 # non-unit norm
+ mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
+ norms = torch.norm(mb.embeddings, dim=0)
+ assert torch.allclose(norms, torch.ones(50), atol=1e-5)
+
+ def test_no_normalization(self) -> None:
+ """If normalize=False, stored embeddings keep original norms."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=False)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(50, 64) * 5
+ original_norms = torch.norm(embeddings, dim=-1)
+ mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
+ stored_norms = torch.norm(mb.embeddings, dim=0)
+ assert torch.allclose(stored_norms, original_norms, atol=1e-5)
+
+ def test_get_passages_by_indices(self) -> None:
+ """Index -> passage text lookup."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=False)
+ mb = MemoryBank(config)
+ passages = [f"passage {i}" for i in range(100)]
+ mb.build_from_embeddings(torch.randn(100, 64), passages)
+ result = mb.get_passages_by_indices(torch.tensor([0, 50, 99]))
+ assert result == ["passage 0", "passage 50", "passage 99"]
+
+ def test_save_and_load(self, tmp_path) -> None: # type: ignore[no-untyped-def]
+ """Save and reload memory bank, verify contents match."""
+ config = MemoryBankConfig(embedding_dim=32, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(20, 32)
+ passages = [f"text {i}" for i in range(20)]
+ mb.build_from_embeddings(embeddings, passages)
+
+ save_path = str(tmp_path / "mb.pt")
+ mb.save(save_path)
+
+ mb2 = MemoryBank(config)
+ mb2.load(save_path)
+ assert mb2.size == 20
+ assert mb2.passages == passages
+ assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6)
diff --git a/tests/test_metrics.py b/tests/test_metrics.py
new file mode 100644
index 0000000..4a18bec
--- /dev/null
+++ b/tests/test_metrics.py
@@ -0,0 +1,54 @@
+"""Unit tests for evaluation metrics."""
+
+from hag.metrics import exact_match, f1_score, retrieval_recall_at_k
+
+
+class TestMetrics:
+ """Tests for EM, F1, and retrieval recall metrics."""
+
+ def test_exact_match_basic(self) -> None:
+ """Basic exact match: case insensitive."""
+ assert exact_match("Paris", "paris") == 1.0
+ assert exact_match("Paris", "London") == 0.0
+
+ def test_exact_match_normalization(self) -> None:
+ """Should strip whitespace, lowercase, remove articles."""
+ assert exact_match(" The Paris ", "paris") == 1.0
+ assert exact_match("A dog", "dog") == 1.0
+
+ def test_exact_match_empty(self) -> None:
+ """Empty strings should match."""
+ assert exact_match("", "") == 1.0
+
+ def test_f1_score_perfect(self) -> None:
+ """Identical strings should have F1 = 1.0."""
+ assert f1_score("the cat sat", "the cat sat") == 1.0
+
+ def test_f1_score_partial(self) -> None:
+ """Partial overlap should give 0 < F1 < 1."""
+ score = f1_score("the cat sat on the mat", "the cat sat")
+ assert 0.5 < score < 1.0
+
+ def test_f1_score_no_overlap(self) -> None:
+ """No common tokens should give F1 = 0."""
+ assert f1_score("hello world", "foo bar") == 0.0
+
+ def test_f1_score_empty(self) -> None:
+ """Two empty strings should have F1 = 1.0."""
+ assert f1_score("", "") == 1.0
+
+ def test_retrieval_recall(self) -> None:
+ """Standard retrieval recall computation."""
+ assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2 / 3
+
+ def test_retrieval_recall_perfect(self) -> None:
+ """All gold passages retrieved."""
+ assert retrieval_recall_at_k([1, 2, 3], [1, 2, 3], k=3) == 1.0
+
+ def test_retrieval_recall_none(self) -> None:
+ """No gold passages retrieved."""
+ assert retrieval_recall_at_k([4, 5, 6], [1, 2, 3], k=3) == 0.0
+
+ def test_retrieval_recall_empty_gold(self) -> None:
+ """No gold passages means perfect recall by convention."""
+ assert retrieval_recall_at_k([1, 2, 3], [], k=3) == 1.0
diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py
new file mode 100644
index 0000000..3ac81ce
--- /dev/null
+++ b/tests/test_pipeline.py
@@ -0,0 +1,132 @@
+"""Integration tests for the RAG/HAG pipeline using mock encoder and generator."""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from hag.config import (
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.encoder import FakeEncoder
+from hag.generator import FakeGenerator
+from hag.memory_bank import MemoryBank
+from hag.pipeline import RAGPipeline
+from hag.retriever_faiss import FAISSRetriever
+
+
+def _make_memory_bank(n: int = 20, d: int = 64) -> MemoryBank:
+ """Create a small memory bank with random embeddings."""
+ torch.manual_seed(123)
+ config = MemoryBankConfig(embedding_dim=d, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(n, d)
+ passages = [f"passage {i} content" for i in range(n)]
+ mb.build_from_embeddings(embeddings, passages)
+ return mb
+
+
+def _make_faiss_retriever(n: int = 20, d: int = 64) -> FAISSRetriever:
+ """Create a small FAISS retriever with random embeddings."""
+ np.random.seed(123)
+ embeddings = np.random.randn(n, d).astype(np.float32)
+ passages = [f"passage {i} content" for i in range(n)]
+ retriever = FAISSRetriever(top_k=3)
+ retriever.build_index(embeddings, passages)
+ return retriever
+
+
+class TestPipeline:
+ """Integration tests using mock encoder and generator."""
+
+ def test_pipeline_runs_end_to_end_hopfield(self) -> None:
+ """Mock encoder + mock LLM + Hopfield retriever -> produces an answer."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3),
+ memory=MemoryBankConfig(embedding_dim=64, normalize=True),
+ retriever_type="hopfield",
+ )
+ mb = _make_memory_bank(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+ result = pipeline.run("What is the capital of France?")
+
+ assert isinstance(result.answer, str)
+ assert len(result.answer) > 0
+ assert len(result.retrieved_passages) == 3
+
+ def test_pipeline_runs_end_to_end_faiss(self) -> None:
+ """Mock encoder + mock LLM + FAISS retriever -> produces an answer."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(top_k=3),
+ retriever_type="faiss",
+ )
+ faiss_ret = _make_faiss_retriever(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ faiss_retriever=faiss_ret,
+ )
+ result = pipeline.run("What is the capital of France?")
+
+ assert isinstance(result.answer, str)
+ assert len(result.answer) > 0
+ assert len(result.retrieved_passages) == 3
+
+ def test_pipeline_retrieval_result_included(self) -> None:
+ """PipelineResult should contain retrieval metadata."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3),
+ memory=MemoryBankConfig(embedding_dim=64, normalize=True),
+ retriever_type="hopfield",
+ )
+ mb = _make_memory_bank(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+ result = pipeline.run("Test question?")
+
+ assert result.retrieval_result is not None
+ assert result.retrieval_result.scores is not None
+ assert result.retrieval_result.indices is not None
+ assert result.question == "Test question?"
+
+ def test_pipeline_batch(self) -> None:
+ """run_batch should return results for each question."""
+ config = PipelineConfig(
+ hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3),
+ memory=MemoryBankConfig(embedding_dim=64, normalize=True),
+ retriever_type="hopfield",
+ )
+ mb = _make_memory_bank(n=20, d=64)
+ encoder = FakeEncoder(dim=64)
+ generator = FakeGenerator()
+
+ pipeline = RAGPipeline(
+ config=config,
+ encoder=encoder,
+ generator=generator,
+ memory_bank=mb,
+ )
+ questions = ["Q1?", "Q2?", "Q3?"]
+ results = pipeline.run_batch(questions)
+ assert len(results) == 3
+ assert all(r.answer == "mock answer" for r in results)
diff --git a/tests/test_retriever.py b/tests/test_retriever.py
new file mode 100644
index 0000000..fa96dca
--- /dev/null
+++ b/tests/test_retriever.py
@@ -0,0 +1,149 @@
+"""Unit tests for both FAISS and Hopfield retrievers."""
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from hag.config import HopfieldConfig, MemoryBankConfig
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+from hag.retriever_faiss import FAISSRetriever
+from hag.retriever_hopfield import HopfieldRetriever
+
+
+class TestHopfieldRetriever:
+ """Tests for the Hopfield-based retriever."""
+
+ def setup_method(self) -> None:
+ """Set up a small memory bank and Hopfield retriever."""
+ torch.manual_seed(42)
+ self.d = 64
+ self.N = 50
+ self.top_k = 3
+
+ config = MemoryBankConfig(embedding_dim=self.d, normalize=True)
+ self.mb = MemoryBank(config)
+ embeddings = torch.randn(self.N, self.d)
+ self.passages = [f"passage {i}" for i in range(self.N)]
+ self.mb.build_from_embeddings(embeddings, self.passages)
+
+ hopfield_config = HopfieldConfig(beta=2.0, max_iter=5)
+ hopfield = HopfieldRetrieval(hopfield_config)
+ self.retriever = HopfieldRetriever(hopfield, self.mb, top_k=self.top_k)
+
+ def test_retrieves_correct_number_of_passages(self) -> None:
+ """top_k=3 should return exactly 3 passages."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query)
+ assert len(result.passages) == self.top_k
+ assert result.scores.shape == (self.top_k,)
+ assert result.indices.shape == (self.top_k,)
+
+ def test_scores_are_sorted_descending(self) -> None:
+ """Returned scores should be in descending order."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query)
+ scores = result.scores.tolist()
+ assert scores == sorted(scores, reverse=True)
+
+ def test_passages_match_indices(self) -> None:
+ """Returned passage texts should correspond to returned indices."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query)
+ expected_passages = [self.passages[i] for i in result.indices.tolist()]
+ assert result.passages == expected_passages
+
+ def test_analysis_mode_returns_hopfield_result(self) -> None:
+ """When return_analysis=True, hopfield_result should be populated."""
+ query = F.normalize(torch.randn(1, self.d), dim=-1)
+ result = self.retriever.retrieve(query, return_analysis=True)
+ assert result.hopfield_result is not None
+ assert result.hopfield_result.energy_curve is not None
+ assert result.hopfield_result.trajectory is not None
+
+
+class TestFAISSRetriever:
+ """Tests for the FAISS baseline retriever."""
+
+ def setup_method(self) -> None:
+ """Set up a small FAISS index."""
+ np.random.seed(42)
+ self.d = 64
+ self.N = 50
+ self.top_k = 3
+
+ self.embeddings = np.random.randn(self.N, self.d).astype(np.float32)
+ self.passages = [f"passage {i}" for i in range(self.N)]
+
+ self.retriever = FAISSRetriever(top_k=self.top_k)
+ self.retriever.build_index(self.embeddings.copy(), self.passages)
+
+ def test_retrieves_correct_number(self) -> None:
+ """top_k=3 should return exactly 3 passages."""
+ query = np.random.randn(self.d).astype(np.float32)
+ result = self.retriever.retrieve(query)
+ assert len(result.passages) == self.top_k
+
+ def test_nearest_neighbor_is_self(self) -> None:
+ """If query = passage_i's embedding, top-1 should be passage_i."""
+ # Use the original embedding (before normalization in build_index)
+ # Rebuild with fresh copy so we know the normalization state
+ embeddings = self.embeddings.copy()
+ retriever = FAISSRetriever(top_k=1)
+ retriever.build_index(embeddings.copy(), self.passages)
+
+ # Use the normalized version as query (FAISS normalizes internally)
+ target_idx = 10
+ query = self.embeddings[target_idx].copy()
+ result = retriever.retrieve(query)
+ assert result.indices[0].item() == target_idx
+
+ def test_scores_sorted_descending(self) -> None:
+ """Returned scores should be in descending order."""
+ query = np.random.randn(self.d).astype(np.float32)
+ result = self.retriever.retrieve(query)
+ scores = result.scores.tolist()
+ assert scores == sorted(scores, reverse=True)
+
+
+class TestRetrieverComparison:
+ """Compare FAISS and Hopfield retrievers."""
+
+ def test_same_top1_for_obvious_query(self) -> None:
+ """When query is very close to one memory, both should agree on top-1."""
+ torch.manual_seed(42)
+ np.random.seed(42)
+ d = 64
+ N = 50
+ target_idx = 25
+
+ # Create embeddings
+ embeddings_np = np.random.randn(N, d).astype(np.float32)
+ # Normalize
+ norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True)
+ embeddings_np = embeddings_np / norms
+
+ # Query is exactly the target embedding
+ query_np = embeddings_np[target_idx].copy()
+ query_torch = torch.from_numpy(query_np).unsqueeze(0) # (1, d)
+
+ passages = [f"passage {i}" for i in range(N)]
+
+ # FAISS retriever
+ faiss_ret = FAISSRetriever(top_k=1)
+ faiss_ret.build_index(embeddings_np.copy(), passages)
+ faiss_result = faiss_ret.retrieve(query_np.copy())
+
+ # Hopfield retriever
+ mb_config = MemoryBankConfig(embedding_dim=d, normalize=False)
+ mb = MemoryBank(mb_config)
+ mb.build_from_embeddings(
+ torch.from_numpy(embeddings_np), passages
+ )
+ hop_config = HopfieldConfig(beta=10.0, max_iter=5)
+ hopfield = HopfieldRetrieval(hop_config)
+ hop_ret = HopfieldRetriever(hopfield, mb, top_k=1)
+ hop_result = hop_ret.retrieve(query_torch)
+
+ assert faiss_result.indices[0].item() == target_idx
+ assert hop_result.indices[0].item() == target_idx