summaryrefslogtreecommitdiff
path: root/CLAUDE.md
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
committerYurenHao0426 <blackhao0426@gmail.com>2026-02-15 18:19:50 +0000
commitc90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch)
tree43edac8013fec4e65a0b9cddec5314489b4aafc2 /CLAUDE.md
Initial implementation of HAG (Hopfield-Augmented Generation)HEADmaster
Core Hopfield retrieval module with energy-based convergence guarantees, memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end pipeline. All 45 tests passing on CPU with synthetic data. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'CLAUDE.md')
-rw-r--r--CLAUDE.md870
1 files changed, 870 insertions, 0 deletions
diff --git a/CLAUDE.md b/CLAUDE.md
new file mode 100644
index 0000000..c99b846
--- /dev/null
+++ b/CLAUDE.md
@@ -0,0 +1,870 @@
+# CLAUDE.md — Hopfield-Augmented Generation (HAG)
+
+## Project Overview
+
+We are building **HAG (Hopfield-Augmented Generation)**, a retrieval-augmented generation system that replaces standard vector similarity search (FAISS top-k) with **iterative Modern Hopfield Network retrieval**. This is a research project targeting a top ML venue (ICML/NeurIPS).
+
+### Core Idea
+
+Standard RAG does:
+```
+y = LLM(x, top_k_passages(FAISS(Enc(x))))
+```
+
+HAG does:
+```
+q₀ = Enc(x)
+qₜ₊₁ = M · softmax(β · Mᵀ · qₜ) # iterate T steps
+α_T = softmax(β · Mᵀ · q_T) # final attention weights
+retrieved_passages = top_k(α_T) # use weights to select passages
+y = LLM(x, retrieved_passages)
+```
+
+Where:
+- `M ∈ ℝ^{d × N}` is a memory bank of passage embeddings (same encoder as vanilla RAG)
+- `q` is the query embedding
+- `β` is an inverse temperature parameter
+- The iterative update is the retrieval dynamics of a **Modern Continuous Hopfield Network** (Ramsauer et al., 2020)
+
+### Why This Matters
+
+1. **Iterative refinement**: RAG retrieves once. HAG refines the query in embedding space over T steps, enabling implicit multi-hop reasoning.
+2. **Energy-based theory**: Retrieval is energy minimization with convergence guarantees (energy decreases monotonically).
+3. **Differentiable**: Unlike FAISS top-k, the entire retrieval is differentiable — enabling future end-to-end training.
+4. **No structural preprocessing**: Unlike GraphRAG methods (LinearRAG, HippoRAG), no NER, no graph construction, no PageRank. Just an embedding matrix.
+
+### What We Are NOT Doing (Scope Boundaries)
+
+- NOT modifying the LLM architecture (no cross-attention injection)
+- NOT training the memory bank in phase 1 (M is frozen embeddings)
+- NOT building a graph structure
+- NOT doing anything that requires the LLM during retrieval
+- The LLM is frozen and used only for final generation
+
+---
+
+## Repository Structure
+
+```
+hag/
+├── CLAUDE.md # This file
+├── pyproject.toml # Project config, dependencies
+├── README.md # Project README
+├── hag/
+│ ├── __init__.py
+│ ├── config.py # All hyperparameters and configs (dataclass-based)
+│ ├── encoder.py # Wrapper for encoding queries/passages
+│ ├── memory_bank.py # Memory bank construction and management
+│ ├── hopfield.py # Core: Hopfield retrieval module
+│ ├── retriever_faiss.py # Baseline: FAISS top-k retriever
+│ ├── retriever_hopfield.py # HAG retriever (wraps hopfield.py + passage selection)
+│ ├── generator.py # LLM generation wrapper
+│ ├── pipeline.py # End-to-end RAG/HAG pipeline
+│ ├── metrics.py # Evaluation metrics (EM, F1, retrieval recall)
+│ └── energy.py # Energy computation and analysis utilities
+├── scripts/
+│ ├── build_memory_bank.py # Offline: encode corpus → memory bank
+│ ├── run_eval.py # Run evaluation on a dataset
+│ ├── run_baseline.py # Run vanilla RAG baseline
+│ ├── run_hag.py # Run HAG
+│ ├── analyze_energy.py # Analyze energy curves and convergence
+│ └── visualize_trajectory.py # UMAP visualization of query trajectory
+├── tests/
+│ ├── test_hopfield.py # Unit tests for Hopfield module
+│ ├── test_memory_bank.py # Unit tests for memory bank
+│ ├── test_retriever.py # Unit tests for both retrievers
+│ ├── test_pipeline.py # Integration tests for pipeline
+│ ├── test_metrics.py # Unit tests for metrics
+│ └── test_energy.py # Unit tests for energy module
+├── configs/
+│ ├── default.yaml # Default experiment config
+│ ├── hotpotqa.yaml # HotpotQA-specific config
+│ ├── musique.yaml # MuSiQue-specific config
+│ └── 2wikimhqa.yaml # 2WikiMultiHopQA-specific config
+└── notebooks/
+ └── demo.ipynb # Quick demo notebook
+```
+
+---
+
+## Module Specifications
+
+### 1. `hag/config.py`
+
+Use Python dataclasses. All hyperparameters in one place.
+
+```python
+@dataclass
+class HopfieldConfig:
+ beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval
+ max_iter: int = 5 # Maximum Hopfield iteration steps
+ conv_threshold: float = 1e-4 # Convergence: stop if ||q_{t+1} - q_t|| < threshold
+ top_k: int = 5 # Number of passages to retrieve from final attention weights
+
+@dataclass
+class MemoryBankConfig:
+ embedding_dim: int = 768 # Must match encoder output dim
+ normalize: bool = True # L2-normalize embeddings in memory bank
+
+@dataclass
+class EncoderConfig:
+ model_name: str = "facebook/contriever-msmarco" # Default encoder
+ max_length: int = 512
+ batch_size: int = 64
+
+@dataclass
+class GeneratorConfig:
+ model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
+ max_new_tokens: int = 128
+ temperature: float = 0.0 # Greedy decoding for reproducibility
+
+@dataclass
+class PipelineConfig:
+ hopfield: HopfieldConfig = field(default_factory=HopfieldConfig)
+ memory: MemoryBankConfig = field(default_factory=MemoryBankConfig)
+ encoder: EncoderConfig = field(default_factory=EncoderConfig)
+ generator: GeneratorConfig = field(default_factory=GeneratorConfig)
+ retriever_type: str = "hopfield" # "hopfield" or "faiss"
+```
+
+### 2. `hag/hopfield.py` — THE CORE MODULE
+
+This is the most important file. Implement the Modern Continuous Hopfield Network retrieval.
+
+```python
+class HopfieldRetrieval:
+ """
+ Modern Continuous Hopfield Network for memory retrieval.
+
+ Given memory bank M ∈ ℝ^{d × N} and query q ∈ ℝ^d:
+ 1. Compute attention: α = softmax(β * Mᵀ @ q)
+ 2. Update query: q_new = M @ α
+ 3. Repeat until convergence or max_iter
+
+ The energy function is:
+ E(q) = -1/β * log(Σ_i exp(β * q^T m_i)) + 1/2 * ||q||^2
+
+ Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease)
+ """
+
+ def __init__(self, config: HopfieldConfig):
+ ...
+
+ def retrieve(
+ self,
+ query: torch.Tensor, # (d,) or (batch, d)
+ memory: torch.Tensor, # (d, N) — the memory bank
+ return_trajectory: bool = False,
+ return_energy: bool = False,
+ ) -> HopfieldResult:
+ """
+ Run iterative Hopfield retrieval.
+
+ Returns HopfieldResult with:
+ - attention_weights: (N,) or (batch, N) — final α_T
+ - converged_query: (d,) or (batch, d) — final q_T
+ - num_steps: int — actual steps taken
+ - trajectory: optional list of q_t at each step
+ - energy_curve: optional list of E(q_t) at each step
+ """
+ ...
+
+ def compute_energy(
+ self,
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ ) -> torch.Tensor:
+ """
+ E(q) = -1/β * log(Σ_i exp(β * qᵀ m_i)) + 1/2 * ||q||^2
+
+ This is the log-sum-exp energy of the Modern Hopfield Network.
+ """
+ ...
+```
+
+**CRITICAL implementation details for `retrieve()`:**
+
+```python
+def retrieve(self, query, memory, return_trajectory=False, return_energy=False):
+ # Ensure correct shapes
+ # query: (batch, d) — unsqueeze if needed
+ # memory: (d, N)
+
+ q = query.clone()
+ trajectory = [q.clone()] if return_trajectory else None
+ energies = [self.compute_energy(q, memory)] if return_energy else None
+
+ for t in range(self.config.max_iter):
+ # Core Hopfield update
+ logits = self.config.beta * (memory.T @ q.T).T # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ q_new = (memory @ alpha.T).T # (batch, d)
+
+ # Check convergence
+ delta = torch.norm(q_new - q, dim=-1).max()
+ q = q_new
+
+ if return_trajectory:
+ trajectory.append(q.clone())
+ if return_energy:
+ energies.append(self.compute_energy(q, memory))
+
+ if delta < self.config.conv_threshold:
+ break
+
+ # Final attention weights (recompute to ensure consistency)
+ logits = self.config.beta * (memory.T @ q.T).T
+ alpha = torch.softmax(logits, dim=-1)
+
+ return HopfieldResult(
+ attention_weights=alpha,
+ converged_query=q,
+ num_steps=t + 1,
+ trajectory=trajectory,
+ energy_curve=energies,
+ )
+```
+
+**Pay special attention to:**
+- Numerical stability: use `torch.logsumexp` in energy computation
+- Batch support: all operations should work for both single query and batched queries
+- Memory efficiency: don't store trajectory/energy unless explicitly requested
+
+### 3. `hag/memory_bank.py`
+
+```python
+class MemoryBank:
+ """
+ Stores passage embeddings and provides lookup from indices back to text.
+
+ The memory bank is M ∈ ℝ^{d × N} where each column is a passage embedding.
+ Also maintains a mapping from column index → passage text for final retrieval.
+ """
+
+ def __init__(self, config: MemoryBankConfig):
+ self.embeddings: torch.Tensor = None # (d, N)
+ self.passages: List[str] = [] # N passages
+ self.config = config
+
+ def build_from_embeddings(self, embeddings: torch.Tensor, passages: List[str]):
+ """
+ embeddings: (N, d) — note: input is (N, d), internally stored as (d, N)
+ passages: list of N passage strings
+ """
+ if self.config.normalize:
+ embeddings = F.normalize(embeddings, dim=-1)
+ self.embeddings = embeddings.T # Store as (d, N) for efficient matmul
+ self.passages = passages
+
+ def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
+ """Given top-k indices, return corresponding passage texts."""
+ ...
+
+ def save(self, path: str): ...
+ def load(self, path: str): ...
+
+ @property
+ def size(self) -> int:
+ return self.embeddings.shape[1] if self.embeddings is not None else 0
+
+ @property
+ def dim(self) -> int:
+ return self.embeddings.shape[0] if self.embeddings is not None else 0
+```
+
+### 4. `hag/retriever_faiss.py`
+
+Vanilla RAG baseline retriever using FAISS.
+
+```python
+class FAISSRetriever:
+ """
+ Standard top-k retrieval using FAISS inner product search.
+ This is the baseline to compare against.
+ """
+
+ def __init__(self, top_k: int = 5):
+ self.index = None
+ self.passages: List[str] = []
+ self.top_k = top_k
+
+ def build_index(self, embeddings: np.ndarray, passages: List[str]):
+ """Build FAISS IndexFlatIP from (N, d) embeddings."""
+ ...
+
+ def retrieve(self, query: np.ndarray) -> RetrievalResult:
+ """
+ query: (d,) or (batch, d)
+ Returns RetrievalResult with top_k passage texts, scores, and indices.
+ """
+ ...
+```
+
+### 5. `hag/retriever_hopfield.py`
+
+```python
+class HopfieldRetriever:
+ """
+ Wraps HopfieldRetrieval + MemoryBank into a retriever interface.
+
+ The bridge between Hopfield's continuous retrieval and the discrete
+ passage selection needed for LLM prompting.
+ """
+
+ def __init__(self, hopfield: HopfieldRetrieval, memory_bank: MemoryBank, top_k: int = 5):
+ ...
+
+ def retrieve(self, query_embedding: torch.Tensor, return_analysis: bool = False) -> RetrievalResult:
+ """
+ 1. Run Hopfield iterative retrieval → get attention weights α_T
+ 2. Take top_k indices from α_T
+ 3. Look up corresponding passage texts from memory bank
+ 4. Optionally return trajectory and energy for analysis
+
+ Returns RetrievalResult with:
+ - passages: List[str]
+ - scores: attention weights for top-k
+ - indices: which memory slots were selected
+ - (optional) hopfield_result: full HopfieldResult for analysis
+ """
+ ...
+```
+
+### 6. `hag/pipeline.py`
+
+```python
+class RAGPipeline:
+ """
+ End-to-end pipeline: query → encode → retrieve → generate.
+ Supports both FAISS (baseline) and Hopfield (ours) retrieval.
+ """
+
+ def __init__(self, config: PipelineConfig):
+ self.encoder = Encoder(config.encoder)
+ self.generator = Generator(config.generator)
+
+ if config.retriever_type == "faiss":
+ self.retriever = FAISSRetriever(top_k=config.hopfield.top_k)
+ elif config.retriever_type == "hopfield":
+ hopfield = HopfieldRetrieval(config.hopfield)
+ self.retriever = HopfieldRetriever(hopfield, memory_bank, top_k=config.hopfield.top_k)
+
+ def run(self, question: str) -> PipelineResult:
+ """
+ Full pipeline:
+ 1. Encode question → query embedding
+ 2. Retrieve passages (FAISS or Hopfield)
+ 3. Format prompt with question + retrieved passages
+ 4. Generate answer with LLM
+ """
+ ...
+
+ def run_batch(self, questions: List[str]) -> List[PipelineResult]:
+ ...
+```
+
+### 7. `hag/metrics.py`
+
+```python
+def exact_match(prediction: str, ground_truth: str) -> float:
+ """Normalized exact match. Lowercase, strip, remove articles/punctuation."""
+ ...
+
+def f1_score(prediction: str, ground_truth: str) -> float:
+ """Token-level F1 between prediction and ground truth."""
+ ...
+
+def retrieval_recall_at_k(retrieved_indices: List[int], gold_indices: List[int], k: int) -> float:
+ """What fraction of gold passages appear in the retrieved top-k?"""
+ ...
+
+def evaluate_dataset(results: List[PipelineResult], gold_answers: List[str]) -> Dict[str, float]:
+ """Compute aggregate metrics over a dataset. Returns dict with EM, F1, etc."""
+ ...
+```
+
+### 8. `hag/energy.py`
+
+```python
+def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]:
+ """Extract energy values at each iteration step."""
+ ...
+
+def compute_energy_gap(energy_curve: List[float]) -> float:
+ """
+ ΔE = E(q_0) - E(q_T)
+ Larger gap = more refinement happened.
+ """
+ ...
+
+def verify_monotonic_decrease(energy_curve: List[float]) -> bool:
+ """Check that E(q_{t+1}) <= E(q_t) for all t. Should always be True."""
+ ...
+
+def compute_attention_entropy(attention_weights: torch.Tensor) -> float:
+ """
+ H(α) = -Σ_i α_i log α_i
+ Low entropy = sharp retrieval (confident).
+ High entropy = diffuse retrieval (uncertain).
+ """
+ ...
+```
+
+---
+
+## Data Types
+
+Define these as dataclasses or NamedTuples:
+
+```python
+@dataclass
+class HopfieldResult:
+ attention_weights: torch.Tensor # (batch, N) or (N,)
+ converged_query: torch.Tensor # (batch, d) or (d,)
+ num_steps: int
+ trajectory: Optional[List[torch.Tensor]] = None # list of q_t
+ energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t)
+
+@dataclass
+class RetrievalResult:
+ passages: List[str]
+ scores: torch.Tensor # top-k scores
+ indices: torch.Tensor # top-k indices
+ hopfield_result: Optional[HopfieldResult] = None
+
+@dataclass
+class PipelineResult:
+ question: str
+ answer: str
+ retrieved_passages: List[str]
+ retrieval_result: RetrievalResult
+```
+
+---
+
+## Test Specifications
+
+**ALL TESTS MUST RUN ON CPU. NO GPU REQUIRED.**
+
+Use small synthetic data (random tensors) for unit tests. Do not download any models or datasets in tests.
+
+### `tests/test_hopfield.py`
+
+```python
+class TestHopfieldRetrieval:
+ """Test the core Hopfield module with synthetic data on CPU."""
+
+ def setup_method(self):
+ """Create small synthetic memory bank and queries."""
+ torch.manual_seed(42)
+ self.d = 64 # embedding dim
+ self.N = 100 # number of memories
+ self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N)
+ self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d)
+ self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6)
+ self.hopfield = HopfieldRetrieval(self.config)
+
+ def test_output_shapes(self):
+ """attention_weights should be (1, N), converged_query should be (1, d)."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert result.attention_weights.shape == (1, self.N)
+ assert result.converged_query.shape == (1, self.d)
+
+ def test_attention_weights_sum_to_one(self):
+ """softmax output must sum to 1."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert torch.allclose(result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5)
+
+ def test_attention_weights_non_negative(self):
+ """All attention weights must be >= 0."""
+ result = self.hopfield.retrieve(self.query, self.memory)
+ assert (result.attention_weights >= 0).all()
+
+ def test_energy_monotonic_decrease(self):
+ """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property."""
+ result = self.hopfield.retrieve(self.query, self.memory, return_energy=True)
+ energies = [e.item() for e in result.energy_curve]
+ for i in range(len(energies) - 1):
+ assert energies[i+1] <= energies[i] + 1e-6, \
+ f"Energy increased at step {i}: {energies[i]} -> {energies[i+1]}"
+
+ def test_convergence(self):
+ """With enough iterations, query should converge (delta < threshold)."""
+ config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(self.query, self.memory, return_trajectory=True)
+ # Final two queries should be very close
+ q_last = result.trajectory[-1]
+ q_prev = result.trajectory[-2]
+ delta = torch.norm(q_last - q_prev)
+ assert delta < 1e-4, f"Did not converge: delta={delta}"
+
+ def test_high_beta_sharp_retrieval(self):
+ """Higher β should produce sharper (lower entropy) attention."""
+ low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5))
+ high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5))
+
+ result_low = low_beta.retrieve(self.query, self.memory)
+ result_high = high_beta.retrieve(self.query, self.memory)
+
+ entropy_low = -(result_low.attention_weights * result_low.attention_weights.log()).sum()
+ entropy_high = -(result_high.attention_weights * result_high.attention_weights.log()).sum()
+
+ assert entropy_high < entropy_low, "Higher β should give lower entropy"
+
+ def test_single_memory_converges_to_it(self):
+ """With N=1, retrieval should converge to the single memory."""
+ single_memory = F.normalize(torch.randn(self.d, 1), dim=0)
+ result = self.hopfield.retrieve(self.query, single_memory)
+ assert torch.allclose(result.attention_weights, torch.ones(1, 1), atol=1e-5)
+
+ def test_query_near_memory_retrieves_it(self):
+ """If query ≈ memory_i, attention should peak at index i."""
+ target_idx = 42
+ query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match
+ config = HopfieldConfig(beta=10.0, max_iter=5)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(query, self.memory)
+ top_idx = result.attention_weights.argmax(dim=-1).item()
+ assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}"
+
+ def test_batch_retrieval(self):
+ """Should handle batch of queries."""
+ batch_query = F.normalize(torch.randn(8, self.d), dim=-1)
+ result = self.hopfield.retrieve(batch_query, self.memory)
+ assert result.attention_weights.shape == (8, self.N)
+ assert result.converged_query.shape == (8, self.d)
+
+ def test_iteration_refines_query(self):
+ """
+ Multi-hop test: query starts far from target, iteration should bring it closer.
+
+ Setup: memory has two clusters. Query is near cluster A but the "answer"
+ is in cluster B, reachable through an intermediate memory that bridges both.
+ After iteration, the query should drift toward cluster B.
+ """
+ torch.manual_seed(0)
+ d = 32
+
+ # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...]
+ # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...]
+ # Bridge memory 10: between A and B
+ center_a = torch.zeros(d); center_a[0] = 1.0
+ center_b = torch.zeros(d); center_b[1] = 1.0
+ bridge = F.normalize(center_a + center_b, dim=0)
+
+ memories = []
+ for _ in range(5):
+ memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0))
+ for _ in range(5):
+ memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0))
+ memories.append(bridge)
+
+ M = torch.stack(memories, dim=1) # (d, 11)
+ q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0)
+
+ config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8)
+ hopfield = HopfieldRetrieval(config)
+ result = hopfield.retrieve(q0, M, return_trajectory=True)
+
+ # After iteration, query should have drifted: its dot product with center_b
+ # should be higher than the initial query's dot product with center_b
+ initial_sim_b = (q0.squeeze() @ center_b).item()
+ final_sim_b = (result.converged_query.squeeze() @ center_b).item()
+ assert final_sim_b > initial_sim_b, \
+ f"Iteration should pull query toward cluster B: initial={initial_sim_b:.4f}, final={final_sim_b:.4f}"
+```
+
+### `tests/test_memory_bank.py`
+
+```python
+class TestMemoryBank:
+
+ def test_build_and_size(self):
+ """Build memory bank and verify size."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(100, 64)
+ passages = [f"passage {i}" for i in range(100)]
+ mb.build_from_embeddings(embeddings, passages)
+ assert mb.size == 100
+ assert mb.dim == 64
+
+ def test_normalization(self):
+ """If normalize=True, stored embeddings should have unit norm."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(50, 64) * 5 # non-unit norm
+ mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)])
+ norms = torch.norm(mb.embeddings, dim=0)
+ assert torch.allclose(norms, torch.ones(50), atol=1e-5)
+
+ def test_get_passages_by_indices(self):
+ """Index → passage text lookup."""
+ config = MemoryBankConfig(embedding_dim=64, normalize=False)
+ mb = MemoryBank(config)
+ passages = [f"passage {i}" for i in range(100)]
+ mb.build_from_embeddings(torch.randn(100, 64), passages)
+ result = mb.get_passages_by_indices(torch.tensor([0, 50, 99]))
+ assert result == ["passage 0", "passage 50", "passage 99"]
+
+ def test_save_and_load(self, tmp_path):
+ """Save and reload memory bank, verify contents match."""
+ config = MemoryBankConfig(embedding_dim=32, normalize=True)
+ mb = MemoryBank(config)
+ embeddings = torch.randn(20, 32)
+ passages = [f"text {i}" for i in range(20)]
+ mb.build_from_embeddings(embeddings, passages)
+
+ save_path = str(tmp_path / "mb.pt")
+ mb.save(save_path)
+
+ mb2 = MemoryBank(config)
+ mb2.load(save_path)
+ assert mb2.size == 20
+ assert mb2.passages == passages
+ assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6)
+```
+
+### `tests/test_retriever.py`
+
+```python
+class TestHopfieldRetriever:
+
+ def test_retrieves_correct_number_of_passages(self):
+ """top_k=3 should return exactly 3 passages."""
+ ...
+
+ def test_scores_are_sorted_descending(self):
+ """Returned scores should be in descending order."""
+ ...
+
+ def test_passages_match_indices(self):
+ """Returned passage texts should correspond to returned indices."""
+ ...
+
+class TestFAISSRetriever:
+
+ def test_retrieves_correct_number(self):
+ ...
+
+ def test_nearest_neighbor_is_self(self):
+ """If query = passage_i's embedding, top-1 should be passage_i."""
+ ...
+
+class TestRetrieverComparison:
+
+ def test_same_top1_for_obvious_query(self):
+ """
+ When query is very close to one memory, both FAISS and Hopfield
+ should agree on the top-1 result.
+ """
+ ...
+```
+
+### `tests/test_metrics.py`
+
+```python
+class TestMetrics:
+
+ def test_exact_match_basic(self):
+ assert exact_match("Paris", "paris") == 1.0
+ assert exact_match("Paris", "London") == 0.0
+
+ def test_exact_match_normalization(self):
+ """Should strip whitespace, lowercase, remove articles."""
+ assert exact_match(" The Paris ", "paris") == 1.0
+ assert exact_match("A dog", "dog") == 1.0
+
+ def test_f1_score_perfect(self):
+ assert f1_score("the cat sat", "the cat sat") == 1.0
+
+ def test_f1_score_partial(self):
+ score = f1_score("the cat sat on the mat", "the cat sat")
+ assert 0.5 < score < 1.0
+
+ def test_f1_score_no_overlap(self):
+ assert f1_score("hello world", "foo bar") == 0.0
+
+ def test_retrieval_recall(self):
+ assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2/3
+```
+
+### `tests/test_energy.py`
+
+```python
+class TestEnergy:
+
+ def test_energy_computation_matches_formula(self):
+ """
+ Manually compute E(q) = -1/β * log(Σ exp(β * qᵀ m_i)) + 1/2 * ||q||^2
+ and verify it matches the module output.
+ """
+ torch.manual_seed(42)
+ d, N = 16, 10
+ memory = torch.randn(d, N)
+ query = torch.randn(1, d)
+ beta = 2.0
+
+ config = HopfieldConfig(beta=beta)
+ hopfield = HopfieldRetrieval(config)
+ computed = hopfield.compute_energy(query, memory)
+
+ # Manual computation
+ logits = beta * (memory.T @ query.T).T # (1, N)
+ expected = -1.0/beta * torch.logsumexp(logits, dim=-1) + 0.5 * (query ** 2).sum(dim=-1)
+ assert torch.allclose(computed, expected, atol=1e-5)
+
+ def test_verify_monotonic(self):
+ energies = [10.0, 8.0, 6.5, 6.4, 6.39]
+ assert verify_monotonic_decrease(energies) == True
+
+ def test_verify_monotonic_fails(self):
+ energies = [10.0, 8.0, 9.0, 6.0]
+ assert verify_monotonic_decrease(energies) == False
+
+ def test_attention_entropy(self):
+ """Uniform attention should have max entropy, one-hot should have zero."""
+ uniform = torch.ones(100) / 100
+ one_hot = torch.zeros(100); one_hot[0] = 1.0
+ assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot)
+ assert compute_attention_entropy(one_hot) < 0.01
+```
+
+### `tests/test_pipeline.py`
+
+```python
+class TestPipeline:
+ """
+ Integration tests using MOCK encoder and generator.
+ These test the wiring, not the model quality.
+ """
+
+ def test_pipeline_runs_end_to_end_hopfield(self):
+ """Mock encoder + mock LLM + Hopfield retriever → produces an answer string."""
+ ...
+
+ def test_pipeline_runs_end_to_end_faiss(self):
+ """Same but with FAISS retriever."""
+ ...
+
+ def test_pipeline_retrieval_result_included(self):
+ """PipelineResult should contain retrieval metadata."""
+ ...
+```
+
+**For pipeline tests, use mock/fake encoder and generator:**
+```python
+class FakeEncoder:
+ def encode(self, text: str) -> torch.Tensor:
+ # Deterministic hash-based embedding for testing
+ torch.manual_seed(hash(text) % 2**32)
+ return F.normalize(torch.randn(1, 64), dim=-1)
+
+class FakeGenerator:
+ def generate(self, prompt: str) -> str:
+ return "mock answer"
+```
+
+---
+
+## Dependencies
+
+```toml
+[project]
+name = "hag"
+version = "0.1.0"
+requires-python = ">=3.10"
+
+dependencies = [
+ "torch>=2.0",
+ "numpy>=1.24",
+ "faiss-cpu>=1.7",
+ "transformers>=4.36",
+ "datasets>=2.14",
+ "pyyaml>=6.0",
+ "tqdm>=4.65",
+]
+
+[project.optional-dependencies]
+dev = [
+ "pytest>=7.0",
+ "pytest-cov>=4.0",
+]
+eval = [
+ "scikit-learn>=1.3",
+ "umap-learn>=0.5",
+ "matplotlib>=3.7",
+]
+```
+
+---
+
+## Implementation Order
+
+**Phase 1: Core module (do this first)**
+1. `config.py` — all dataclasses
+2. `hopfield.py` — the core Hopfield retrieval
+3. `energy.py` — energy utilities
+4. `tests/test_hopfield.py` — run and pass all tests
+5. `tests/test_energy.py` — run and pass all tests
+
+**Phase 2: Memory and retrievers**
+6. `memory_bank.py`
+7. `retriever_hopfield.py`
+8. `retriever_faiss.py`
+9. `tests/test_memory_bank.py`
+10. `tests/test_retriever.py`
+
+**Phase 3: Metrics and pipeline**
+11. `metrics.py`
+12. `pipeline.py` (with mock encoder/generator for now)
+13. `tests/test_metrics.py`
+14. `tests/test_pipeline.py`
+
+**Phase 4: Scripts (need GPU/models — implement structure only)**
+15. `scripts/build_memory_bank.py`
+16. `scripts/run_eval.py`
+17. `scripts/run_baseline.py`
+18. `scripts/run_hag.py`
+
+---
+
+## Coding Conventions
+
+- **Type hints everywhere.** Every function signature must have full type hints.
+- **Docstrings on every public method.** Include shapes in docstrings for tensor arguments.
+- **Tensor shapes in comments.** Any line with a tensor operation should have a shape comment: `# (batch, N)`
+- **No global state.** Everything parameterized through config dataclasses.
+- **CPU-first.** Default device is CPU. GPU is opt-in via `.to(device)`.
+- **Deterministic tests.** Always set `torch.manual_seed()` in tests.
+- **Use `torch.no_grad()` in retrieval** when not training (phase 1).
+- Use `logging` module, not print statements.
+
+---
+
+## Key Mathematical References
+
+The Modern Continuous Hopfield Network update rule and energy function come from:
+
+> Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021)
+
+Key equations:
+- **Update**: `q_{t+1} = M · softmax(β · Mᵀ · qₜ)`
+- **Energy**: `E(q) = -1/β · log(Σᵢ exp(β · qᵀmᵢ)) + 1/2 · ||q||²`
+- **Property**: Energy decreases monotonically: `E(q_{t+1}) ≤ E(q_t)`
+- **Property**: Fixed points are local minima of E
+- **Capacity**: Exponential in d (unlike classical Hopfield's ~0.14N)
+
+The connection to attention: the Hopfield update is mathematically equivalent to one step of cross-attention where queries attend to stored memories. But the ITERATIVE application (multiple steps) and the energy-based analysis are what differentiate this from standard attention.
+
+---
+
+## What NOT To Do
+
+- Do NOT use `torch.cuda` in any test
+- Do NOT download models/datasets in tests — use synthetic data and mocks
+- Do NOT implement the encoder or generator with real models yet — provide the interface and use mocks in tests
+- Do NOT add wandb, tensorboard, or other logging frameworks yet
+- Do NOT try to optimize for speed yet — correctness first
+- Do NOT implement the training loop (memory bank training) — that's phase 2 of the research, not this codebase yet