diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /CLAUDE.md | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'CLAUDE.md')
| -rw-r--r-- | CLAUDE.md | 870 |
1 files changed, 870 insertions, 0 deletions
diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..c99b846 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,870 @@ +# CLAUDE.md — Hopfield-Augmented Generation (HAG) + +## Project Overview + +We are building **HAG (Hopfield-Augmented Generation)**, a retrieval-augmented generation system that replaces standard vector similarity search (FAISS top-k) with **iterative Modern Hopfield Network retrieval**. This is a research project targeting a top ML venue (ICML/NeurIPS). + +### Core Idea + +Standard RAG does: +``` +y = LLM(x, top_k_passages(FAISS(Enc(x)))) +``` + +HAG does: +``` +q₀ = Enc(x) +qₜ₊₁ = M · softmax(β · Mᵀ · qₜ) # iterate T steps +α_T = softmax(β · Mᵀ · q_T) # final attention weights +retrieved_passages = top_k(α_T) # use weights to select passages +y = LLM(x, retrieved_passages) +``` + +Where: +- `M ∈ ℝ^{d × N}` is a memory bank of passage embeddings (same encoder as vanilla RAG) +- `q` is the query embedding +- `β` is an inverse temperature parameter +- The iterative update is the retrieval dynamics of a **Modern Continuous Hopfield Network** (Ramsauer et al., 2020) + +### Why This Matters + +1. **Iterative refinement**: RAG retrieves once. HAG refines the query in embedding space over T steps, enabling implicit multi-hop reasoning. +2. **Energy-based theory**: Retrieval is energy minimization with convergence guarantees (energy decreases monotonically). +3. **Differentiable**: Unlike FAISS top-k, the entire retrieval is differentiable — enabling future end-to-end training. +4. **No structural preprocessing**: Unlike GraphRAG methods (LinearRAG, HippoRAG), no NER, no graph construction, no PageRank. Just an embedding matrix. + +### What We Are NOT Doing (Scope Boundaries) + +- NOT modifying the LLM architecture (no cross-attention injection) +- NOT training the memory bank in phase 1 (M is frozen embeddings) +- NOT building a graph structure +- NOT doing anything that requires the LLM during retrieval +- The LLM is frozen and used only for final generation + +--- + +## Repository Structure + +``` +hag/ +├── CLAUDE.md # This file +├── pyproject.toml # Project config, dependencies +├── README.md # Project README +├── hag/ +│ ├── __init__.py +│ ├── config.py # All hyperparameters and configs (dataclass-based) +│ ├── encoder.py # Wrapper for encoding queries/passages +│ ├── memory_bank.py # Memory bank construction and management +│ ├── hopfield.py # Core: Hopfield retrieval module +│ ├── retriever_faiss.py # Baseline: FAISS top-k retriever +│ ├── retriever_hopfield.py # HAG retriever (wraps hopfield.py + passage selection) +│ ├── generator.py # LLM generation wrapper +│ ├── pipeline.py # End-to-end RAG/HAG pipeline +│ ├── metrics.py # Evaluation metrics (EM, F1, retrieval recall) +│ └── energy.py # Energy computation and analysis utilities +├── scripts/ +│ ├── build_memory_bank.py # Offline: encode corpus → memory bank +│ ├── run_eval.py # Run evaluation on a dataset +│ ├── run_baseline.py # Run vanilla RAG baseline +│ ├── run_hag.py # Run HAG +│ ├── analyze_energy.py # Analyze energy curves and convergence +│ └── visualize_trajectory.py # UMAP visualization of query trajectory +├── tests/ +│ ├── test_hopfield.py # Unit tests for Hopfield module +│ ├── test_memory_bank.py # Unit tests for memory bank +│ ├── test_retriever.py # Unit tests for both retrievers +│ ├── test_pipeline.py # Integration tests for pipeline +│ ├── test_metrics.py # Unit tests for metrics +│ └── test_energy.py # Unit tests for energy module +├── configs/ +│ ├── default.yaml # Default experiment config +│ ├── hotpotqa.yaml # HotpotQA-specific config +│ ├── musique.yaml # MuSiQue-specific config +│ └── 2wikimhqa.yaml # 2WikiMultiHopQA-specific config +└── notebooks/ + └── demo.ipynb # Quick demo notebook +``` + +--- + +## Module Specifications + +### 1. `hag/config.py` + +Use Python dataclasses. All hyperparameters in one place. + +```python +@dataclass +class HopfieldConfig: + beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval + max_iter: int = 5 # Maximum Hopfield iteration steps + conv_threshold: float = 1e-4 # Convergence: stop if ||q_{t+1} - q_t|| < threshold + top_k: int = 5 # Number of passages to retrieve from final attention weights + +@dataclass +class MemoryBankConfig: + embedding_dim: int = 768 # Must match encoder output dim + normalize: bool = True # L2-normalize embeddings in memory bank + +@dataclass +class EncoderConfig: + model_name: str = "facebook/contriever-msmarco" # Default encoder + max_length: int = 512 + batch_size: int = 64 + +@dataclass +class GeneratorConfig: + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + max_new_tokens: int = 128 + temperature: float = 0.0 # Greedy decoding for reproducibility + +@dataclass +class PipelineConfig: + hopfield: HopfieldConfig = field(default_factory=HopfieldConfig) + memory: MemoryBankConfig = field(default_factory=MemoryBankConfig) + encoder: EncoderConfig = field(default_factory=EncoderConfig) + generator: GeneratorConfig = field(default_factory=GeneratorConfig) + retriever_type: str = "hopfield" # "hopfield" or "faiss" +``` + +### 2. `hag/hopfield.py` — THE CORE MODULE + +This is the most important file. Implement the Modern Continuous Hopfield Network retrieval. + +```python +class HopfieldRetrieval: + """ + Modern Continuous Hopfield Network for memory retrieval. + + Given memory bank M ∈ ℝ^{d × N} and query q ∈ ℝ^d: + 1. Compute attention: α = softmax(β * Mᵀ @ q) + 2. Update query: q_new = M @ α + 3. Repeat until convergence or max_iter + + The energy function is: + E(q) = -1/β * log(Σ_i exp(β * q^T m_i)) + 1/2 * ||q||^2 + + Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease) + """ + + def __init__(self, config: HopfieldConfig): + ... + + def retrieve( + self, + query: torch.Tensor, # (d,) or (batch, d) + memory: torch.Tensor, # (d, N) — the memory bank + return_trajectory: bool = False, + return_energy: bool = False, + ) -> HopfieldResult: + """ + Run iterative Hopfield retrieval. + + Returns HopfieldResult with: + - attention_weights: (N,) or (batch, N) — final α_T + - converged_query: (d,) or (batch, d) — final q_T + - num_steps: int — actual steps taken + - trajectory: optional list of q_t at each step + - energy_curve: optional list of E(q_t) at each step + """ + ... + + def compute_energy( + self, + query: torch.Tensor, + memory: torch.Tensor, + ) -> torch.Tensor: + """ + E(q) = -1/β * log(Σ_i exp(β * qᵀ m_i)) + 1/2 * ||q||^2 + + This is the log-sum-exp energy of the Modern Hopfield Network. + """ + ... +``` + +**CRITICAL implementation details for `retrieve()`:** + +```python +def retrieve(self, query, memory, return_trajectory=False, return_energy=False): + # Ensure correct shapes + # query: (batch, d) — unsqueeze if needed + # memory: (d, N) + + q = query.clone() + trajectory = [q.clone()] if return_trajectory else None + energies = [self.compute_energy(q, memory)] if return_energy else None + + for t in range(self.config.max_iter): + # Core Hopfield update + logits = self.config.beta * (memory.T @ q.T).T # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + q_new = (memory @ alpha.T).T # (batch, d) + + # Check convergence + delta = torch.norm(q_new - q, dim=-1).max() + q = q_new + + if return_trajectory: + trajectory.append(q.clone()) + if return_energy: + energies.append(self.compute_energy(q, memory)) + + if delta < self.config.conv_threshold: + break + + # Final attention weights (recompute to ensure consistency) + logits = self.config.beta * (memory.T @ q.T).T + alpha = torch.softmax(logits, dim=-1) + + return HopfieldResult( + attention_weights=alpha, + converged_query=q, + num_steps=t + 1, + trajectory=trajectory, + energy_curve=energies, + ) +``` + +**Pay special attention to:** +- Numerical stability: use `torch.logsumexp` in energy computation +- Batch support: all operations should work for both single query and batched queries +- Memory efficiency: don't store trajectory/energy unless explicitly requested + +### 3. `hag/memory_bank.py` + +```python +class MemoryBank: + """ + Stores passage embeddings and provides lookup from indices back to text. + + The memory bank is M ∈ ℝ^{d × N} where each column is a passage embedding. + Also maintains a mapping from column index → passage text for final retrieval. + """ + + def __init__(self, config: MemoryBankConfig): + self.embeddings: torch.Tensor = None # (d, N) + self.passages: List[str] = [] # N passages + self.config = config + + def build_from_embeddings(self, embeddings: torch.Tensor, passages: List[str]): + """ + embeddings: (N, d) — note: input is (N, d), internally stored as (d, N) + passages: list of N passage strings + """ + if self.config.normalize: + embeddings = F.normalize(embeddings, dim=-1) + self.embeddings = embeddings.T # Store as (d, N) for efficient matmul + self.passages = passages + + def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: + """Given top-k indices, return corresponding passage texts.""" + ... + + def save(self, path: str): ... + def load(self, path: str): ... + + @property + def size(self) -> int: + return self.embeddings.shape[1] if self.embeddings is not None else 0 + + @property + def dim(self) -> int: + return self.embeddings.shape[0] if self.embeddings is not None else 0 +``` + +### 4. `hag/retriever_faiss.py` + +Vanilla RAG baseline retriever using FAISS. + +```python +class FAISSRetriever: + """ + Standard top-k retrieval using FAISS inner product search. + This is the baseline to compare against. + """ + + def __init__(self, top_k: int = 5): + self.index = None + self.passages: List[str] = [] + self.top_k = top_k + + def build_index(self, embeddings: np.ndarray, passages: List[str]): + """Build FAISS IndexFlatIP from (N, d) embeddings.""" + ... + + def retrieve(self, query: np.ndarray) -> RetrievalResult: + """ + query: (d,) or (batch, d) + Returns RetrievalResult with top_k passage texts, scores, and indices. + """ + ... +``` + +### 5. `hag/retriever_hopfield.py` + +```python +class HopfieldRetriever: + """ + Wraps HopfieldRetrieval + MemoryBank into a retriever interface. + + The bridge between Hopfield's continuous retrieval and the discrete + passage selection needed for LLM prompting. + """ + + def __init__(self, hopfield: HopfieldRetrieval, memory_bank: MemoryBank, top_k: int = 5): + ... + + def retrieve(self, query_embedding: torch.Tensor, return_analysis: bool = False) -> RetrievalResult: + """ + 1. Run Hopfield iterative retrieval → get attention weights α_T + 2. Take top_k indices from α_T + 3. Look up corresponding passage texts from memory bank + 4. Optionally return trajectory and energy for analysis + + Returns RetrievalResult with: + - passages: List[str] + - scores: attention weights for top-k + - indices: which memory slots were selected + - (optional) hopfield_result: full HopfieldResult for analysis + """ + ... +``` + +### 6. `hag/pipeline.py` + +```python +class RAGPipeline: + """ + End-to-end pipeline: query → encode → retrieve → generate. + Supports both FAISS (baseline) and Hopfield (ours) retrieval. + """ + + def __init__(self, config: PipelineConfig): + self.encoder = Encoder(config.encoder) + self.generator = Generator(config.generator) + + if config.retriever_type == "faiss": + self.retriever = FAISSRetriever(top_k=config.hopfield.top_k) + elif config.retriever_type == "hopfield": + hopfield = HopfieldRetrieval(config.hopfield) + self.retriever = HopfieldRetriever(hopfield, memory_bank, top_k=config.hopfield.top_k) + + def run(self, question: str) -> PipelineResult: + """ + Full pipeline: + 1. Encode question → query embedding + 2. Retrieve passages (FAISS or Hopfield) + 3. Format prompt with question + retrieved passages + 4. Generate answer with LLM + """ + ... + + def run_batch(self, questions: List[str]) -> List[PipelineResult]: + ... +``` + +### 7. `hag/metrics.py` + +```python +def exact_match(prediction: str, ground_truth: str) -> float: + """Normalized exact match. Lowercase, strip, remove articles/punctuation.""" + ... + +def f1_score(prediction: str, ground_truth: str) -> float: + """Token-level F1 between prediction and ground truth.""" + ... + +def retrieval_recall_at_k(retrieved_indices: List[int], gold_indices: List[int], k: int) -> float: + """What fraction of gold passages appear in the retrieved top-k?""" + ... + +def evaluate_dataset(results: List[PipelineResult], gold_answers: List[str]) -> Dict[str, float]: + """Compute aggregate metrics over a dataset. Returns dict with EM, F1, etc.""" + ... +``` + +### 8. `hag/energy.py` + +```python +def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]: + """Extract energy values at each iteration step.""" + ... + +def compute_energy_gap(energy_curve: List[float]) -> float: + """ + ΔE = E(q_0) - E(q_T) + Larger gap = more refinement happened. + """ + ... + +def verify_monotonic_decrease(energy_curve: List[float]) -> bool: + """Check that E(q_{t+1}) <= E(q_t) for all t. Should always be True.""" + ... + +def compute_attention_entropy(attention_weights: torch.Tensor) -> float: + """ + H(α) = -Σ_i α_i log α_i + Low entropy = sharp retrieval (confident). + High entropy = diffuse retrieval (uncertain). + """ + ... +``` + +--- + +## Data Types + +Define these as dataclasses or NamedTuples: + +```python +@dataclass +class HopfieldResult: + attention_weights: torch.Tensor # (batch, N) or (N,) + converged_query: torch.Tensor # (batch, d) or (d,) + num_steps: int + trajectory: Optional[List[torch.Tensor]] = None # list of q_t + energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t) + +@dataclass +class RetrievalResult: + passages: List[str] + scores: torch.Tensor # top-k scores + indices: torch.Tensor # top-k indices + hopfield_result: Optional[HopfieldResult] = None + +@dataclass +class PipelineResult: + question: str + answer: str + retrieved_passages: List[str] + retrieval_result: RetrievalResult +``` + +--- + +## Test Specifications + +**ALL TESTS MUST RUN ON CPU. NO GPU REQUIRED.** + +Use small synthetic data (random tensors) for unit tests. Do not download any models or datasets in tests. + +### `tests/test_hopfield.py` + +```python +class TestHopfieldRetrieval: + """Test the core Hopfield module with synthetic data on CPU.""" + + def setup_method(self): + """Create small synthetic memory bank and queries.""" + torch.manual_seed(42) + self.d = 64 # embedding dim + self.N = 100 # number of memories + self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N) + self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d) + self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6) + self.hopfield = HopfieldRetrieval(self.config) + + def test_output_shapes(self): + """attention_weights should be (1, N), converged_query should be (1, d).""" + result = self.hopfield.retrieve(self.query, self.memory) + assert result.attention_weights.shape == (1, self.N) + assert result.converged_query.shape == (1, self.d) + + def test_attention_weights_sum_to_one(self): + """softmax output must sum to 1.""" + result = self.hopfield.retrieve(self.query, self.memory) + assert torch.allclose(result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5) + + def test_attention_weights_non_negative(self): + """All attention weights must be >= 0.""" + result = self.hopfield.retrieve(self.query, self.memory) + assert (result.attention_weights >= 0).all() + + def test_energy_monotonic_decrease(self): + """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property.""" + result = self.hopfield.retrieve(self.query, self.memory, return_energy=True) + energies = [e.item() for e in result.energy_curve] + for i in range(len(energies) - 1): + assert energies[i+1] <= energies[i] + 1e-6, \ + f"Energy increased at step {i}: {energies[i]} -> {energies[i+1]}" + + def test_convergence(self): + """With enough iterations, query should converge (delta < threshold).""" + config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(self.query, self.memory, return_trajectory=True) + # Final two queries should be very close + q_last = result.trajectory[-1] + q_prev = result.trajectory[-2] + delta = torch.norm(q_last - q_prev) + assert delta < 1e-4, f"Did not converge: delta={delta}" + + def test_high_beta_sharp_retrieval(self): + """Higher β should produce sharper (lower entropy) attention.""" + low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5)) + high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5)) + + result_low = low_beta.retrieve(self.query, self.memory) + result_high = high_beta.retrieve(self.query, self.memory) + + entropy_low = -(result_low.attention_weights * result_low.attention_weights.log()).sum() + entropy_high = -(result_high.attention_weights * result_high.attention_weights.log()).sum() + + assert entropy_high < entropy_low, "Higher β should give lower entropy" + + def test_single_memory_converges_to_it(self): + """With N=1, retrieval should converge to the single memory.""" + single_memory = F.normalize(torch.randn(self.d, 1), dim=0) + result = self.hopfield.retrieve(self.query, single_memory) + assert torch.allclose(result.attention_weights, torch.ones(1, 1), atol=1e-5) + + def test_query_near_memory_retrieves_it(self): + """If query ≈ memory_i, attention should peak at index i.""" + target_idx = 42 + query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match + config = HopfieldConfig(beta=10.0, max_iter=5) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(query, self.memory) + top_idx = result.attention_weights.argmax(dim=-1).item() + assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}" + + def test_batch_retrieval(self): + """Should handle batch of queries.""" + batch_query = F.normalize(torch.randn(8, self.d), dim=-1) + result = self.hopfield.retrieve(batch_query, self.memory) + assert result.attention_weights.shape == (8, self.N) + assert result.converged_query.shape == (8, self.d) + + def test_iteration_refines_query(self): + """ + Multi-hop test: query starts far from target, iteration should bring it closer. + + Setup: memory has two clusters. Query is near cluster A but the "answer" + is in cluster B, reachable through an intermediate memory that bridges both. + After iteration, the query should drift toward cluster B. + """ + torch.manual_seed(0) + d = 32 + + # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...] + # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...] + # Bridge memory 10: between A and B + center_a = torch.zeros(d); center_a[0] = 1.0 + center_b = torch.zeros(d); center_b[1] = 1.0 + bridge = F.normalize(center_a + center_b, dim=0) + + memories = [] + for _ in range(5): + memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0)) + for _ in range(5): + memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0)) + memories.append(bridge) + + M = torch.stack(memories, dim=1) # (d, 11) + q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0) + + config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(q0, M, return_trajectory=True) + + # After iteration, query should have drifted: its dot product with center_b + # should be higher than the initial query's dot product with center_b + initial_sim_b = (q0.squeeze() @ center_b).item() + final_sim_b = (result.converged_query.squeeze() @ center_b).item() + assert final_sim_b > initial_sim_b, \ + f"Iteration should pull query toward cluster B: initial={initial_sim_b:.4f}, final={final_sim_b:.4f}" +``` + +### `tests/test_memory_bank.py` + +```python +class TestMemoryBank: + + def test_build_and_size(self): + """Build memory bank and verify size.""" + config = MemoryBankConfig(embedding_dim=64, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(100, 64) + passages = [f"passage {i}" for i in range(100)] + mb.build_from_embeddings(embeddings, passages) + assert mb.size == 100 + assert mb.dim == 64 + + def test_normalization(self): + """If normalize=True, stored embeddings should have unit norm.""" + config = MemoryBankConfig(embedding_dim=64, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(50, 64) * 5 # non-unit norm + mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) + norms = torch.norm(mb.embeddings, dim=0) + assert torch.allclose(norms, torch.ones(50), atol=1e-5) + + def test_get_passages_by_indices(self): + """Index → passage text lookup.""" + config = MemoryBankConfig(embedding_dim=64, normalize=False) + mb = MemoryBank(config) + passages = [f"passage {i}" for i in range(100)] + mb.build_from_embeddings(torch.randn(100, 64), passages) + result = mb.get_passages_by_indices(torch.tensor([0, 50, 99])) + assert result == ["passage 0", "passage 50", "passage 99"] + + def test_save_and_load(self, tmp_path): + """Save and reload memory bank, verify contents match.""" + config = MemoryBankConfig(embedding_dim=32, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(20, 32) + passages = [f"text {i}" for i in range(20)] + mb.build_from_embeddings(embeddings, passages) + + save_path = str(tmp_path / "mb.pt") + mb.save(save_path) + + mb2 = MemoryBank(config) + mb2.load(save_path) + assert mb2.size == 20 + assert mb2.passages == passages + assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6) +``` + +### `tests/test_retriever.py` + +```python +class TestHopfieldRetriever: + + def test_retrieves_correct_number_of_passages(self): + """top_k=3 should return exactly 3 passages.""" + ... + + def test_scores_are_sorted_descending(self): + """Returned scores should be in descending order.""" + ... + + def test_passages_match_indices(self): + """Returned passage texts should correspond to returned indices.""" + ... + +class TestFAISSRetriever: + + def test_retrieves_correct_number(self): + ... + + def test_nearest_neighbor_is_self(self): + """If query = passage_i's embedding, top-1 should be passage_i.""" + ... + +class TestRetrieverComparison: + + def test_same_top1_for_obvious_query(self): + """ + When query is very close to one memory, both FAISS and Hopfield + should agree on the top-1 result. + """ + ... +``` + +### `tests/test_metrics.py` + +```python +class TestMetrics: + + def test_exact_match_basic(self): + assert exact_match("Paris", "paris") == 1.0 + assert exact_match("Paris", "London") == 0.0 + + def test_exact_match_normalization(self): + """Should strip whitespace, lowercase, remove articles.""" + assert exact_match(" The Paris ", "paris") == 1.0 + assert exact_match("A dog", "dog") == 1.0 + + def test_f1_score_perfect(self): + assert f1_score("the cat sat", "the cat sat") == 1.0 + + def test_f1_score_partial(self): + score = f1_score("the cat sat on the mat", "the cat sat") + assert 0.5 < score < 1.0 + + def test_f1_score_no_overlap(self): + assert f1_score("hello world", "foo bar") == 0.0 + + def test_retrieval_recall(self): + assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2/3 +``` + +### `tests/test_energy.py` + +```python +class TestEnergy: + + def test_energy_computation_matches_formula(self): + """ + Manually compute E(q) = -1/β * log(Σ exp(β * qᵀ m_i)) + 1/2 * ||q||^2 + and verify it matches the module output. + """ + torch.manual_seed(42) + d, N = 16, 10 + memory = torch.randn(d, N) + query = torch.randn(1, d) + beta = 2.0 + + config = HopfieldConfig(beta=beta) + hopfield = HopfieldRetrieval(config) + computed = hopfield.compute_energy(query, memory) + + # Manual computation + logits = beta * (memory.T @ query.T).T # (1, N) + expected = -1.0/beta * torch.logsumexp(logits, dim=-1) + 0.5 * (query ** 2).sum(dim=-1) + assert torch.allclose(computed, expected, atol=1e-5) + + def test_verify_monotonic(self): + energies = [10.0, 8.0, 6.5, 6.4, 6.39] + assert verify_monotonic_decrease(energies) == True + + def test_verify_monotonic_fails(self): + energies = [10.0, 8.0, 9.0, 6.0] + assert verify_monotonic_decrease(energies) == False + + def test_attention_entropy(self): + """Uniform attention should have max entropy, one-hot should have zero.""" + uniform = torch.ones(100) / 100 + one_hot = torch.zeros(100); one_hot[0] = 1.0 + assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot) + assert compute_attention_entropy(one_hot) < 0.01 +``` + +### `tests/test_pipeline.py` + +```python +class TestPipeline: + """ + Integration tests using MOCK encoder and generator. + These test the wiring, not the model quality. + """ + + def test_pipeline_runs_end_to_end_hopfield(self): + """Mock encoder + mock LLM + Hopfield retriever → produces an answer string.""" + ... + + def test_pipeline_runs_end_to_end_faiss(self): + """Same but with FAISS retriever.""" + ... + + def test_pipeline_retrieval_result_included(self): + """PipelineResult should contain retrieval metadata.""" + ... +``` + +**For pipeline tests, use mock/fake encoder and generator:** +```python +class FakeEncoder: + def encode(self, text: str) -> torch.Tensor: + # Deterministic hash-based embedding for testing + torch.manual_seed(hash(text) % 2**32) + return F.normalize(torch.randn(1, 64), dim=-1) + +class FakeGenerator: + def generate(self, prompt: str) -> str: + return "mock answer" +``` + +--- + +## Dependencies + +```toml +[project] +name = "hag" +version = "0.1.0" +requires-python = ">=3.10" + +dependencies = [ + "torch>=2.0", + "numpy>=1.24", + "faiss-cpu>=1.7", + "transformers>=4.36", + "datasets>=2.14", + "pyyaml>=6.0", + "tqdm>=4.65", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-cov>=4.0", +] +eval = [ + "scikit-learn>=1.3", + "umap-learn>=0.5", + "matplotlib>=3.7", +] +``` + +--- + +## Implementation Order + +**Phase 1: Core module (do this first)** +1. `config.py` — all dataclasses +2. `hopfield.py` — the core Hopfield retrieval +3. `energy.py` — energy utilities +4. `tests/test_hopfield.py` — run and pass all tests +5. `tests/test_energy.py` — run and pass all tests + +**Phase 2: Memory and retrievers** +6. `memory_bank.py` +7. `retriever_hopfield.py` +8. `retriever_faiss.py` +9. `tests/test_memory_bank.py` +10. `tests/test_retriever.py` + +**Phase 3: Metrics and pipeline** +11. `metrics.py` +12. `pipeline.py` (with mock encoder/generator for now) +13. `tests/test_metrics.py` +14. `tests/test_pipeline.py` + +**Phase 4: Scripts (need GPU/models — implement structure only)** +15. `scripts/build_memory_bank.py` +16. `scripts/run_eval.py` +17. `scripts/run_baseline.py` +18. `scripts/run_hag.py` + +--- + +## Coding Conventions + +- **Type hints everywhere.** Every function signature must have full type hints. +- **Docstrings on every public method.** Include shapes in docstrings for tensor arguments. +- **Tensor shapes in comments.** Any line with a tensor operation should have a shape comment: `# (batch, N)` +- **No global state.** Everything parameterized through config dataclasses. +- **CPU-first.** Default device is CPU. GPU is opt-in via `.to(device)`. +- **Deterministic tests.** Always set `torch.manual_seed()` in tests. +- **Use `torch.no_grad()` in retrieval** when not training (phase 1). +- Use `logging` module, not print statements. + +--- + +## Key Mathematical References + +The Modern Continuous Hopfield Network update rule and energy function come from: + +> Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021) + +Key equations: +- **Update**: `q_{t+1} = M · softmax(β · Mᵀ · qₜ)` +- **Energy**: `E(q) = -1/β · log(Σᵢ exp(β · qᵀmᵢ)) + 1/2 · ||q||²` +- **Property**: Energy decreases monotonically: `E(q_{t+1}) ≤ E(q_t)` +- **Property**: Fixed points are local minima of E +- **Capacity**: Exponential in d (unlike classical Hopfield's ~0.14N) + +The connection to attention: the Hopfield update is mathematically equivalent to one step of cross-attention where queries attend to stored memories. But the ITERATIVE application (multiple steps) and the energy-based analysis are what differentiate this from standard attention. + +--- + +## What NOT To Do + +- Do NOT use `torch.cuda` in any test +- Do NOT download models/datasets in tests — use synthetic data and mocks +- Do NOT implement the encoder or generator with real models yet — provide the interface and use mocks in tests +- Do NOT add wandb, tensorboard, or other logging frameworks yet +- Do NOT try to optimize for speed yet — correctness first +- Do NOT implement the training loop (memory bank training) — that's phase 2 of the research, not this codebase yet |
