# CLAUDE.md — Hopfield-Augmented Generation (HAG) ## Project Overview We are building **HAG (Hopfield-Augmented Generation)**, a retrieval-augmented generation system that replaces standard vector similarity search (FAISS top-k) with **iterative Modern Hopfield Network retrieval**. This is a research project targeting a top ML venue (ICML/NeurIPS). ### Core Idea Standard RAG does: ``` y = LLM(x, top_k_passages(FAISS(Enc(x)))) ``` HAG does: ``` q₀ = Enc(x) qₜ₊₁ = M · softmax(β · Mᵀ · qₜ) # iterate T steps α_T = softmax(β · Mᵀ · q_T) # final attention weights retrieved_passages = top_k(α_T) # use weights to select passages y = LLM(x, retrieved_passages) ``` Where: - `M ∈ ℝ^{d × N}` is a memory bank of passage embeddings (same encoder as vanilla RAG) - `q` is the query embedding - `β` is an inverse temperature parameter - The iterative update is the retrieval dynamics of a **Modern Continuous Hopfield Network** (Ramsauer et al., 2020) ### Why This Matters 1. **Iterative refinement**: RAG retrieves once. HAG refines the query in embedding space over T steps, enabling implicit multi-hop reasoning. 2. **Energy-based theory**: Retrieval is energy minimization with convergence guarantees (energy decreases monotonically). 3. **Differentiable**: Unlike FAISS top-k, the entire retrieval is differentiable — enabling future end-to-end training. 4. **No structural preprocessing**: Unlike GraphRAG methods (LinearRAG, HippoRAG), no NER, no graph construction, no PageRank. Just an embedding matrix. ### What We Are NOT Doing (Scope Boundaries) - NOT modifying the LLM architecture (no cross-attention injection) - NOT training the memory bank in phase 1 (M is frozen embeddings) - NOT building a graph structure - NOT doing anything that requires the LLM during retrieval - The LLM is frozen and used only for final generation --- ## Repository Structure ``` hag/ ├── CLAUDE.md # This file ├── pyproject.toml # Project config, dependencies ├── README.md # Project README ├── hag/ │ ├── __init__.py │ ├── config.py # All hyperparameters and configs (dataclass-based) │ ├── encoder.py # Wrapper for encoding queries/passages │ ├── memory_bank.py # Memory bank construction and management │ ├── hopfield.py # Core: Hopfield retrieval module │ ├── retriever_faiss.py # Baseline: FAISS top-k retriever │ ├── retriever_hopfield.py # HAG retriever (wraps hopfield.py + passage selection) │ ├── generator.py # LLM generation wrapper │ ├── pipeline.py # End-to-end RAG/HAG pipeline │ ├── metrics.py # Evaluation metrics (EM, F1, retrieval recall) │ └── energy.py # Energy computation and analysis utilities ├── scripts/ │ ├── build_memory_bank.py # Offline: encode corpus → memory bank │ ├── run_eval.py # Run evaluation on a dataset │ ├── run_baseline.py # Run vanilla RAG baseline │ ├── run_hag.py # Run HAG │ ├── analyze_energy.py # Analyze energy curves and convergence │ └── visualize_trajectory.py # UMAP visualization of query trajectory ├── tests/ │ ├── test_hopfield.py # Unit tests for Hopfield module │ ├── test_memory_bank.py # Unit tests for memory bank │ ├── test_retriever.py # Unit tests for both retrievers │ ├── test_pipeline.py # Integration tests for pipeline │ ├── test_metrics.py # Unit tests for metrics │ └── test_energy.py # Unit tests for energy module ├── configs/ │ ├── default.yaml # Default experiment config │ ├── hotpotqa.yaml # HotpotQA-specific config │ ├── musique.yaml # MuSiQue-specific config │ └── 2wikimhqa.yaml # 2WikiMultiHopQA-specific config └── notebooks/ └── demo.ipynb # Quick demo notebook ``` --- ## Module Specifications ### 1. `hag/config.py` Use Python dataclasses. All hyperparameters in one place. ```python @dataclass class HopfieldConfig: beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval max_iter: int = 5 # Maximum Hopfield iteration steps conv_threshold: float = 1e-4 # Convergence: stop if ||q_{t+1} - q_t|| < threshold top_k: int = 5 # Number of passages to retrieve from final attention weights @dataclass class MemoryBankConfig: embedding_dim: int = 768 # Must match encoder output dim normalize: bool = True # L2-normalize embeddings in memory bank @dataclass class EncoderConfig: model_name: str = "facebook/contriever-msmarco" # Default encoder max_length: int = 512 batch_size: int = 64 @dataclass class GeneratorConfig: model_name: str = "meta-llama/Llama-3.1-8B-Instruct" max_new_tokens: int = 128 temperature: float = 0.0 # Greedy decoding for reproducibility @dataclass class PipelineConfig: hopfield: HopfieldConfig = field(default_factory=HopfieldConfig) memory: MemoryBankConfig = field(default_factory=MemoryBankConfig) encoder: EncoderConfig = field(default_factory=EncoderConfig) generator: GeneratorConfig = field(default_factory=GeneratorConfig) retriever_type: str = "hopfield" # "hopfield" or "faiss" ``` ### 2. `hag/hopfield.py` — THE CORE MODULE This is the most important file. Implement the Modern Continuous Hopfield Network retrieval. ```python class HopfieldRetrieval: """ Modern Continuous Hopfield Network for memory retrieval. Given memory bank M ∈ ℝ^{d × N} and query q ∈ ℝ^d: 1. Compute attention: α = softmax(β * Mᵀ @ q) 2. Update query: q_new = M @ α 3. Repeat until convergence or max_iter The energy function is: E(q) = -1/β * log(Σ_i exp(β * q^T m_i)) + 1/2 * ||q||^2 Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease) """ def __init__(self, config: HopfieldConfig): ... def retrieve( self, query: torch.Tensor, # (d,) or (batch, d) memory: torch.Tensor, # (d, N) — the memory bank return_trajectory: bool = False, return_energy: bool = False, ) -> HopfieldResult: """ Run iterative Hopfield retrieval. Returns HopfieldResult with: - attention_weights: (N,) or (batch, N) — final α_T - converged_query: (d,) or (batch, d) — final q_T - num_steps: int — actual steps taken - trajectory: optional list of q_t at each step - energy_curve: optional list of E(q_t) at each step """ ... def compute_energy( self, query: torch.Tensor, memory: torch.Tensor, ) -> torch.Tensor: """ E(q) = -1/β * log(Σ_i exp(β * qᵀ m_i)) + 1/2 * ||q||^2 This is the log-sum-exp energy of the Modern Hopfield Network. """ ... ``` **CRITICAL implementation details for `retrieve()`:** ```python def retrieve(self, query, memory, return_trajectory=False, return_energy=False): # Ensure correct shapes # query: (batch, d) — unsqueeze if needed # memory: (d, N) q = query.clone() trajectory = [q.clone()] if return_trajectory else None energies = [self.compute_energy(q, memory)] if return_energy else None for t in range(self.config.max_iter): # Core Hopfield update logits = self.config.beta * (memory.T @ q.T).T # (batch, N) alpha = torch.softmax(logits, dim=-1) # (batch, N) q_new = (memory @ alpha.T).T # (batch, d) # Check convergence delta = torch.norm(q_new - q, dim=-1).max() q = q_new if return_trajectory: trajectory.append(q.clone()) if return_energy: energies.append(self.compute_energy(q, memory)) if delta < self.config.conv_threshold: break # Final attention weights (recompute to ensure consistency) logits = self.config.beta * (memory.T @ q.T).T alpha = torch.softmax(logits, dim=-1) return HopfieldResult( attention_weights=alpha, converged_query=q, num_steps=t + 1, trajectory=trajectory, energy_curve=energies, ) ``` **Pay special attention to:** - Numerical stability: use `torch.logsumexp` in energy computation - Batch support: all operations should work for both single query and batched queries - Memory efficiency: don't store trajectory/energy unless explicitly requested ### 3. `hag/memory_bank.py` ```python class MemoryBank: """ Stores passage embeddings and provides lookup from indices back to text. The memory bank is M ∈ ℝ^{d × N} where each column is a passage embedding. Also maintains a mapping from column index → passage text for final retrieval. """ def __init__(self, config: MemoryBankConfig): self.embeddings: torch.Tensor = None # (d, N) self.passages: List[str] = [] # N passages self.config = config def build_from_embeddings(self, embeddings: torch.Tensor, passages: List[str]): """ embeddings: (N, d) — note: input is (N, d), internally stored as (d, N) passages: list of N passage strings """ if self.config.normalize: embeddings = F.normalize(embeddings, dim=-1) self.embeddings = embeddings.T # Store as (d, N) for efficient matmul self.passages = passages def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: """Given top-k indices, return corresponding passage texts.""" ... def save(self, path: str): ... def load(self, path: str): ... @property def size(self) -> int: return self.embeddings.shape[1] if self.embeddings is not None else 0 @property def dim(self) -> int: return self.embeddings.shape[0] if self.embeddings is not None else 0 ``` ### 4. `hag/retriever_faiss.py` Vanilla RAG baseline retriever using FAISS. ```python class FAISSRetriever: """ Standard top-k retrieval using FAISS inner product search. This is the baseline to compare against. """ def __init__(self, top_k: int = 5): self.index = None self.passages: List[str] = [] self.top_k = top_k def build_index(self, embeddings: np.ndarray, passages: List[str]): """Build FAISS IndexFlatIP from (N, d) embeddings.""" ... def retrieve(self, query: np.ndarray) -> RetrievalResult: """ query: (d,) or (batch, d) Returns RetrievalResult with top_k passage texts, scores, and indices. """ ... ``` ### 5. `hag/retriever_hopfield.py` ```python class HopfieldRetriever: """ Wraps HopfieldRetrieval + MemoryBank into a retriever interface. The bridge between Hopfield's continuous retrieval and the discrete passage selection needed for LLM prompting. """ def __init__(self, hopfield: HopfieldRetrieval, memory_bank: MemoryBank, top_k: int = 5): ... def retrieve(self, query_embedding: torch.Tensor, return_analysis: bool = False) -> RetrievalResult: """ 1. Run Hopfield iterative retrieval → get attention weights α_T 2. Take top_k indices from α_T 3. Look up corresponding passage texts from memory bank 4. Optionally return trajectory and energy for analysis Returns RetrievalResult with: - passages: List[str] - scores: attention weights for top-k - indices: which memory slots were selected - (optional) hopfield_result: full HopfieldResult for analysis """ ... ``` ### 6. `hag/pipeline.py` ```python class RAGPipeline: """ End-to-end pipeline: query → encode → retrieve → generate. Supports both FAISS (baseline) and Hopfield (ours) retrieval. """ def __init__(self, config: PipelineConfig): self.encoder = Encoder(config.encoder) self.generator = Generator(config.generator) if config.retriever_type == "faiss": self.retriever = FAISSRetriever(top_k=config.hopfield.top_k) elif config.retriever_type == "hopfield": hopfield = HopfieldRetrieval(config.hopfield) self.retriever = HopfieldRetriever(hopfield, memory_bank, top_k=config.hopfield.top_k) def run(self, question: str) -> PipelineResult: """ Full pipeline: 1. Encode question → query embedding 2. Retrieve passages (FAISS or Hopfield) 3. Format prompt with question + retrieved passages 4. Generate answer with LLM """ ... def run_batch(self, questions: List[str]) -> List[PipelineResult]: ... ``` ### 7. `hag/metrics.py` ```python def exact_match(prediction: str, ground_truth: str) -> float: """Normalized exact match. Lowercase, strip, remove articles/punctuation.""" ... def f1_score(prediction: str, ground_truth: str) -> float: """Token-level F1 between prediction and ground truth.""" ... def retrieval_recall_at_k(retrieved_indices: List[int], gold_indices: List[int], k: int) -> float: """What fraction of gold passages appear in the retrieved top-k?""" ... def evaluate_dataset(results: List[PipelineResult], gold_answers: List[str]) -> Dict[str, float]: """Compute aggregate metrics over a dataset. Returns dict with EM, F1, etc.""" ... ``` ### 8. `hag/energy.py` ```python def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]: """Extract energy values at each iteration step.""" ... def compute_energy_gap(energy_curve: List[float]) -> float: """ ΔE = E(q_0) - E(q_T) Larger gap = more refinement happened. """ ... def verify_monotonic_decrease(energy_curve: List[float]) -> bool: """Check that E(q_{t+1}) <= E(q_t) for all t. Should always be True.""" ... def compute_attention_entropy(attention_weights: torch.Tensor) -> float: """ H(α) = -Σ_i α_i log α_i Low entropy = sharp retrieval (confident). High entropy = diffuse retrieval (uncertain). """ ... ``` --- ## Data Types Define these as dataclasses or NamedTuples: ```python @dataclass class HopfieldResult: attention_weights: torch.Tensor # (batch, N) or (N,) converged_query: torch.Tensor # (batch, d) or (d,) num_steps: int trajectory: Optional[List[torch.Tensor]] = None # list of q_t energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t) @dataclass class RetrievalResult: passages: List[str] scores: torch.Tensor # top-k scores indices: torch.Tensor # top-k indices hopfield_result: Optional[HopfieldResult] = None @dataclass class PipelineResult: question: str answer: str retrieved_passages: List[str] retrieval_result: RetrievalResult ``` --- ## Test Specifications **ALL TESTS MUST RUN ON CPU. NO GPU REQUIRED.** Use small synthetic data (random tensors) for unit tests. Do not download any models or datasets in tests. ### `tests/test_hopfield.py` ```python class TestHopfieldRetrieval: """Test the core Hopfield module with synthetic data on CPU.""" def setup_method(self): """Create small synthetic memory bank and queries.""" torch.manual_seed(42) self.d = 64 # embedding dim self.N = 100 # number of memories self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N) self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d) self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6) self.hopfield = HopfieldRetrieval(self.config) def test_output_shapes(self): """attention_weights should be (1, N), converged_query should be (1, d).""" result = self.hopfield.retrieve(self.query, self.memory) assert result.attention_weights.shape == (1, self.N) assert result.converged_query.shape == (1, self.d) def test_attention_weights_sum_to_one(self): """softmax output must sum to 1.""" result = self.hopfield.retrieve(self.query, self.memory) assert torch.allclose(result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5) def test_attention_weights_non_negative(self): """All attention weights must be >= 0.""" result = self.hopfield.retrieve(self.query, self.memory) assert (result.attention_weights >= 0).all() def test_energy_monotonic_decrease(self): """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property.""" result = self.hopfield.retrieve(self.query, self.memory, return_energy=True) energies = [e.item() for e in result.energy_curve] for i in range(len(energies) - 1): assert energies[i+1] <= energies[i] + 1e-6, \ f"Energy increased at step {i}: {energies[i]} -> {energies[i+1]}" def test_convergence(self): """With enough iterations, query should converge (delta < threshold).""" config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6) hopfield = HopfieldRetrieval(config) result = hopfield.retrieve(self.query, self.memory, return_trajectory=True) # Final two queries should be very close q_last = result.trajectory[-1] q_prev = result.trajectory[-2] delta = torch.norm(q_last - q_prev) assert delta < 1e-4, f"Did not converge: delta={delta}" def test_high_beta_sharp_retrieval(self): """Higher β should produce sharper (lower entropy) attention.""" low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5)) high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5)) result_low = low_beta.retrieve(self.query, self.memory) result_high = high_beta.retrieve(self.query, self.memory) entropy_low = -(result_low.attention_weights * result_low.attention_weights.log()).sum() entropy_high = -(result_high.attention_weights * result_high.attention_weights.log()).sum() assert entropy_high < entropy_low, "Higher β should give lower entropy" def test_single_memory_converges_to_it(self): """With N=1, retrieval should converge to the single memory.""" single_memory = F.normalize(torch.randn(self.d, 1), dim=0) result = self.hopfield.retrieve(self.query, single_memory) assert torch.allclose(result.attention_weights, torch.ones(1, 1), atol=1e-5) def test_query_near_memory_retrieves_it(self): """If query ≈ memory_i, attention should peak at index i.""" target_idx = 42 query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match config = HopfieldConfig(beta=10.0, max_iter=5) hopfield = HopfieldRetrieval(config) result = hopfield.retrieve(query, self.memory) top_idx = result.attention_weights.argmax(dim=-1).item() assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}" def test_batch_retrieval(self): """Should handle batch of queries.""" batch_query = F.normalize(torch.randn(8, self.d), dim=-1) result = self.hopfield.retrieve(batch_query, self.memory) assert result.attention_weights.shape == (8, self.N) assert result.converged_query.shape == (8, self.d) def test_iteration_refines_query(self): """ Multi-hop test: query starts far from target, iteration should bring it closer. Setup: memory has two clusters. Query is near cluster A but the "answer" is in cluster B, reachable through an intermediate memory that bridges both. After iteration, the query should drift toward cluster B. """ torch.manual_seed(0) d = 32 # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...] # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...] # Bridge memory 10: between A and B center_a = torch.zeros(d); center_a[0] = 1.0 center_b = torch.zeros(d); center_b[1] = 1.0 bridge = F.normalize(center_a + center_b, dim=0) memories = [] for _ in range(5): memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0)) for _ in range(5): memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0)) memories.append(bridge) M = torch.stack(memories, dim=1) # (d, 11) q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0) config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8) hopfield = HopfieldRetrieval(config) result = hopfield.retrieve(q0, M, return_trajectory=True) # After iteration, query should have drifted: its dot product with center_b # should be higher than the initial query's dot product with center_b initial_sim_b = (q0.squeeze() @ center_b).item() final_sim_b = (result.converged_query.squeeze() @ center_b).item() assert final_sim_b > initial_sim_b, \ f"Iteration should pull query toward cluster B: initial={initial_sim_b:.4f}, final={final_sim_b:.4f}" ``` ### `tests/test_memory_bank.py` ```python class TestMemoryBank: def test_build_and_size(self): """Build memory bank and verify size.""" config = MemoryBankConfig(embedding_dim=64, normalize=True) mb = MemoryBank(config) embeddings = torch.randn(100, 64) passages = [f"passage {i}" for i in range(100)] mb.build_from_embeddings(embeddings, passages) assert mb.size == 100 assert mb.dim == 64 def test_normalization(self): """If normalize=True, stored embeddings should have unit norm.""" config = MemoryBankConfig(embedding_dim=64, normalize=True) mb = MemoryBank(config) embeddings = torch.randn(50, 64) * 5 # non-unit norm mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) norms = torch.norm(mb.embeddings, dim=0) assert torch.allclose(norms, torch.ones(50), atol=1e-5) def test_get_passages_by_indices(self): """Index → passage text lookup.""" config = MemoryBankConfig(embedding_dim=64, normalize=False) mb = MemoryBank(config) passages = [f"passage {i}" for i in range(100)] mb.build_from_embeddings(torch.randn(100, 64), passages) result = mb.get_passages_by_indices(torch.tensor([0, 50, 99])) assert result == ["passage 0", "passage 50", "passage 99"] def test_save_and_load(self, tmp_path): """Save and reload memory bank, verify contents match.""" config = MemoryBankConfig(embedding_dim=32, normalize=True) mb = MemoryBank(config) embeddings = torch.randn(20, 32) passages = [f"text {i}" for i in range(20)] mb.build_from_embeddings(embeddings, passages) save_path = str(tmp_path / "mb.pt") mb.save(save_path) mb2 = MemoryBank(config) mb2.load(save_path) assert mb2.size == 20 assert mb2.passages == passages assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6) ``` ### `tests/test_retriever.py` ```python class TestHopfieldRetriever: def test_retrieves_correct_number_of_passages(self): """top_k=3 should return exactly 3 passages.""" ... def test_scores_are_sorted_descending(self): """Returned scores should be in descending order.""" ... def test_passages_match_indices(self): """Returned passage texts should correspond to returned indices.""" ... class TestFAISSRetriever: def test_retrieves_correct_number(self): ... def test_nearest_neighbor_is_self(self): """If query = passage_i's embedding, top-1 should be passage_i.""" ... class TestRetrieverComparison: def test_same_top1_for_obvious_query(self): """ When query is very close to one memory, both FAISS and Hopfield should agree on the top-1 result. """ ... ``` ### `tests/test_metrics.py` ```python class TestMetrics: def test_exact_match_basic(self): assert exact_match("Paris", "paris") == 1.0 assert exact_match("Paris", "London") == 0.0 def test_exact_match_normalization(self): """Should strip whitespace, lowercase, remove articles.""" assert exact_match(" The Paris ", "paris") == 1.0 assert exact_match("A dog", "dog") == 1.0 def test_f1_score_perfect(self): assert f1_score("the cat sat", "the cat sat") == 1.0 def test_f1_score_partial(self): score = f1_score("the cat sat on the mat", "the cat sat") assert 0.5 < score < 1.0 def test_f1_score_no_overlap(self): assert f1_score("hello world", "foo bar") == 0.0 def test_retrieval_recall(self): assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2/3 ``` ### `tests/test_energy.py` ```python class TestEnergy: def test_energy_computation_matches_formula(self): """ Manually compute E(q) = -1/β * log(Σ exp(β * qᵀ m_i)) + 1/2 * ||q||^2 and verify it matches the module output. """ torch.manual_seed(42) d, N = 16, 10 memory = torch.randn(d, N) query = torch.randn(1, d) beta = 2.0 config = HopfieldConfig(beta=beta) hopfield = HopfieldRetrieval(config) computed = hopfield.compute_energy(query, memory) # Manual computation logits = beta * (memory.T @ query.T).T # (1, N) expected = -1.0/beta * torch.logsumexp(logits, dim=-1) + 0.5 * (query ** 2).sum(dim=-1) assert torch.allclose(computed, expected, atol=1e-5) def test_verify_monotonic(self): energies = [10.0, 8.0, 6.5, 6.4, 6.39] assert verify_monotonic_decrease(energies) == True def test_verify_monotonic_fails(self): energies = [10.0, 8.0, 9.0, 6.0] assert verify_monotonic_decrease(energies) == False def test_attention_entropy(self): """Uniform attention should have max entropy, one-hot should have zero.""" uniform = torch.ones(100) / 100 one_hot = torch.zeros(100); one_hot[0] = 1.0 assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot) assert compute_attention_entropy(one_hot) < 0.01 ``` ### `tests/test_pipeline.py` ```python class TestPipeline: """ Integration tests using MOCK encoder and generator. These test the wiring, not the model quality. """ def test_pipeline_runs_end_to_end_hopfield(self): """Mock encoder + mock LLM + Hopfield retriever → produces an answer string.""" ... def test_pipeline_runs_end_to_end_faiss(self): """Same but with FAISS retriever.""" ... def test_pipeline_retrieval_result_included(self): """PipelineResult should contain retrieval metadata.""" ... ``` **For pipeline tests, use mock/fake encoder and generator:** ```python class FakeEncoder: def encode(self, text: str) -> torch.Tensor: # Deterministic hash-based embedding for testing torch.manual_seed(hash(text) % 2**32) return F.normalize(torch.randn(1, 64), dim=-1) class FakeGenerator: def generate(self, prompt: str) -> str: return "mock answer" ``` --- ## Dependencies ```toml [project] name = "hag" version = "0.1.0" requires-python = ">=3.10" dependencies = [ "torch>=2.0", "numpy>=1.24", "faiss-cpu>=1.7", "transformers>=4.36", "datasets>=2.14", "pyyaml>=6.0", "tqdm>=4.65", ] [project.optional-dependencies] dev = [ "pytest>=7.0", "pytest-cov>=4.0", ] eval = [ "scikit-learn>=1.3", "umap-learn>=0.5", "matplotlib>=3.7", ] ``` --- ## Implementation Order **Phase 1: Core module (do this first)** 1. `config.py` — all dataclasses 2. `hopfield.py` — the core Hopfield retrieval 3. `energy.py` — energy utilities 4. `tests/test_hopfield.py` — run and pass all tests 5. `tests/test_energy.py` — run and pass all tests **Phase 2: Memory and retrievers** 6. `memory_bank.py` 7. `retriever_hopfield.py` 8. `retriever_faiss.py` 9. `tests/test_memory_bank.py` 10. `tests/test_retriever.py` **Phase 3: Metrics and pipeline** 11. `metrics.py` 12. `pipeline.py` (with mock encoder/generator for now) 13. `tests/test_metrics.py` 14. `tests/test_pipeline.py` **Phase 4: Scripts (need GPU/models — implement structure only)** 15. `scripts/build_memory_bank.py` 16. `scripts/run_eval.py` 17. `scripts/run_baseline.py` 18. `scripts/run_hag.py` --- ## Coding Conventions - **Type hints everywhere.** Every function signature must have full type hints. - **Docstrings on every public method.** Include shapes in docstrings for tensor arguments. - **Tensor shapes in comments.** Any line with a tensor operation should have a shape comment: `# (batch, N)` - **No global state.** Everything parameterized through config dataclasses. - **CPU-first.** Default device is CPU. GPU is opt-in via `.to(device)`. - **Deterministic tests.** Always set `torch.manual_seed()` in tests. - **Use `torch.no_grad()` in retrieval** when not training (phase 1). - Use `logging` module, not print statements. --- ## Key Mathematical References The Modern Continuous Hopfield Network update rule and energy function come from: > Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021) Key equations: - **Update**: `q_{t+1} = M · softmax(β · Mᵀ · qₜ)` - **Energy**: `E(q) = -1/β · log(Σᵢ exp(β · qᵀmᵢ)) + 1/2 · ||q||²` - **Property**: Energy decreases monotonically: `E(q_{t+1}) ≤ E(q_t)` - **Property**: Fixed points are local minima of E - **Capacity**: Exponential in d (unlike classical Hopfield's ~0.14N) The connection to attention: the Hopfield update is mathematically equivalent to one step of cross-attention where queries attend to stored memories. But the ITERATIVE application (multiple steps) and the energy-based analysis are what differentiate this from standard attention. --- ## What NOT To Do - Do NOT use `torch.cuda` in any test - Do NOT download models/datasets in tests — use synthetic data and mocks - Do NOT implement the encoder or generator with real models yet — provide the interface and use mocks in tests - Do NOT add wandb, tensorboard, or other logging frameworks yet - Do NOT try to optimize for speed yet — correctness first - Do NOT implement the training loop (memory bank training) — that's phase 2 of the research, not this codebase yet