diff options
35 files changed, 3180 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..0741f54 --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +__pycache__/ +*.pyc +*.egg-info/ +dist/ +build/ +.pytest_cache/ +*.pt +.claude/ diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..c99b846 --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,870 @@ +# CLAUDE.md — Hopfield-Augmented Generation (HAG) + +## Project Overview + +We are building **HAG (Hopfield-Augmented Generation)**, a retrieval-augmented generation system that replaces standard vector similarity search (FAISS top-k) with **iterative Modern Hopfield Network retrieval**. This is a research project targeting a top ML venue (ICML/NeurIPS). + +### Core Idea + +Standard RAG does: +``` +y = LLM(x, top_k_passages(FAISS(Enc(x)))) +``` + +HAG does: +``` +q₀ = Enc(x) +qₜ₊₁ = M · softmax(β · Mᵀ · qₜ) # iterate T steps +α_T = softmax(β · Mᵀ · q_T) # final attention weights +retrieved_passages = top_k(α_T) # use weights to select passages +y = LLM(x, retrieved_passages) +``` + +Where: +- `M ∈ ℝ^{d × N}` is a memory bank of passage embeddings (same encoder as vanilla RAG) +- `q` is the query embedding +- `β` is an inverse temperature parameter +- The iterative update is the retrieval dynamics of a **Modern Continuous Hopfield Network** (Ramsauer et al., 2020) + +### Why This Matters + +1. **Iterative refinement**: RAG retrieves once. HAG refines the query in embedding space over T steps, enabling implicit multi-hop reasoning. +2. **Energy-based theory**: Retrieval is energy minimization with convergence guarantees (energy decreases monotonically). +3. **Differentiable**: Unlike FAISS top-k, the entire retrieval is differentiable — enabling future end-to-end training. +4. **No structural preprocessing**: Unlike GraphRAG methods (LinearRAG, HippoRAG), no NER, no graph construction, no PageRank. Just an embedding matrix. + +### What We Are NOT Doing (Scope Boundaries) + +- NOT modifying the LLM architecture (no cross-attention injection) +- NOT training the memory bank in phase 1 (M is frozen embeddings) +- NOT building a graph structure +- NOT doing anything that requires the LLM during retrieval +- The LLM is frozen and used only for final generation + +--- + +## Repository Structure + +``` +hag/ +├── CLAUDE.md # This file +├── pyproject.toml # Project config, dependencies +├── README.md # Project README +├── hag/ +│ ├── __init__.py +│ ├── config.py # All hyperparameters and configs (dataclass-based) +│ ├── encoder.py # Wrapper for encoding queries/passages +│ ├── memory_bank.py # Memory bank construction and management +│ ├── hopfield.py # Core: Hopfield retrieval module +│ ├── retriever_faiss.py # Baseline: FAISS top-k retriever +│ ├── retriever_hopfield.py # HAG retriever (wraps hopfield.py + passage selection) +│ ├── generator.py # LLM generation wrapper +│ ├── pipeline.py # End-to-end RAG/HAG pipeline +│ ├── metrics.py # Evaluation metrics (EM, F1, retrieval recall) +│ └── energy.py # Energy computation and analysis utilities +├── scripts/ +│ ├── build_memory_bank.py # Offline: encode corpus → memory bank +│ ├── run_eval.py # Run evaluation on a dataset +│ ├── run_baseline.py # Run vanilla RAG baseline +│ ├── run_hag.py # Run HAG +│ ├── analyze_energy.py # Analyze energy curves and convergence +│ └── visualize_trajectory.py # UMAP visualization of query trajectory +├── tests/ +│ ├── test_hopfield.py # Unit tests for Hopfield module +│ ├── test_memory_bank.py # Unit tests for memory bank +│ ├── test_retriever.py # Unit tests for both retrievers +│ ├── test_pipeline.py # Integration tests for pipeline +│ ├── test_metrics.py # Unit tests for metrics +│ └── test_energy.py # Unit tests for energy module +├── configs/ +│ ├── default.yaml # Default experiment config +│ ├── hotpotqa.yaml # HotpotQA-specific config +│ ├── musique.yaml # MuSiQue-specific config +│ └── 2wikimhqa.yaml # 2WikiMultiHopQA-specific config +└── notebooks/ + └── demo.ipynb # Quick demo notebook +``` + +--- + +## Module Specifications + +### 1. `hag/config.py` + +Use Python dataclasses. All hyperparameters in one place. + +```python +@dataclass +class HopfieldConfig: + beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval + max_iter: int = 5 # Maximum Hopfield iteration steps + conv_threshold: float = 1e-4 # Convergence: stop if ||q_{t+1} - q_t|| < threshold + top_k: int = 5 # Number of passages to retrieve from final attention weights + +@dataclass +class MemoryBankConfig: + embedding_dim: int = 768 # Must match encoder output dim + normalize: bool = True # L2-normalize embeddings in memory bank + +@dataclass +class EncoderConfig: + model_name: str = "facebook/contriever-msmarco" # Default encoder + max_length: int = 512 + batch_size: int = 64 + +@dataclass +class GeneratorConfig: + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + max_new_tokens: int = 128 + temperature: float = 0.0 # Greedy decoding for reproducibility + +@dataclass +class PipelineConfig: + hopfield: HopfieldConfig = field(default_factory=HopfieldConfig) + memory: MemoryBankConfig = field(default_factory=MemoryBankConfig) + encoder: EncoderConfig = field(default_factory=EncoderConfig) + generator: GeneratorConfig = field(default_factory=GeneratorConfig) + retriever_type: str = "hopfield" # "hopfield" or "faiss" +``` + +### 2. `hag/hopfield.py` — THE CORE MODULE + +This is the most important file. Implement the Modern Continuous Hopfield Network retrieval. + +```python +class HopfieldRetrieval: + """ + Modern Continuous Hopfield Network for memory retrieval. + + Given memory bank M ∈ ℝ^{d × N} and query q ∈ ℝ^d: + 1. Compute attention: α = softmax(β * Mᵀ @ q) + 2. Update query: q_new = M @ α + 3. Repeat until convergence or max_iter + + The energy function is: + E(q) = -1/β * log(Σ_i exp(β * q^T m_i)) + 1/2 * ||q||^2 + + Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease) + """ + + def __init__(self, config: HopfieldConfig): + ... + + def retrieve( + self, + query: torch.Tensor, # (d,) or (batch, d) + memory: torch.Tensor, # (d, N) — the memory bank + return_trajectory: bool = False, + return_energy: bool = False, + ) -> HopfieldResult: + """ + Run iterative Hopfield retrieval. + + Returns HopfieldResult with: + - attention_weights: (N,) or (batch, N) — final α_T + - converged_query: (d,) or (batch, d) — final q_T + - num_steps: int — actual steps taken + - trajectory: optional list of q_t at each step + - energy_curve: optional list of E(q_t) at each step + """ + ... + + def compute_energy( + self, + query: torch.Tensor, + memory: torch.Tensor, + ) -> torch.Tensor: + """ + E(q) = -1/β * log(Σ_i exp(β * qᵀ m_i)) + 1/2 * ||q||^2 + + This is the log-sum-exp energy of the Modern Hopfield Network. + """ + ... +``` + +**CRITICAL implementation details for `retrieve()`:** + +```python +def retrieve(self, query, memory, return_trajectory=False, return_energy=False): + # Ensure correct shapes + # query: (batch, d) — unsqueeze if needed + # memory: (d, N) + + q = query.clone() + trajectory = [q.clone()] if return_trajectory else None + energies = [self.compute_energy(q, memory)] if return_energy else None + + for t in range(self.config.max_iter): + # Core Hopfield update + logits = self.config.beta * (memory.T @ q.T).T # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + q_new = (memory @ alpha.T).T # (batch, d) + + # Check convergence + delta = torch.norm(q_new - q, dim=-1).max() + q = q_new + + if return_trajectory: + trajectory.append(q.clone()) + if return_energy: + energies.append(self.compute_energy(q, memory)) + + if delta < self.config.conv_threshold: + break + + # Final attention weights (recompute to ensure consistency) + logits = self.config.beta * (memory.T @ q.T).T + alpha = torch.softmax(logits, dim=-1) + + return HopfieldResult( + attention_weights=alpha, + converged_query=q, + num_steps=t + 1, + trajectory=trajectory, + energy_curve=energies, + ) +``` + +**Pay special attention to:** +- Numerical stability: use `torch.logsumexp` in energy computation +- Batch support: all operations should work for both single query and batched queries +- Memory efficiency: don't store trajectory/energy unless explicitly requested + +### 3. `hag/memory_bank.py` + +```python +class MemoryBank: + """ + Stores passage embeddings and provides lookup from indices back to text. + + The memory bank is M ∈ ℝ^{d × N} where each column is a passage embedding. + Also maintains a mapping from column index → passage text for final retrieval. + """ + + def __init__(self, config: MemoryBankConfig): + self.embeddings: torch.Tensor = None # (d, N) + self.passages: List[str] = [] # N passages + self.config = config + + def build_from_embeddings(self, embeddings: torch.Tensor, passages: List[str]): + """ + embeddings: (N, d) — note: input is (N, d), internally stored as (d, N) + passages: list of N passage strings + """ + if self.config.normalize: + embeddings = F.normalize(embeddings, dim=-1) + self.embeddings = embeddings.T # Store as (d, N) for efficient matmul + self.passages = passages + + def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: + """Given top-k indices, return corresponding passage texts.""" + ... + + def save(self, path: str): ... + def load(self, path: str): ... + + @property + def size(self) -> int: + return self.embeddings.shape[1] if self.embeddings is not None else 0 + + @property + def dim(self) -> int: + return self.embeddings.shape[0] if self.embeddings is not None else 0 +``` + +### 4. `hag/retriever_faiss.py` + +Vanilla RAG baseline retriever using FAISS. + +```python +class FAISSRetriever: + """ + Standard top-k retrieval using FAISS inner product search. + This is the baseline to compare against. + """ + + def __init__(self, top_k: int = 5): + self.index = None + self.passages: List[str] = [] + self.top_k = top_k + + def build_index(self, embeddings: np.ndarray, passages: List[str]): + """Build FAISS IndexFlatIP from (N, d) embeddings.""" + ... + + def retrieve(self, query: np.ndarray) -> RetrievalResult: + """ + query: (d,) or (batch, d) + Returns RetrievalResult with top_k passage texts, scores, and indices. + """ + ... +``` + +### 5. `hag/retriever_hopfield.py` + +```python +class HopfieldRetriever: + """ + Wraps HopfieldRetrieval + MemoryBank into a retriever interface. + + The bridge between Hopfield's continuous retrieval and the discrete + passage selection needed for LLM prompting. + """ + + def __init__(self, hopfield: HopfieldRetrieval, memory_bank: MemoryBank, top_k: int = 5): + ... + + def retrieve(self, query_embedding: torch.Tensor, return_analysis: bool = False) -> RetrievalResult: + """ + 1. Run Hopfield iterative retrieval → get attention weights α_T + 2. Take top_k indices from α_T + 3. Look up corresponding passage texts from memory bank + 4. Optionally return trajectory and energy for analysis + + Returns RetrievalResult with: + - passages: List[str] + - scores: attention weights for top-k + - indices: which memory slots were selected + - (optional) hopfield_result: full HopfieldResult for analysis + """ + ... +``` + +### 6. `hag/pipeline.py` + +```python +class RAGPipeline: + """ + End-to-end pipeline: query → encode → retrieve → generate. + Supports both FAISS (baseline) and Hopfield (ours) retrieval. + """ + + def __init__(self, config: PipelineConfig): + self.encoder = Encoder(config.encoder) + self.generator = Generator(config.generator) + + if config.retriever_type == "faiss": + self.retriever = FAISSRetriever(top_k=config.hopfield.top_k) + elif config.retriever_type == "hopfield": + hopfield = HopfieldRetrieval(config.hopfield) + self.retriever = HopfieldRetriever(hopfield, memory_bank, top_k=config.hopfield.top_k) + + def run(self, question: str) -> PipelineResult: + """ + Full pipeline: + 1. Encode question → query embedding + 2. Retrieve passages (FAISS or Hopfield) + 3. Format prompt with question + retrieved passages + 4. Generate answer with LLM + """ + ... + + def run_batch(self, questions: List[str]) -> List[PipelineResult]: + ... +``` + +### 7. `hag/metrics.py` + +```python +def exact_match(prediction: str, ground_truth: str) -> float: + """Normalized exact match. Lowercase, strip, remove articles/punctuation.""" + ... + +def f1_score(prediction: str, ground_truth: str) -> float: + """Token-level F1 between prediction and ground truth.""" + ... + +def retrieval_recall_at_k(retrieved_indices: List[int], gold_indices: List[int], k: int) -> float: + """What fraction of gold passages appear in the retrieved top-k?""" + ... + +def evaluate_dataset(results: List[PipelineResult], gold_answers: List[str]) -> Dict[str, float]: + """Compute aggregate metrics over a dataset. Returns dict with EM, F1, etc.""" + ... +``` + +### 8. `hag/energy.py` + +```python +def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]: + """Extract energy values at each iteration step.""" + ... + +def compute_energy_gap(energy_curve: List[float]) -> float: + """ + ΔE = E(q_0) - E(q_T) + Larger gap = more refinement happened. + """ + ... + +def verify_monotonic_decrease(energy_curve: List[float]) -> bool: + """Check that E(q_{t+1}) <= E(q_t) for all t. Should always be True.""" + ... + +def compute_attention_entropy(attention_weights: torch.Tensor) -> float: + """ + H(α) = -Σ_i α_i log α_i + Low entropy = sharp retrieval (confident). + High entropy = diffuse retrieval (uncertain). + """ + ... +``` + +--- + +## Data Types + +Define these as dataclasses or NamedTuples: + +```python +@dataclass +class HopfieldResult: + attention_weights: torch.Tensor # (batch, N) or (N,) + converged_query: torch.Tensor # (batch, d) or (d,) + num_steps: int + trajectory: Optional[List[torch.Tensor]] = None # list of q_t + energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t) + +@dataclass +class RetrievalResult: + passages: List[str] + scores: torch.Tensor # top-k scores + indices: torch.Tensor # top-k indices + hopfield_result: Optional[HopfieldResult] = None + +@dataclass +class PipelineResult: + question: str + answer: str + retrieved_passages: List[str] + retrieval_result: RetrievalResult +``` + +--- + +## Test Specifications + +**ALL TESTS MUST RUN ON CPU. NO GPU REQUIRED.** + +Use small synthetic data (random tensors) for unit tests. Do not download any models or datasets in tests. + +### `tests/test_hopfield.py` + +```python +class TestHopfieldRetrieval: + """Test the core Hopfield module with synthetic data on CPU.""" + + def setup_method(self): + """Create small synthetic memory bank and queries.""" + torch.manual_seed(42) + self.d = 64 # embedding dim + self.N = 100 # number of memories + self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N) + self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d) + self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6) + self.hopfield = HopfieldRetrieval(self.config) + + def test_output_shapes(self): + """attention_weights should be (1, N), converged_query should be (1, d).""" + result = self.hopfield.retrieve(self.query, self.memory) + assert result.attention_weights.shape == (1, self.N) + assert result.converged_query.shape == (1, self.d) + + def test_attention_weights_sum_to_one(self): + """softmax output must sum to 1.""" + result = self.hopfield.retrieve(self.query, self.memory) + assert torch.allclose(result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5) + + def test_attention_weights_non_negative(self): + """All attention weights must be >= 0.""" + result = self.hopfield.retrieve(self.query, self.memory) + assert (result.attention_weights >= 0).all() + + def test_energy_monotonic_decrease(self): + """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property.""" + result = self.hopfield.retrieve(self.query, self.memory, return_energy=True) + energies = [e.item() for e in result.energy_curve] + for i in range(len(energies) - 1): + assert energies[i+1] <= energies[i] + 1e-6, \ + f"Energy increased at step {i}: {energies[i]} -> {energies[i+1]}" + + def test_convergence(self): + """With enough iterations, query should converge (delta < threshold).""" + config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(self.query, self.memory, return_trajectory=True) + # Final two queries should be very close + q_last = result.trajectory[-1] + q_prev = result.trajectory[-2] + delta = torch.norm(q_last - q_prev) + assert delta < 1e-4, f"Did not converge: delta={delta}" + + def test_high_beta_sharp_retrieval(self): + """Higher β should produce sharper (lower entropy) attention.""" + low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5)) + high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5)) + + result_low = low_beta.retrieve(self.query, self.memory) + result_high = high_beta.retrieve(self.query, self.memory) + + entropy_low = -(result_low.attention_weights * result_low.attention_weights.log()).sum() + entropy_high = -(result_high.attention_weights * result_high.attention_weights.log()).sum() + + assert entropy_high < entropy_low, "Higher β should give lower entropy" + + def test_single_memory_converges_to_it(self): + """With N=1, retrieval should converge to the single memory.""" + single_memory = F.normalize(torch.randn(self.d, 1), dim=0) + result = self.hopfield.retrieve(self.query, single_memory) + assert torch.allclose(result.attention_weights, torch.ones(1, 1), atol=1e-5) + + def test_query_near_memory_retrieves_it(self): + """If query ≈ memory_i, attention should peak at index i.""" + target_idx = 42 + query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match + config = HopfieldConfig(beta=10.0, max_iter=5) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(query, self.memory) + top_idx = result.attention_weights.argmax(dim=-1).item() + assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}" + + def test_batch_retrieval(self): + """Should handle batch of queries.""" + batch_query = F.normalize(torch.randn(8, self.d), dim=-1) + result = self.hopfield.retrieve(batch_query, self.memory) + assert result.attention_weights.shape == (8, self.N) + assert result.converged_query.shape == (8, self.d) + + def test_iteration_refines_query(self): + """ + Multi-hop test: query starts far from target, iteration should bring it closer. + + Setup: memory has two clusters. Query is near cluster A but the "answer" + is in cluster B, reachable through an intermediate memory that bridges both. + After iteration, the query should drift toward cluster B. + """ + torch.manual_seed(0) + d = 32 + + # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...] + # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...] + # Bridge memory 10: between A and B + center_a = torch.zeros(d); center_a[0] = 1.0 + center_b = torch.zeros(d); center_b[1] = 1.0 + bridge = F.normalize(center_a + center_b, dim=0) + + memories = [] + for _ in range(5): + memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0)) + for _ in range(5): + memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0)) + memories.append(bridge) + + M = torch.stack(memories, dim=1) # (d, 11) + q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0) + + config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(q0, M, return_trajectory=True) + + # After iteration, query should have drifted: its dot product with center_b + # should be higher than the initial query's dot product with center_b + initial_sim_b = (q0.squeeze() @ center_b).item() + final_sim_b = (result.converged_query.squeeze() @ center_b).item() + assert final_sim_b > initial_sim_b, \ + f"Iteration should pull query toward cluster B: initial={initial_sim_b:.4f}, final={final_sim_b:.4f}" +``` + +### `tests/test_memory_bank.py` + +```python +class TestMemoryBank: + + def test_build_and_size(self): + """Build memory bank and verify size.""" + config = MemoryBankConfig(embedding_dim=64, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(100, 64) + passages = [f"passage {i}" for i in range(100)] + mb.build_from_embeddings(embeddings, passages) + assert mb.size == 100 + assert mb.dim == 64 + + def test_normalization(self): + """If normalize=True, stored embeddings should have unit norm.""" + config = MemoryBankConfig(embedding_dim=64, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(50, 64) * 5 # non-unit norm + mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) + norms = torch.norm(mb.embeddings, dim=0) + assert torch.allclose(norms, torch.ones(50), atol=1e-5) + + def test_get_passages_by_indices(self): + """Index → passage text lookup.""" + config = MemoryBankConfig(embedding_dim=64, normalize=False) + mb = MemoryBank(config) + passages = [f"passage {i}" for i in range(100)] + mb.build_from_embeddings(torch.randn(100, 64), passages) + result = mb.get_passages_by_indices(torch.tensor([0, 50, 99])) + assert result == ["passage 0", "passage 50", "passage 99"] + + def test_save_and_load(self, tmp_path): + """Save and reload memory bank, verify contents match.""" + config = MemoryBankConfig(embedding_dim=32, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(20, 32) + passages = [f"text {i}" for i in range(20)] + mb.build_from_embeddings(embeddings, passages) + + save_path = str(tmp_path / "mb.pt") + mb.save(save_path) + + mb2 = MemoryBank(config) + mb2.load(save_path) + assert mb2.size == 20 + assert mb2.passages == passages + assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6) +``` + +### `tests/test_retriever.py` + +```python +class TestHopfieldRetriever: + + def test_retrieves_correct_number_of_passages(self): + """top_k=3 should return exactly 3 passages.""" + ... + + def test_scores_are_sorted_descending(self): + """Returned scores should be in descending order.""" + ... + + def test_passages_match_indices(self): + """Returned passage texts should correspond to returned indices.""" + ... + +class TestFAISSRetriever: + + def test_retrieves_correct_number(self): + ... + + def test_nearest_neighbor_is_self(self): + """If query = passage_i's embedding, top-1 should be passage_i.""" + ... + +class TestRetrieverComparison: + + def test_same_top1_for_obvious_query(self): + """ + When query is very close to one memory, both FAISS and Hopfield + should agree on the top-1 result. + """ + ... +``` + +### `tests/test_metrics.py` + +```python +class TestMetrics: + + def test_exact_match_basic(self): + assert exact_match("Paris", "paris") == 1.0 + assert exact_match("Paris", "London") == 0.0 + + def test_exact_match_normalization(self): + """Should strip whitespace, lowercase, remove articles.""" + assert exact_match(" The Paris ", "paris") == 1.0 + assert exact_match("A dog", "dog") == 1.0 + + def test_f1_score_perfect(self): + assert f1_score("the cat sat", "the cat sat") == 1.0 + + def test_f1_score_partial(self): + score = f1_score("the cat sat on the mat", "the cat sat") + assert 0.5 < score < 1.0 + + def test_f1_score_no_overlap(self): + assert f1_score("hello world", "foo bar") == 0.0 + + def test_retrieval_recall(self): + assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2/3 +``` + +### `tests/test_energy.py` + +```python +class TestEnergy: + + def test_energy_computation_matches_formula(self): + """ + Manually compute E(q) = -1/β * log(Σ exp(β * qᵀ m_i)) + 1/2 * ||q||^2 + and verify it matches the module output. + """ + torch.manual_seed(42) + d, N = 16, 10 + memory = torch.randn(d, N) + query = torch.randn(1, d) + beta = 2.0 + + config = HopfieldConfig(beta=beta) + hopfield = HopfieldRetrieval(config) + computed = hopfield.compute_energy(query, memory) + + # Manual computation + logits = beta * (memory.T @ query.T).T # (1, N) + expected = -1.0/beta * torch.logsumexp(logits, dim=-1) + 0.5 * (query ** 2).sum(dim=-1) + assert torch.allclose(computed, expected, atol=1e-5) + + def test_verify_monotonic(self): + energies = [10.0, 8.0, 6.5, 6.4, 6.39] + assert verify_monotonic_decrease(energies) == True + + def test_verify_monotonic_fails(self): + energies = [10.0, 8.0, 9.0, 6.0] + assert verify_monotonic_decrease(energies) == False + + def test_attention_entropy(self): + """Uniform attention should have max entropy, one-hot should have zero.""" + uniform = torch.ones(100) / 100 + one_hot = torch.zeros(100); one_hot[0] = 1.0 + assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot) + assert compute_attention_entropy(one_hot) < 0.01 +``` + +### `tests/test_pipeline.py` + +```python +class TestPipeline: + """ + Integration tests using MOCK encoder and generator. + These test the wiring, not the model quality. + """ + + def test_pipeline_runs_end_to_end_hopfield(self): + """Mock encoder + mock LLM + Hopfield retriever → produces an answer string.""" + ... + + def test_pipeline_runs_end_to_end_faiss(self): + """Same but with FAISS retriever.""" + ... + + def test_pipeline_retrieval_result_included(self): + """PipelineResult should contain retrieval metadata.""" + ... +``` + +**For pipeline tests, use mock/fake encoder and generator:** +```python +class FakeEncoder: + def encode(self, text: str) -> torch.Tensor: + # Deterministic hash-based embedding for testing + torch.manual_seed(hash(text) % 2**32) + return F.normalize(torch.randn(1, 64), dim=-1) + +class FakeGenerator: + def generate(self, prompt: str) -> str: + return "mock answer" +``` + +--- + +## Dependencies + +```toml +[project] +name = "hag" +version = "0.1.0" +requires-python = ">=3.10" + +dependencies = [ + "torch>=2.0", + "numpy>=1.24", + "faiss-cpu>=1.7", + "transformers>=4.36", + "datasets>=2.14", + "pyyaml>=6.0", + "tqdm>=4.65", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-cov>=4.0", +] +eval = [ + "scikit-learn>=1.3", + "umap-learn>=0.5", + "matplotlib>=3.7", +] +``` + +--- + +## Implementation Order + +**Phase 1: Core module (do this first)** +1. `config.py` — all dataclasses +2. `hopfield.py` — the core Hopfield retrieval +3. `energy.py` — energy utilities +4. `tests/test_hopfield.py` — run and pass all tests +5. `tests/test_energy.py` — run and pass all tests + +**Phase 2: Memory and retrievers** +6. `memory_bank.py` +7. `retriever_hopfield.py` +8. `retriever_faiss.py` +9. `tests/test_memory_bank.py` +10. `tests/test_retriever.py` + +**Phase 3: Metrics and pipeline** +11. `metrics.py` +12. `pipeline.py` (with mock encoder/generator for now) +13. `tests/test_metrics.py` +14. `tests/test_pipeline.py` + +**Phase 4: Scripts (need GPU/models — implement structure only)** +15. `scripts/build_memory_bank.py` +16. `scripts/run_eval.py` +17. `scripts/run_baseline.py` +18. `scripts/run_hag.py` + +--- + +## Coding Conventions + +- **Type hints everywhere.** Every function signature must have full type hints. +- **Docstrings on every public method.** Include shapes in docstrings for tensor arguments. +- **Tensor shapes in comments.** Any line with a tensor operation should have a shape comment: `# (batch, N)` +- **No global state.** Everything parameterized through config dataclasses. +- **CPU-first.** Default device is CPU. GPU is opt-in via `.to(device)`. +- **Deterministic tests.** Always set `torch.manual_seed()` in tests. +- **Use `torch.no_grad()` in retrieval** when not training (phase 1). +- Use `logging` module, not print statements. + +--- + +## Key Mathematical References + +The Modern Continuous Hopfield Network update rule and energy function come from: + +> Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021) + +Key equations: +- **Update**: `q_{t+1} = M · softmax(β · Mᵀ · qₜ)` +- **Energy**: `E(q) = -1/β · log(Σᵢ exp(β · qᵀmᵢ)) + 1/2 · ||q||²` +- **Property**: Energy decreases monotonically: `E(q_{t+1}) ≤ E(q_t)` +- **Property**: Fixed points are local minima of E +- **Capacity**: Exponential in d (unlike classical Hopfield's ~0.14N) + +The connection to attention: the Hopfield update is mathematically equivalent to one step of cross-attention where queries attend to stored memories. But the ITERATIVE application (multiple steps) and the energy-based analysis are what differentiate this from standard attention. + +--- + +## What NOT To Do + +- Do NOT use `torch.cuda` in any test +- Do NOT download models/datasets in tests — use synthetic data and mocks +- Do NOT implement the encoder or generator with real models yet — provide the interface and use mocks in tests +- Do NOT add wandb, tensorboard, or other logging frameworks yet +- Do NOT try to optimize for speed yet — correctness first +- Do NOT implement the training loop (memory bank training) — that's phase 2 of the research, not this codebase yet diff --git a/README.md b/README.md new file mode 100644 index 0000000..7bd7a72 --- /dev/null +++ b/README.md @@ -0,0 +1,40 @@ +# HAG: Hopfield-Augmented Generation + +A retrieval-augmented generation system that replaces standard vector similarity search (FAISS top-k) with **iterative Modern Hopfield Network retrieval**. + +## Quick Start + +```bash +pip install -e ".[dev]" +pytest +``` + +## Core Idea + +Standard RAG retrieves passages once via FAISS similarity. HAG iteratively refines the query in embedding space using Modern Continuous Hopfield Network dynamics: + +``` +q_{t+1} = M * softmax(beta * M^T * q_t) +``` + +This enables implicit multi-hop reasoning through energy minimization with convergence guarantees. + +## Project Structure + +- `hag/hopfield.py` — Core Hopfield retrieval module +- `hag/memory_bank.py` — Passage embedding storage +- `hag/retriever_hopfield.py` — Hopfield-based retriever +- `hag/retriever_faiss.py` — FAISS baseline retriever +- `hag/pipeline.py` — End-to-end RAG/HAG pipeline +- `hag/metrics.py` — Evaluation metrics (EM, F1, retrieval recall) +- `hag/energy.py` — Energy analysis utilities +- `scripts/` — Evaluation and analysis scripts +- `tests/` — Comprehensive test suite + +## Running Tests + +```bash +pytest tests/ -v +``` + +All tests run on CPU with synthetic data — no GPU or model downloads required. @@ -0,0 +1,23 @@ +# TBD — GPU Ready 后待完成 + +## 1. Encoder 加 GPU 支持 +- `hag/encoder.py` 的 `Encoder` 类加 `device` 参数 +- `_load_model()` 时 `.to(device)`,推理时 tensor 搬到 GPU +- batch encode 结果 `.cpu()` 回传 + +## 2. 数据准备脚本 +- 新建 `scripts/prepare_hotpotqa.py` +- 从 HuggingFace 下载 HotpotQA,提取所有 context passages +- 去重,存成 memory bank 可消费的格式(JSONL 或直接构建 memory_bank.pt) +- 同时导出 questions + gold answers + gold passage indices 供评测用 + +## 3. 对比评测脚本 +- 新建 `scripts/run_comparison.py` +- 同一个 memory bank,分别跑 FAISS 和 Hopfield retriever +- 同一个 LLM 生成答案 +- 输出对比表格:EM / F1 / Retrieval Recall@k +- 支持不同 beta、max_iter 的 sweep + +## 4. run_eval.py 补 FAISS 路径 +- 当前 `run_eval.py` 只接 memory_bank,FAISS 模式需要额外从 memory_bank 构建 FAISS 索引 +- 加上自动转换逻辑 diff --git a/configs/2wikimhqa.yaml b/configs/2wikimhqa.yaml new file mode 100644 index 0000000..0048cf2 --- /dev/null +++ b/configs/2wikimhqa.yaml @@ -0,0 +1,22 @@ +hopfield: + beta: 2.0 + max_iter: 8 + conv_threshold: 1.0e-4 + top_k: 5 + +memory: + embedding_dim: 768 + normalize: true + +encoder: + model_name: "facebook/contriever-msmarco" + max_length: 512 + batch_size: 64 + +generator: + model_name: "meta-llama/Llama-3.1-8B-Instruct" + max_new_tokens: 128 + temperature: 0.0 + +retriever_type: "hopfield" +dataset: "2wikimultihopqa" diff --git a/configs/default.yaml b/configs/default.yaml new file mode 100644 index 0000000..cb76d34 --- /dev/null +++ b/configs/default.yaml @@ -0,0 +1,21 @@ +hopfield: + beta: 1.0 + max_iter: 5 + conv_threshold: 1.0e-4 + top_k: 5 + +memory: + embedding_dim: 768 + normalize: true + +encoder: + model_name: "facebook/contriever-msmarco" + max_length: 512 + batch_size: 64 + +generator: + model_name: "meta-llama/Llama-3.1-8B-Instruct" + max_new_tokens: 128 + temperature: 0.0 + +retriever_type: "hopfield" diff --git a/configs/hotpotqa.yaml b/configs/hotpotqa.yaml new file mode 100644 index 0000000..99ea9d7 --- /dev/null +++ b/configs/hotpotqa.yaml @@ -0,0 +1,22 @@ +hopfield: + beta: 2.0 + max_iter: 8 + conv_threshold: 1.0e-4 + top_k: 5 + +memory: + embedding_dim: 768 + normalize: true + +encoder: + model_name: "facebook/contriever-msmarco" + max_length: 512 + batch_size: 64 + +generator: + model_name: "meta-llama/Llama-3.1-8B-Instruct" + max_new_tokens: 128 + temperature: 0.0 + +retriever_type: "hopfield" +dataset: "hotpotqa" diff --git a/configs/musique.yaml b/configs/musique.yaml new file mode 100644 index 0000000..52dd7c4 --- /dev/null +++ b/configs/musique.yaml @@ -0,0 +1,22 @@ +hopfield: + beta: 2.0 + max_iter: 8 + conv_threshold: 1.0e-4 + top_k: 5 + +memory: + embedding_dim: 768 + normalize: true + +encoder: + model_name: "facebook/contriever-msmarco" + max_length: 512 + batch_size: 64 + +generator: + model_name: "meta-llama/Llama-3.1-8B-Instruct" + max_new_tokens: 128 + temperature: 0.0 + +retriever_type: "hopfield" +dataset: "musique" diff --git a/hag/__init__.py b/hag/__init__.py new file mode 100644 index 0000000..18496e9 --- /dev/null +++ b/hag/__init__.py @@ -0,0 +1,21 @@ +"""HAG: Hopfield-Augmented Generation.""" + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.datatypes import HopfieldResult, PipelineResult, RetrievalResult + +__all__ = [ + "HopfieldConfig", + "MemoryBankConfig", + "EncoderConfig", + "GeneratorConfig", + "PipelineConfig", + "HopfieldResult", + "RetrievalResult", + "PipelineResult", +] diff --git a/hag/config.py b/hag/config.py new file mode 100644 index 0000000..793e3a6 --- /dev/null +++ b/hag/config.py @@ -0,0 +1,50 @@ +"""All hyperparameters and configuration dataclasses for HAG.""" + +from dataclasses import dataclass, field + + +@dataclass +class HopfieldConfig: + """Configuration for the Hopfield retrieval module.""" + + beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval + max_iter: int = 5 # Maximum Hopfield iteration steps + conv_threshold: float = 1e-4 # Stop if ||q_{t+1} - q_t|| < threshold + top_k: int = 5 # Number of passages to retrieve from final attention weights + + +@dataclass +class MemoryBankConfig: + """Configuration for the memory bank.""" + + embedding_dim: int = 768 # Must match encoder output dim + normalize: bool = True # L2-normalize embeddings in memory bank + + +@dataclass +class EncoderConfig: + """Configuration for the query/passage encoder.""" + + model_name: str = "facebook/contriever-msmarco" + max_length: int = 512 + batch_size: int = 64 + + +@dataclass +class GeneratorConfig: + """Configuration for the LLM generator.""" + + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + max_new_tokens: int = 128 + temperature: float = 0.0 # Greedy decoding for reproducibility + + +@dataclass +class PipelineConfig: + """Top-level pipeline configuration.""" + + hopfield: HopfieldConfig = field(default_factory=HopfieldConfig) + memory: MemoryBankConfig = field(default_factory=MemoryBankConfig) + encoder: EncoderConfig = field(default_factory=EncoderConfig) + generator: GeneratorConfig = field(default_factory=GeneratorConfig) + retriever_type: str = "hopfield" # "hopfield" or "faiss" diff --git a/hag/datatypes.py b/hag/datatypes.py new file mode 100644 index 0000000..0f4254d --- /dev/null +++ b/hag/datatypes.py @@ -0,0 +1,37 @@ +"""Data types used across HAG modules.""" + +from dataclasses import dataclass, field +from typing import List, Optional + +import torch + + +@dataclass +class HopfieldResult: + """Result from Hopfield iterative retrieval.""" + + attention_weights: torch.Tensor # (batch, N) or (N,) + converged_query: torch.Tensor # (batch, d) or (d,) + num_steps: int + trajectory: Optional[List[torch.Tensor]] = None # list of q_t + energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t) + + +@dataclass +class RetrievalResult: + """Result from a retriever (FAISS or Hopfield).""" + + passages: List[str] + scores: torch.Tensor # top-k scores + indices: torch.Tensor # top-k indices + hopfield_result: Optional[HopfieldResult] = None + + +@dataclass +class PipelineResult: + """Result from the full RAG/HAG pipeline.""" + + question: str + answer: str + retrieved_passages: List[str] + retrieval_result: RetrievalResult diff --git a/hag/encoder.py b/hag/encoder.py new file mode 100644 index 0000000..7e103f3 --- /dev/null +++ b/hag/encoder.py @@ -0,0 +1,88 @@ +"""Wrapper for encoding queries and passages into embeddings.""" + +import logging +from typing import List, Union + +import torch + +from hag.config import EncoderConfig + +logger = logging.getLogger(__name__) + + +class Encoder: + """Encodes text queries/passages into dense embeddings. + + Uses a HuggingFace transformer model (e.g., Contriever). + For testing, use FakeEncoder instead. + """ + + def __init__(self, config: EncoderConfig) -> None: + self.config = config + self._tokenizer = None + self._model = None + + def _load_model(self) -> None: + """Lazy-load the model and tokenizer.""" + from transformers import AutoModel, AutoTokenizer + + logger.info("Loading encoder model: %s", self.config.model_name) + self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) + self._model = AutoModel.from_pretrained(self.config.model_name) + self._model.eval() + + @torch.no_grad() + def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: + """Encode text(s) into embedding(s). + + Args: + texts: single string or list of strings + + Returns: + (1, d) tensor for single input, (N, d) for list input. + """ + if self._model is None: + self._load_model() + + if isinstance(texts, str): + texts = [texts] + + inputs = self._tokenizer( + texts, + max_length=self.config.max_length, + padding=True, + truncation=True, + return_tensors="pt", + ) + outputs = self._model(**inputs) + # Mean pooling over token embeddings + embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d) + return embeddings + + +class FakeEncoder: + """Deterministic hash-based encoder for testing. No model download needed.""" + + def __init__(self, dim: int = 64) -> None: + self.dim = dim + + def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: + """Produce deterministic embeddings based on text hash. + + Args: + texts: single string or list of strings + + Returns: + (1, d) or (N, d) normalized tensor. + """ + if isinstance(texts, str): + texts = [texts] + + embeddings = [] + for text in texts: + torch.manual_seed(hash(text) % 2**32) + emb = torch.randn(1, self.dim) + embeddings.append(emb) + + result = torch.cat(embeddings, dim=0) # (N, d) + return torch.nn.functional.normalize(result, dim=-1) diff --git a/hag/energy.py b/hag/energy.py new file mode 100644 index 0000000..62a39e9 --- /dev/null +++ b/hag/energy.py @@ -0,0 +1,83 @@ +"""Energy computation and analysis utilities for Hopfield retrieval.""" + +import logging +from typing import List + +import torch + +from hag.datatypes import HopfieldResult + +logger = logging.getLogger(__name__) + + +def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]: + """Extract energy values at each iteration step. + + Args: + hopfield_result: result from HopfieldRetrieval.retrieve() with return_energy=True + + Returns: + List of energy values (floats) at each step. + """ + if hopfield_result.energy_curve is None: + return [] + return [e.item() if e.dim() == 0 else e.mean().item() for e in hopfield_result.energy_curve] + + +def compute_energy_gap(energy_curve: List[float]) -> float: + """Compute the energy gap: Delta_E = E(q_0) - E(q_T). + + Larger gap means more refinement happened during iteration. + + Args: + energy_curve: list of energy values at each step + + Returns: + Energy gap (float). Positive if energy decreased. + """ + if len(energy_curve) < 2: + return 0.0 + return energy_curve[0] - energy_curve[-1] + + +def verify_monotonic_decrease(energy_curve: List[float], tol: float = 1e-6) -> bool: + """Check that E(q_{t+1}) <= E(q_t) for all t. + + This should always be True for the Modern Hopfield Network. + + Args: + energy_curve: list of energy values at each step + tol: numerical tolerance for comparison + + Returns: + True if energy decreases monotonically (within tolerance). + """ + for i in range(len(energy_curve) - 1): + if energy_curve[i + 1] > energy_curve[i] + tol: + return False + return True + + +def compute_attention_entropy(attention_weights: torch.Tensor) -> float: + """Compute the entropy of attention weights. + + H(alpha) = -sum_i alpha_i * log(alpha_i) + + Low entropy = sharp retrieval (confident). + High entropy = diffuse retrieval (uncertain). + + Args: + attention_weights: (N,) or (batch, N) — attention distribution + + Returns: + Entropy value (float). Averaged over batch if batched. + """ + if attention_weights.dim() == 1: + attention_weights = attention_weights.unsqueeze(0) # (1, N) + + # Clamp to avoid log(0) + eps = 1e-12 + alpha = attention_weights.clamp(min=eps) + entropy = -(alpha * alpha.log()).sum(dim=-1) # (batch,) + + return entropy.mean().item() diff --git a/hag/generator.py b/hag/generator.py new file mode 100644 index 0000000..2142e0c --- /dev/null +++ b/hag/generator.py @@ -0,0 +1,87 @@ +"""LLM generation wrapper for producing answers from retrieved context.""" + +import logging +from typing import List + +from hag.config import GeneratorConfig + +logger = logging.getLogger(__name__) + +PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. + +Context: +{context} + +Question: {question} + +Answer:""" + + +class Generator: + """LLM-based answer generator. + + Uses a HuggingFace causal LM (e.g., Llama-3.1-8B-Instruct). + For testing, use FakeGenerator instead. + """ + + def __init__(self, config: GeneratorConfig) -> None: + self.config = config + self._tokenizer = None + self._model = None + + def _load_model(self) -> None: + """Lazy-load the model and tokenizer.""" + from transformers import AutoModelForCausalLM, AutoTokenizer + + logger.info("Loading generator model: %s", self.config.model_name) + self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) + self._model = AutoModelForCausalLM.from_pretrained( + self.config.model_name, + torch_dtype="auto", + ) + self._model.eval() + + def generate(self, question: str, passages: List[str]) -> str: + """Generate an answer given a question and retrieved passages. + + Args: + question: the user question + passages: list of retrieved passage texts + + Returns: + Generated answer string. + """ + if self._model is None: + self._load_model() + + context = "\n\n".join( + f"[{i+1}] {p}" for i, p in enumerate(passages) + ) + prompt = PROMPT_TEMPLATE.format(context=context, question=question) + + inputs = self._tokenizer(prompt, return_tensors="pt") + outputs = self._model.generate( + **inputs, + max_new_tokens=self.config.max_new_tokens, + temperature=self.config.temperature if self.config.temperature > 0 else None, + do_sample=self.config.temperature > 0, + ) + # Decode only the generated tokens (skip the prompt) + generated = outputs[0][inputs["input_ids"].shape[1]:] + return self._tokenizer.decode(generated, skip_special_tokens=True).strip() + + +class FakeGenerator: + """Deterministic mock generator for testing. No model download needed.""" + + def generate(self, question: str, passages: List[str]) -> str: + """Return a mock answer. + + Args: + question: the user question + passages: list of retrieved passages + + Returns: + Mock answer string. + """ + return "mock answer" diff --git a/hag/hopfield.py b/hag/hopfield.py new file mode 100644 index 0000000..287e4af --- /dev/null +++ b/hag/hopfield.py @@ -0,0 +1,124 @@ +"""Core Modern Continuous Hopfield Network retrieval module. + +Implements the iterative retrieval dynamics from: + Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021) + +Update rule: q_{t+1} = M * softmax(beta * M^T * q_t) +Energy: E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 +""" + +import logging +from typing import Optional + +import torch + +from hag.config import HopfieldConfig +from hag.datatypes import HopfieldResult + +logger = logging.getLogger(__name__) + + +class HopfieldRetrieval: + """Modern Continuous Hopfield Network for memory retrieval. + + Given memory bank M in R^{d x N} and query q in R^d: + 1. Compute attention: alpha = softmax(beta * M^T @ q) + 2. Update query: q_new = M @ alpha + 3. Repeat until convergence or max_iter + + The energy function is: + E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 + + Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease) + """ + + def __init__(self, config: HopfieldConfig) -> None: + self.config = config + + @torch.no_grad() + def retrieve( + self, + query: torch.Tensor, + memory: torch.Tensor, + return_trajectory: bool = False, + return_energy: bool = False, + ) -> HopfieldResult: + """Run iterative Hopfield retrieval. + + Args: + query: (d,) or (batch, d) — query embedding(s) + memory: (d, N) — memory bank of passage embeddings + return_trajectory: if True, store q_t at each step + return_energy: if True, store E(q_t) at each step + + Returns: + HopfieldResult with attention_weights, converged_query, num_steps, + and optionally trajectory and energy_curve. + """ + # Ensure query is 2D: (batch, d) + if query.dim() == 1: + query = query.unsqueeze(0) # (1, d) + + q = query.clone() # (batch, d) + + trajectory = [q.clone()] if return_trajectory else None + energies = [self.compute_energy(q, memory)] if return_energy else None + + num_steps = 0 + for t in range(self.config.max_iter): + # Core Hopfield update + logits = self.config.beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + q_new = alpha @ memory.T # (batch, d) + + # Check convergence + delta = torch.norm(q_new - q, dim=-1).max() # scalar + q = q_new + + if return_trajectory: + trajectory.append(q.clone()) + if return_energy: + energies.append(self.compute_energy(q, memory)) + + num_steps = t + 1 + + if delta < self.config.conv_threshold: + break + + # Final attention weights (recompute to ensure consistency) + logits = self.config.beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + + return HopfieldResult( + attention_weights=alpha, + converged_query=q, + num_steps=num_steps, + trajectory=trajectory, + energy_curve=energies, + ) + + def compute_energy( + self, + query: torch.Tensor, + memory: torch.Tensor, + ) -> torch.Tensor: + """Compute the Hopfield energy function. + + E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 + + Args: + query: (batch, d) or (d,) — query embedding(s) + memory: (d, N) — memory bank + + Returns: + Energy scalar or (batch,) tensor. + """ + if query.dim() == 1: + query = query.unsqueeze(0) # (1, d) + + logits = self.config.beta * (query @ memory) # (batch, N) + lse = torch.logsumexp(logits, dim=-1) # (batch,) + norm_sq = 0.5 * (query**2).sum(dim=-1) # (batch,) + energy = -1.0 / self.config.beta * lse + norm_sq # (batch,) + + return energy diff --git a/hag/memory_bank.py b/hag/memory_bank.py new file mode 100644 index 0000000..42dcc73 --- /dev/null +++ b/hag/memory_bank.py @@ -0,0 +1,93 @@ +"""Memory bank construction and management for passage embeddings.""" + +import logging +from typing import Dict, List, Optional + +import torch +import torch.nn.functional as F + +from hag.config import MemoryBankConfig + +logger = logging.getLogger(__name__) + + +class MemoryBank: + """Stores passage embeddings and provides lookup from indices back to text. + + The memory bank is M in R^{d x N} where each column is a passage embedding. + Also maintains a mapping from column index to passage text for final retrieval. + """ + + def __init__(self, config: MemoryBankConfig) -> None: + self.config = config + self.embeddings: Optional[torch.Tensor] = None # (d, N) + self.passages: List[str] = [] + + def build_from_embeddings( + self, embeddings: torch.Tensor, passages: List[str] + ) -> None: + """Build memory bank from precomputed embeddings. + + Args: + embeddings: (N, d) — passage embeddings (note: input is N x d) + passages: list of N passage strings + """ + assert embeddings.shape[0] == len(passages), ( + f"Number of embeddings ({embeddings.shape[0]}) must match " + f"number of passages ({len(passages)})" + ) + if self.config.normalize: + embeddings = F.normalize(embeddings, dim=-1) + self.embeddings = embeddings.T # Store as (d, N) for efficient matmul + self.passages = list(passages) + logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim) + + def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: + """Given top-k indices, return corresponding passage texts. + + Args: + indices: (k,) or (batch, k) tensor of integer indices + + Returns: + List of passage strings. + """ + flat_indices = indices.flatten().tolist() + return [self.passages[i] for i in flat_indices] + + def save(self, path: str) -> None: + """Save memory bank to disk. + + Args: + path: file path for saving (e.g., 'memory_bank.pt') + """ + data: Dict = { + "embeddings": self.embeddings, + "passages": self.passages, + "config": { + "embedding_dim": self.config.embedding_dim, + "normalize": self.config.normalize, + }, + } + torch.save(data, path) + logger.info("Saved memory bank to %s", path) + + def load(self, path: str) -> None: + """Load memory bank from disk. + + Args: + path: file path to load from + """ + data = torch.load(path, weights_only=False) + self.embeddings = data["embeddings"] + self.passages = data["passages"] + logger.info("Loaded memory bank from %s (%d passages)", path, self.size) + + @property + def size(self) -> int: + """Number of passages in the memory bank.""" + return self.embeddings.shape[1] if self.embeddings is not None else 0 + + @property + def dim(self) -> int: + """Embedding dimensionality.""" + return self.embeddings.shape[0] if self.embeddings is not None else 0 diff --git a/hag/metrics.py b/hag/metrics.py new file mode 100644 index 0000000..6a196df --- /dev/null +++ b/hag/metrics.py @@ -0,0 +1,113 @@ +"""Evaluation metrics for HAG: exact match, F1, retrieval recall.""" + +import logging +import re +import string +from collections import Counter +from typing import Dict, List + +from hag.datatypes import PipelineResult + +logger = logging.getLogger(__name__) + + +def _normalize_answer(text: str) -> str: + """Normalize answer text: lowercase, strip, remove articles and punctuation.""" + text = text.lower().strip() + # Remove articles + text = re.sub(r"\b(a|an|the)\b", " ", text) + # Remove punctuation + text = text.translate(str.maketrans("", "", string.punctuation)) + # Collapse whitespace + text = " ".join(text.split()) + return text + + +def exact_match(prediction: str, ground_truth: str) -> float: + """Normalized exact match. + + Args: + prediction: predicted answer string + ground_truth: gold answer string + + Returns: + 1.0 if normalized strings match, 0.0 otherwise. + """ + return float(_normalize_answer(prediction) == _normalize_answer(ground_truth)) + + +def f1_score(prediction: str, ground_truth: str) -> float: + """Token-level F1 between prediction and ground truth. + + Args: + prediction: predicted answer string + ground_truth: gold answer string + + Returns: + F1 score between 0.0 and 1.0. + """ + pred_tokens = _normalize_answer(prediction).split() + gold_tokens = _normalize_answer(ground_truth).split() + + if not pred_tokens and not gold_tokens: + return 1.0 + if not pred_tokens or not gold_tokens: + return 0.0 + + common = Counter(pred_tokens) & Counter(gold_tokens) + num_same = sum(common.values()) + + if num_same == 0: + return 0.0 + + precision = num_same / len(pred_tokens) + recall = num_same / len(gold_tokens) + f1 = 2 * precision * recall / (precision + recall) + return f1 + + +def retrieval_recall_at_k( + retrieved_indices: List[int], gold_indices: List[int], k: int +) -> float: + """What fraction of gold passages appear in the retrieved top-k? + + Args: + retrieved_indices: list of retrieved passage indices (top-k) + gold_indices: list of gold/relevant passage indices + k: number of retrieved passages to consider + + Returns: + Recall score between 0.0 and 1.0. + """ + if not gold_indices: + return 1.0 + retrieved_set = set(retrieved_indices[:k]) + gold_set = set(gold_indices) + return len(retrieved_set & gold_set) / len(gold_set) + + +def evaluate_dataset( + results: List[PipelineResult], gold_answers: List[str] +) -> Dict[str, float]: + """Compute aggregate metrics over a dataset. + + Args: + results: list of PipelineResult from the pipeline + gold_answers: list of gold answer strings + + Returns: + Dict with keys 'em', 'f1' containing averaged scores. + """ + assert len(results) == len(gold_answers) + + em_scores = [] + f1_scores = [] + + for result, gold in zip(results, gold_answers): + em_scores.append(exact_match(result.answer, gold)) + f1_scores.append(f1_score(result.answer, gold)) + + return { + "em": sum(em_scores) / len(em_scores) if em_scores else 0.0, + "f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0, + } diff --git a/hag/pipeline.py b/hag/pipeline.py new file mode 100644 index 0000000..1fefb84 --- /dev/null +++ b/hag/pipeline.py @@ -0,0 +1,107 @@ +"""End-to-end RAG/HAG pipeline: query -> encode -> retrieve -> generate.""" + +import logging +from typing import List, Optional, Protocol, Union + +import numpy as np +import torch + +from hag.config import PipelineConfig +from hag.datatypes import PipelineResult, RetrievalResult +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank +from hag.retriever_faiss import FAISSRetriever +from hag.retriever_hopfield import HopfieldRetriever + +logger = logging.getLogger(__name__) + + +class EncoderProtocol(Protocol): + """Protocol for encoder interface.""" + + def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: ... + + +class GeneratorProtocol(Protocol): + """Protocol for generator interface.""" + + def generate(self, question: str, passages: List[str]) -> str: ... + + +class RAGPipeline: + """End-to-end pipeline: query -> encode -> retrieve -> generate. + + Supports both FAISS (baseline) and Hopfield (ours) retrieval. + """ + + def __init__( + self, + config: PipelineConfig, + encoder: EncoderProtocol, + generator: GeneratorProtocol, + memory_bank: Optional[MemoryBank] = None, + faiss_retriever: Optional[FAISSRetriever] = None, + ) -> None: + self.config = config + self.encoder = encoder + self.generator = generator + + if config.retriever_type == "faiss": + assert faiss_retriever is not None, "FAISSRetriever required for faiss mode" + self.retriever_type = "faiss" + self.faiss_retriever = faiss_retriever + self.hopfield_retriever: Optional[HopfieldRetriever] = None + elif config.retriever_type == "hopfield": + assert memory_bank is not None, "MemoryBank required for hopfield mode" + hopfield = HopfieldRetrieval(config.hopfield) + self.retriever_type = "hopfield" + self.hopfield_retriever = HopfieldRetriever( + hopfield, memory_bank, top_k=config.hopfield.top_k + ) + self.faiss_retriever = None + else: + raise ValueError(f"Unknown retriever_type: {config.retriever_type}") + + def run(self, question: str) -> PipelineResult: + """Run the full pipeline on a single question. + + 1. Encode question -> query embedding + 2. Retrieve passages (FAISS or Hopfield) + 3. Generate answer with LLM + + Args: + question: input question string + + Returns: + PipelineResult with answer and retrieval metadata. + """ + # Encode + query_emb = self.encoder.encode(question) # (1, d) + + # Retrieve + if self.retriever_type == "hopfield": + retrieval_result = self.hopfield_retriever.retrieve(query_emb) + else: + query_np = query_emb.detach().numpy().astype(np.float32) + retrieval_result = self.faiss_retriever.retrieve(query_np) + + # Generate + answer = self.generator.generate(question, retrieval_result.passages) + + return PipelineResult( + question=question, + answer=answer, + retrieved_passages=retrieval_result.passages, + retrieval_result=retrieval_result, + ) + + def run_batch(self, questions: List[str]) -> List[PipelineResult]: + """Run pipeline on a batch of questions. + + Args: + questions: list of question strings + + Returns: + List of PipelineResult, one per question. + """ + return [self.run(q) for q in questions] diff --git a/hag/retriever_faiss.py b/hag/retriever_faiss.py new file mode 100644 index 0000000..cd54a85 --- /dev/null +++ b/hag/retriever_faiss.py @@ -0,0 +1,73 @@ +"""Baseline FAISS top-k retriever for vanilla RAG.""" + +import logging +from typing import List, Optional + +import faiss +import numpy as np +import torch + +from hag.datatypes import RetrievalResult + +logger = logging.getLogger(__name__) + + +class FAISSRetriever: + """Standard top-k retrieval using FAISS inner product search. + + This is the baseline to compare against Hopfield retrieval. + """ + + def __init__(self, top_k: int = 5) -> None: + self.index: Optional[faiss.IndexFlatIP] = None + self.passages: List[str] = [] + self.top_k = top_k + + def build_index(self, embeddings: np.ndarray, passages: List[str]) -> None: + """Build FAISS IndexFlatIP from embeddings. + + Args: + embeddings: (N, d) numpy array of passage embeddings + passages: list of N passage strings + """ + assert embeddings.shape[0] == len(passages) + d = embeddings.shape[1] + self.index = faiss.IndexFlatIP(d) + # Normalize for cosine similarity via inner product + faiss.normalize_L2(embeddings) + self.index.add(embeddings) + self.passages = list(passages) + logger.info("Built FAISS index with %d passages, dim=%d", len(passages), d) + + def retrieve(self, query: np.ndarray) -> RetrievalResult: + """Retrieve top-k passages for a query. + + Args: + query: (d,) or (batch, d) numpy array + + Returns: + RetrievalResult with passages, scores, and indices. + """ + assert self.index is not None, "Index not built. Call build_index first." + + if query.ndim == 1: + query = query.reshape(1, -1) # (1, d) + + # Normalize query for cosine similarity + query_copy = query.copy() + faiss.normalize_L2(query_copy) + + scores, indices = self.index.search(query_copy, self.top_k) # (batch, k) + + # Flatten for single query case + if scores.shape[0] == 1: + scores = scores[0] # (k,) + indices = indices[0] # (k,) + + passages = [self.passages[i] for i in indices.flatten().tolist()] + + return RetrievalResult( + passages=passages, + scores=torch.from_numpy(scores).float(), + indices=torch.from_numpy(indices).long(), + ) diff --git a/hag/retriever_hopfield.py b/hag/retriever_hopfield.py new file mode 100644 index 0000000..1cb6968 --- /dev/null +++ b/hag/retriever_hopfield.py @@ -0,0 +1,77 @@ +"""Hopfield-based retriever wrapping HopfieldRetrieval + MemoryBank.""" + +import logging +from typing import List + +import torch + +from hag.datatypes import RetrievalResult +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank + +logger = logging.getLogger(__name__) + + +class HopfieldRetriever: + """Wraps HopfieldRetrieval + MemoryBank into a retriever interface. + + The bridge between Hopfield's continuous retrieval and the discrete + passage selection needed for LLM prompting. + """ + + def __init__( + self, + hopfield: HopfieldRetrieval, + memory_bank: MemoryBank, + top_k: int = 5, + ) -> None: + self.hopfield = hopfield + self.memory_bank = memory_bank + self.top_k = top_k + + def retrieve( + self, + query_embedding: torch.Tensor, + return_analysis: bool = False, + ) -> RetrievalResult: + """Retrieve top-k passages using iterative Hopfield retrieval. + + 1. Run Hopfield iterative retrieval -> get attention weights alpha_T + 2. Take top_k indices from alpha_T + 3. Look up corresponding passage texts from memory bank + 4. Optionally return trajectory and energy for analysis + + Args: + query_embedding: (d,) or (batch, d) — query embedding + return_analysis: if True, include full HopfieldResult + + Returns: + RetrievalResult with passages, scores, indices, and optionally + the full hopfield_result. + """ + hopfield_result = self.hopfield.retrieve( + query_embedding, + self.memory_bank.embeddings, + return_trajectory=return_analysis, + return_energy=return_analysis, + ) + + alpha = hopfield_result.attention_weights # (batch, N) or (1, N) + + # Get top-k indices and scores + k = min(self.top_k, alpha.shape[-1]) + scores, indices = torch.topk(alpha, k, dim=-1) # (batch, k) + + # Flatten for single-query case + if scores.shape[0] == 1: + scores = scores.squeeze(0) # (k,) + indices = indices.squeeze(0) # (k,) + + passages = self.memory_bank.get_passages_by_indices(indices) + + return RetrievalResult( + passages=passages, + scores=scores, + indices=indices, + hopfield_result=hopfield_result if return_analysis else None, + ) diff --git a/notebooks/demo.ipynb b/notebooks/demo.ipynb new file mode 100644 index 0000000..caace5b --- /dev/null +++ b/notebooks/demo.ipynb @@ -0,0 +1,62 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# HAG: Hopfield-Augmented Generation Demo\n", + "\n", + "This notebook demonstrates the core Hopfield retrieval mechanism with synthetic data." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn.functional as F\n", + "\n", + "from hag.config import HopfieldConfig\n", + "from hag.hopfield import HopfieldRetrieval\n", + "from hag.energy import compute_energy_curve, verify_monotonic_decrease, compute_attention_entropy" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create synthetic memory bank and query\n", + "torch.manual_seed(42)\n", + "d, N = 64, 200\n", + "memory = F.normalize(torch.randn(d, N), dim=0)\n", + "query = F.normalize(torch.randn(1, d), dim=-1)\n", + "\n", + "# Run Hopfield retrieval with different beta values\n", + "for beta in [0.5, 1.0, 2.0, 5.0]:\n", + " config = HopfieldConfig(beta=beta, max_iter=20, conv_threshold=1e-6)\n", + " hopfield = HopfieldRetrieval(config)\n", + " result = hopfield.retrieve(query, memory, return_energy=True)\n", + " curve = compute_energy_curve(result)\n", + " entropy = compute_attention_entropy(result.attention_weights)\n", + " print(f'beta={beta}: steps={result.num_steps}, entropy={entropy:.4f}, monotonic={verify_monotonic_decrease(curve)}')" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python", + "version": "3.10.0" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..24f7806 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,36 @@ +[build-system] +requires = ["setuptools>=68.0", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "hag" +version = "0.1.0" +requires-python = ">=3.10" +description = "Hopfield-Augmented Generation: RAG with iterative Modern Hopfield Network retrieval" + +dependencies = [ + "torch>=2.0", + "numpy>=1.24", + "faiss-cpu>=1.7", + "transformers>=4.36", + "datasets>=2.14", + "pyyaml>=6.0", + "tqdm>=4.65", +] + +[project.optional-dependencies] +dev = [ + "pytest>=7.0", + "pytest-cov>=4.0", +] +eval = [ + "scikit-learn>=1.3", + "umap-learn>=0.5", + "matplotlib>=3.7", +] + +[tool.setuptools.packages.find] +include = ["hag*"] + +[tool.pytest.ini_options] +testpaths = ["tests"] diff --git a/scripts/analyze_energy.py b/scripts/analyze_energy.py new file mode 100644 index 0000000..fd044a4 --- /dev/null +++ b/scripts/analyze_energy.py @@ -0,0 +1,78 @@ +"""Analyze energy curves and convergence properties of Hopfield retrieval. + +Usage: + python scripts/analyze_energy.py --config configs/default.yaml --memory-bank data/memory_bank.pt --questions data/questions.jsonl --output energy_analysis.json +""" + +import argparse +import json +import logging + +import torch +import yaml + +from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.energy import ( + compute_attention_entropy, + compute_energy_curve, + compute_energy_gap, + verify_monotonic_decrease, +) +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Analyze Hopfield energy curves") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--questions", type=str, required=True) + parser.add_argument("--output", type=str, default="energy_analysis.json") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + + # Load memory bank + mb = MemoryBank(memory_config) + mb.load(args.memory_bank) + + # Load questions + with open(args.questions) as f: + questions = [json.loads(line)["question"] for line in f] + + encoder = Encoder(encoder_config) + hopfield = HopfieldRetrieval(hopfield_config) + + analyses = [] + for q in questions: + query_emb = encoder.encode(q) # (1, d) + result = hopfield.retrieve( + query_emb, mb.embeddings, return_energy=True, return_trajectory=True + ) + + curve = compute_energy_curve(result) + analyses.append({ + "question": q, + "energy_curve": curve, + "energy_gap": compute_energy_gap(curve), + "monotonic": verify_monotonic_decrease(curve), + "num_steps": result.num_steps, + "attention_entropy": compute_attention_entropy(result.attention_weights), + }) + + with open(args.output, "w") as f: + json.dump(analyses, f, indent=2) + logger.info("Energy analysis saved to %s (%d questions)", args.output, len(analyses)) + + +if __name__ == "__main__": + main() diff --git a/scripts/build_memory_bank.py b/scripts/build_memory_bank.py new file mode 100644 index 0000000..2aff828 --- /dev/null +++ b/scripts/build_memory_bank.py @@ -0,0 +1,72 @@ +"""Offline script: encode corpus passages into a memory bank. + +Usage: + python scripts/build_memory_bank.py --config configs/default.yaml --corpus data/corpus.jsonl --output data/memory_bank.pt +""" + +import argparse +import json +import logging + +import torch +import yaml +from tqdm import tqdm + +from hag.config import EncoderConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.memory_bank import MemoryBank + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def load_corpus(path: str) -> list[str]: + """Load passages from a JSONL file (one JSON object per line with 'text' field).""" + passages = [] + with open(path) as f: + for line in f: + obj = json.loads(line) + passages.append(obj["text"]) + return passages + + +def main() -> None: + parser = argparse.ArgumentParser(description="Build memory bank from corpus") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--corpus", type=str, required=True) + parser.add_argument("--output", type=str, required=True) + parser.add_argument("--device", type=str, default="cpu") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + + # Load corpus + logger.info("Loading corpus from %s", args.corpus) + passages = load_corpus(args.corpus) + logger.info("Loaded %d passages", len(passages)) + + # Encode passages in batches + encoder = Encoder(encoder_config) + all_embeddings = [] + + for i in tqdm(range(0, len(passages), encoder_config.batch_size), desc="Encoding"): + batch = passages[i : i + encoder_config.batch_size] + emb = encoder.encode(batch) # (batch_size, d) + all_embeddings.append(emb.cpu()) + + embeddings = torch.cat(all_embeddings, dim=0) # (N, d) + logger.info("Encoded %d passages -> embeddings shape: %s", len(passages), embeddings.shape) + + # Build and save memory bank + mb = MemoryBank(memory_config) + mb.build_from_embeddings(embeddings, passages) + mb.save(args.output) + logger.info("Memory bank saved to %s", args.output) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_baseline.py b/scripts/run_baseline.py new file mode 100644 index 0000000..74c4710 --- /dev/null +++ b/scripts/run_baseline.py @@ -0,0 +1,75 @@ +"""Run vanilla RAG baseline with FAISS retrieval. + +Usage: + python scripts/run_baseline.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" +""" + +import argparse +import logging + +import torch +import yaml + +from hag.config import EncoderConfig, GeneratorConfig, HopfieldConfig, PipelineConfig +from hag.encoder import Encoder +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.retriever_faiss import FAISSRetriever +from hag.pipeline import RAGPipeline + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run vanilla RAG baseline") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--question", type=str, required=True) + parser.add_argument("--top-k", type=int, default=5) + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + # Override retriever type to faiss + pipeline_config = PipelineConfig( + hopfield=HopfieldConfig(**{**cfg.get("hopfield", {}), "top_k": args.top_k}), + encoder=EncoderConfig(**cfg.get("encoder", {})), + generator=GeneratorConfig(**cfg.get("generator", {})), + retriever_type="faiss", + ) + + # Load memory bank to get embeddings for FAISS + from hag.config import MemoryBankConfig + + mb = MemoryBank(MemoryBankConfig(**cfg.get("memory", {}))) + mb.load(args.memory_bank) + + # Build FAISS index from memory bank embeddings + import numpy as np + + embeddings_np = mb.embeddings.T.numpy().astype(np.float32) # (N, d) + faiss_ret = FAISSRetriever(top_k=args.top_k) + faiss_ret.build_index(embeddings_np, mb.passages) + + encoder = Encoder(pipeline_config.encoder) + generator = Generator(pipeline_config.generator) + + pipeline = RAGPipeline( + config=pipeline_config, + encoder=encoder, + generator=generator, + faiss_retriever=faiss_ret, + ) + + result = pipeline.run(args.question) + print(f"\nQuestion: {result.question}") + print(f"Answer: {result.answer}") + print(f"\nRetrieved passages:") + for i, p in enumerate(result.retrieved_passages): + print(f" [{i+1}] {p[:200]}...") + + +if __name__ == "__main__": + main() diff --git a/scripts/run_eval.py b/scripts/run_eval.py new file mode 100644 index 0000000..713b3c2 --- /dev/null +++ b/scripts/run_eval.py @@ -0,0 +1,90 @@ +"""Run evaluation on a dataset with either FAISS or Hopfield retrieval. + +Usage: + python scripts/run_eval.py --config configs/hotpotqa.yaml --memory-bank data/memory_bank.pt --dataset hotpotqa --split validation --max-samples 500 +""" + +import argparse +import json +import logging + +import yaml + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import Encoder +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.metrics import evaluate_dataset +from hag.pipeline import RAGPipeline +from hag.retriever_faiss import FAISSRetriever + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run HAG/RAG evaluation") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--dataset", type=str, default="hotpotqa") + parser.add_argument("--split", type=str, default="validation") + parser.add_argument("--max-samples", type=int, default=500) + parser.add_argument("--output", type=str, default="results.json") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + pipeline_config = PipelineConfig( + hopfield=HopfieldConfig(**cfg.get("hopfield", {})), + memory=MemoryBankConfig(**cfg.get("memory", {})), + encoder=EncoderConfig(**cfg.get("encoder", {})), + generator=GeneratorConfig(**cfg.get("generator", {})), + retriever_type=cfg.get("retriever_type", "hopfield"), + ) + + # Load memory bank + mb = MemoryBank(pipeline_config.memory) + mb.load(args.memory_bank) + + # Build pipeline + encoder = Encoder(pipeline_config.encoder) + generator = Generator(pipeline_config.generator) + pipeline = RAGPipeline( + config=pipeline_config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + + # Load dataset + from datasets import load_dataset + + logger.info("Loading dataset: %s / %s", args.dataset, args.split) + ds = load_dataset(args.dataset, split=args.split) + if args.max_samples and len(ds) > args.max_samples: + ds = ds.select(range(args.max_samples)) + + questions = [ex["question"] for ex in ds] + gold_answers = [ex["answer"] for ex in ds] + + # Run evaluation + logger.info("Running evaluation on %d questions", len(questions)) + results = pipeline.run_batch(questions) + metrics = evaluate_dataset(results, gold_answers) + + logger.info("Results: %s", metrics) + + with open(args.output, "w") as f: + json.dump(metrics, f, indent=2) + logger.info("Results saved to %s", args.output) + + +if __name__ == "__main__": + main() diff --git a/scripts/run_hag.py b/scripts/run_hag.py new file mode 100644 index 0000000..4cacd1a --- /dev/null +++ b/scripts/run_hag.py @@ -0,0 +1,79 @@ +"""Run HAG (Hopfield-Augmented Generation) on a question. + +Usage: + python scripts/run_hag.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" +""" + +import argparse +import logging + +import yaml + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import Encoder +from hag.generator import Generator +from hag.memory_bank import MemoryBank +from hag.pipeline import RAGPipeline + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Run HAG") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--question", type=str, required=True) + parser.add_argument("--beta", type=float, default=None) + parser.add_argument("--max-iter", type=int, default=None) + parser.add_argument("--top-k", type=int, default=None) + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + hopfield_cfg = cfg.get("hopfield", {}) + if args.beta is not None: + hopfield_cfg["beta"] = args.beta + if args.max_iter is not None: + hopfield_cfg["max_iter"] = args.max_iter + if args.top_k is not None: + hopfield_cfg["top_k"] = args.top_k + + pipeline_config = PipelineConfig( + hopfield=HopfieldConfig(**hopfield_cfg), + memory=MemoryBankConfig(**cfg.get("memory", {})), + encoder=EncoderConfig(**cfg.get("encoder", {})), + generator=GeneratorConfig(**cfg.get("generator", {})), + retriever_type="hopfield", + ) + + mb = MemoryBank(pipeline_config.memory) + mb.load(args.memory_bank) + + encoder = Encoder(pipeline_config.encoder) + generator = Generator(pipeline_config.generator) + + pipeline = RAGPipeline( + config=pipeline_config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + + result = pipeline.run(args.question) + print(f"\nQuestion: {result.question}") + print(f"Answer: {result.answer}") + print(f"\nRetrieved passages:") + for i, p in enumerate(result.retrieved_passages): + print(f" [{i+1}] {p[:200]}...") + + +if __name__ == "__main__": + main() diff --git a/scripts/visualize_trajectory.py b/scripts/visualize_trajectory.py new file mode 100644 index 0000000..e4ba902 --- /dev/null +++ b/scripts/visualize_trajectory.py @@ -0,0 +1,80 @@ +"""UMAP visualization of query trajectory in Hopfield retrieval. + +Usage: + python scripts/visualize_trajectory.py --config configs/default.yaml --memory-bank data/memory_bank.pt --question "Who wrote Hamlet?" +""" + +import argparse +import logging + +import numpy as np +import torch +import yaml + +from hag.config import EncoderConfig, HopfieldConfig, MemoryBankConfig +from hag.encoder import Encoder +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +def main() -> None: + parser = argparse.ArgumentParser(description="Visualize Hopfield query trajectory") + parser.add_argument("--config", type=str, default="configs/default.yaml") + parser.add_argument("--memory-bank", type=str, required=True) + parser.add_argument("--question", type=str, required=True) + parser.add_argument("--output", type=str, default="trajectory.png") + args = parser.parse_args() + + with open(args.config) as f: + cfg = yaml.safe_load(f) + + hopfield_config = HopfieldConfig(**cfg.get("hopfield", {})) + memory_config = MemoryBankConfig(**cfg.get("memory", {})) + encoder_config = EncoderConfig(**cfg.get("encoder", {})) + + mb = MemoryBank(memory_config) + mb.load(args.memory_bank) + + encoder = Encoder(encoder_config) + hopfield = HopfieldRetrieval(hopfield_config) + + query_emb = encoder.encode(args.question) # (1, d) + result = hopfield.retrieve( + query_emb, mb.embeddings, return_trajectory=True + ) + + # Gather all points for UMAP: memories + trajectory + memories_np = mb.embeddings.T.numpy() # (N, d) + trajectory_np = np.stack([q.squeeze().numpy() for q in result.trajectory]) # (T+1, d) + all_points = np.concatenate([memories_np, trajectory_np], axis=0) + + # UMAP projection + import umap + + reducer = umap.UMAP(n_components=2, random_state=42) + projected = reducer.fit_transform(all_points) + + mem_proj = projected[: len(memories_np)] + traj_proj = projected[len(memories_np) :] + + # Plot + import matplotlib.pyplot as plt + + fig, ax = plt.subplots(figsize=(10, 8)) + ax.scatter(mem_proj[:, 0], mem_proj[:, 1], c="lightgray", s=10, alpha=0.5, label="Memories") + ax.plot(traj_proj[:, 0], traj_proj[:, 1], "b-o", markersize=6, label="Query trajectory") + ax.scatter(traj_proj[0, 0], traj_proj[0, 1], c="green", s=100, zorder=5, label="q_0 (start)") + ax.scatter(traj_proj[-1, 0], traj_proj[-1, 1], c="red", s=100, zorder=5, label="q_T (final)") + + ax.set_title(f"Hopfield Query Trajectory ({result.num_steps} steps)") + ax.legend() + plt.tight_layout() + plt.savefig(args.output, dpi=150) + logger.info("Trajectory visualization saved to %s", args.output) + + +if __name__ == "__main__": + main() diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/tests/__init__.py diff --git a/tests/test_energy.py b/tests/test_energy.py new file mode 100644 index 0000000..256f2d8 --- /dev/null +++ b/tests/test_energy.py @@ -0,0 +1,80 @@ +"""Unit tests for the energy computation module.""" + +import torch + +from hag.config import HopfieldConfig +from hag.energy import ( + compute_attention_entropy, + compute_energy_curve, + compute_energy_gap, + verify_monotonic_decrease, +) +from hag.hopfield import HopfieldRetrieval + + +class TestEnergy: + """Tests for energy computation and analysis utilities.""" + + def test_energy_computation_matches_formula(self) -> None: + """Manually compute E(q) and verify it matches the module output.""" + torch.manual_seed(42) + d, N = 16, 10 + memory = torch.randn(d, N) + query = torch.randn(1, d) + beta = 2.0 + + config = HopfieldConfig(beta=beta) + hopfield = HopfieldRetrieval(config) + computed = hopfield.compute_energy(query, memory) + + # Manual computation: + # E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 + logits = beta * (query @ memory) # (1, N) + expected = ( + -1.0 / beta * torch.logsumexp(logits, dim=-1) + + 0.5 * (query**2).sum(dim=-1) + ) + assert torch.allclose(computed, expected, atol=1e-5) + + def test_verify_monotonic(self) -> None: + """Monotonically decreasing energy curve should pass.""" + energies = [10.0, 8.0, 6.5, 6.4, 6.39] + assert verify_monotonic_decrease(energies) is True + + def test_verify_monotonic_fails(self) -> None: + """Non-monotonic energy curve should fail.""" + energies = [10.0, 8.0, 9.0, 6.0] + assert verify_monotonic_decrease(energies) is False + + def test_attention_entropy(self) -> None: + """Uniform attention should have max entropy, one-hot should have zero.""" + uniform = torch.ones(100) / 100 + one_hot = torch.zeros(100) + one_hot[0] = 1.0 + assert compute_attention_entropy(uniform) > compute_attention_entropy(one_hot) + assert compute_attention_entropy(one_hot) < 0.01 + + def test_energy_curve_extraction(self) -> None: + """Energy curve should be extractable from a HopfieldResult.""" + torch.manual_seed(42) + d, N = 32, 50 + memory = torch.randn(d, N) + query = torch.randn(1, d) + config = HopfieldConfig(beta=1.0, max_iter=5) + hopfield = HopfieldRetrieval(config) + + result = hopfield.retrieve(query, memory, return_energy=True) + curve = compute_energy_curve(result) + assert len(curve) > 1 + assert all(isinstance(v, float) for v in curve) + + def test_energy_gap(self) -> None: + """Energy gap should be positive when energy decreases.""" + energies = [10.0, 8.0, 6.0] + gap = compute_energy_gap(energies) + assert gap == 4.0 + + def test_energy_gap_single_point(self) -> None: + """Energy gap with fewer than 2 points should be 0.""" + assert compute_energy_gap([5.0]) == 0.0 + assert compute_energy_gap([]) == 0.0 diff --git a/tests/test_hopfield.py b/tests/test_hopfield.py new file mode 100644 index 0000000..246b120 --- /dev/null +++ b/tests/test_hopfield.py @@ -0,0 +1,147 @@ +"""Unit tests for the core Hopfield retrieval module.""" + +import torch +import torch.nn.functional as F + +from hag.config import HopfieldConfig +from hag.hopfield import HopfieldRetrieval + + +class TestHopfieldRetrieval: + """Test the core Hopfield module with synthetic data on CPU.""" + + def setup_method(self) -> None: + """Create small synthetic memory bank and queries.""" + torch.manual_seed(42) + self.d = 64 # embedding dim + self.N = 100 # number of memories + self.memory = F.normalize(torch.randn(self.d, self.N), dim=0) # (d, N) + self.query = F.normalize(torch.randn(1, self.d), dim=-1) # (1, d) + self.config = HopfieldConfig(beta=1.0, max_iter=10, conv_threshold=1e-6) + self.hopfield = HopfieldRetrieval(self.config) + + def test_output_shapes(self) -> None: + """attention_weights should be (1, N), converged_query should be (1, d).""" + result = self.hopfield.retrieve(self.query, self.memory) + assert result.attention_weights.shape == (1, self.N) + assert result.converged_query.shape == (1, self.d) + + def test_attention_weights_sum_to_one(self) -> None: + """softmax output must sum to 1.""" + result = self.hopfield.retrieve(self.query, self.memory) + assert torch.allclose( + result.attention_weights.sum(dim=-1), torch.ones(1), atol=1e-5 + ) + + def test_attention_weights_non_negative(self) -> None: + """All attention weights must be >= 0.""" + result = self.hopfield.retrieve(self.query, self.memory) + assert (result.attention_weights >= 0).all() + + def test_energy_monotonic_decrease(self) -> None: + """E(q_{t+1}) <= E(q_t) for all t. This is THE key theoretical property.""" + result = self.hopfield.retrieve( + self.query, self.memory, return_energy=True + ) + energies = [e.item() for e in result.energy_curve] + for i in range(len(energies) - 1): + assert energies[i + 1] <= energies[i] + 1e-6, ( + f"Energy increased at step {i}: {energies[i]} -> {energies[i + 1]}" + ) + + def test_convergence(self) -> None: + """With enough iterations, query should converge (delta < threshold).""" + config = HopfieldConfig(beta=2.0, max_iter=50, conv_threshold=1e-6) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve( + self.query, self.memory, return_trajectory=True + ) + # Final two queries should be very close + q_last = result.trajectory[-1] + q_prev = result.trajectory[-2] + delta = torch.norm(q_last - q_prev) + assert delta < 1e-4, f"Did not converge: delta={delta}" + + def test_high_beta_sharp_retrieval(self) -> None: + """Higher beta should produce sharper (lower entropy) attention.""" + low_beta = HopfieldRetrieval(HopfieldConfig(beta=0.5, max_iter=5)) + high_beta = HopfieldRetrieval(HopfieldConfig(beta=5.0, max_iter=5)) + + result_low = low_beta.retrieve(self.query, self.memory) + result_high = high_beta.retrieve(self.query, self.memory) + + entropy_low = -( + result_low.attention_weights * result_low.attention_weights.log() + ).sum() + entropy_high = -( + result_high.attention_weights * result_high.attention_weights.log() + ).sum() + + assert entropy_high < entropy_low, "Higher beta should give lower entropy" + + def test_single_memory_converges_to_it(self) -> None: + """With N=1, retrieval should converge to the single memory.""" + single_memory = F.normalize(torch.randn(self.d, 1), dim=0) + result = self.hopfield.retrieve(self.query, single_memory) + assert torch.allclose( + result.attention_weights, torch.ones(1, 1), atol=1e-5 + ) + + def test_query_near_memory_retrieves_it(self) -> None: + """If query ~= memory_i, attention should peak at index i.""" + target_idx = 42 + query = self.memory[:, target_idx].unsqueeze(0) # (1, d) — exact match + config = HopfieldConfig(beta=10.0, max_iter=5) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(query, self.memory) + top_idx = result.attention_weights.argmax(dim=-1).item() + assert top_idx == target_idx, f"Expected {target_idx}, got {top_idx}" + + def test_batch_retrieval(self) -> None: + """Should handle batch of queries.""" + batch_query = F.normalize(torch.randn(8, self.d), dim=-1) + result = self.hopfield.retrieve(batch_query, self.memory) + assert result.attention_weights.shape == (8, self.N) + assert result.converged_query.shape == (8, self.d) + + def test_iteration_refines_query(self) -> None: + """Multi-hop test: query starts far from target, iteration should bring it closer. + + Setup: memory has two clusters. Query is near cluster A but the "answer" + is in cluster B, reachable through an intermediate memory that bridges both. + After iteration, the query should drift toward cluster B. + """ + torch.manual_seed(0) + d = 32 + + # Cluster A: memories 0-4 centered around direction [1, 0, 0, ...] + # Cluster B: memories 5-9 centered around direction [0, 1, 0, ...] + # Bridge memory 10: between A and B + center_a = torch.zeros(d) + center_a[0] = 1.0 + center_b = torch.zeros(d) + center_b[1] = 1.0 + bridge = F.normalize(center_a + center_b, dim=0) + + memories = [] + for _ in range(5): + memories.append(F.normalize(center_a + 0.1 * torch.randn(d), dim=0)) + for _ in range(5): + memories.append(F.normalize(center_b + 0.1 * torch.randn(d), dim=0)) + memories.append(bridge) + + M = torch.stack(memories, dim=1) # (d, 11) + q0 = F.normalize(center_a + 0.05 * torch.randn(d), dim=0).unsqueeze(0) + + config = HopfieldConfig(beta=3.0, max_iter=10, conv_threshold=1e-8) + hopfield = HopfieldRetrieval(config) + result = hopfield.retrieve(q0, M, return_trajectory=True) + + # After iteration, query should have drifted: its dot product with center_b + # should be higher than the initial query's dot product with center_b + initial_sim_b = (q0.squeeze() @ center_b).item() + final_sim_b = (result.converged_query.squeeze() @ center_b).item() + assert final_sim_b > initial_sim_b, ( + f"Iteration should pull query toward cluster B: " + f"initial={initial_sim_b:.4f}, final={final_sim_b:.4f}" + ) diff --git a/tests/test_memory_bank.py b/tests/test_memory_bank.py new file mode 100644 index 0000000..0087bbd --- /dev/null +++ b/tests/test_memory_bank.py @@ -0,0 +1,65 @@ +"""Unit tests for the memory bank module.""" + +import torch + +from hag.config import MemoryBankConfig +from hag.memory_bank import MemoryBank + + +class TestMemoryBank: + """Tests for MemoryBank construction, lookup, and persistence.""" + + def test_build_and_size(self) -> None: + """Build memory bank and verify size.""" + config = MemoryBankConfig(embedding_dim=64, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(100, 64) + passages = [f"passage {i}" for i in range(100)] + mb.build_from_embeddings(embeddings, passages) + assert mb.size == 100 + assert mb.dim == 64 + + def test_normalization(self) -> None: + """If normalize=True, stored embeddings should have unit norm.""" + config = MemoryBankConfig(embedding_dim=64, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(50, 64) * 5 # non-unit norm + mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) + norms = torch.norm(mb.embeddings, dim=0) + assert torch.allclose(norms, torch.ones(50), atol=1e-5) + + def test_no_normalization(self) -> None: + """If normalize=False, stored embeddings keep original norms.""" + config = MemoryBankConfig(embedding_dim=64, normalize=False) + mb = MemoryBank(config) + embeddings = torch.randn(50, 64) * 5 + original_norms = torch.norm(embeddings, dim=-1) + mb.build_from_embeddings(embeddings, [f"p{i}" for i in range(50)]) + stored_norms = torch.norm(mb.embeddings, dim=0) + assert torch.allclose(stored_norms, original_norms, atol=1e-5) + + def test_get_passages_by_indices(self) -> None: + """Index -> passage text lookup.""" + config = MemoryBankConfig(embedding_dim=64, normalize=False) + mb = MemoryBank(config) + passages = [f"passage {i}" for i in range(100)] + mb.build_from_embeddings(torch.randn(100, 64), passages) + result = mb.get_passages_by_indices(torch.tensor([0, 50, 99])) + assert result == ["passage 0", "passage 50", "passage 99"] + + def test_save_and_load(self, tmp_path) -> None: # type: ignore[no-untyped-def] + """Save and reload memory bank, verify contents match.""" + config = MemoryBankConfig(embedding_dim=32, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(20, 32) + passages = [f"text {i}" for i in range(20)] + mb.build_from_embeddings(embeddings, passages) + + save_path = str(tmp_path / "mb.pt") + mb.save(save_path) + + mb2 = MemoryBank(config) + mb2.load(save_path) + assert mb2.size == 20 + assert mb2.passages == passages + assert torch.allclose(mb.embeddings, mb2.embeddings, atol=1e-6) diff --git a/tests/test_metrics.py b/tests/test_metrics.py new file mode 100644 index 0000000..4a18bec --- /dev/null +++ b/tests/test_metrics.py @@ -0,0 +1,54 @@ +"""Unit tests for evaluation metrics.""" + +from hag.metrics import exact_match, f1_score, retrieval_recall_at_k + + +class TestMetrics: + """Tests for EM, F1, and retrieval recall metrics.""" + + def test_exact_match_basic(self) -> None: + """Basic exact match: case insensitive.""" + assert exact_match("Paris", "paris") == 1.0 + assert exact_match("Paris", "London") == 0.0 + + def test_exact_match_normalization(self) -> None: + """Should strip whitespace, lowercase, remove articles.""" + assert exact_match(" The Paris ", "paris") == 1.0 + assert exact_match("A dog", "dog") == 1.0 + + def test_exact_match_empty(self) -> None: + """Empty strings should match.""" + assert exact_match("", "") == 1.0 + + def test_f1_score_perfect(self) -> None: + """Identical strings should have F1 = 1.0.""" + assert f1_score("the cat sat", "the cat sat") == 1.0 + + def test_f1_score_partial(self) -> None: + """Partial overlap should give 0 < F1 < 1.""" + score = f1_score("the cat sat on the mat", "the cat sat") + assert 0.5 < score < 1.0 + + def test_f1_score_no_overlap(self) -> None: + """No common tokens should give F1 = 0.""" + assert f1_score("hello world", "foo bar") == 0.0 + + def test_f1_score_empty(self) -> None: + """Two empty strings should have F1 = 1.0.""" + assert f1_score("", "") == 1.0 + + def test_retrieval_recall(self) -> None: + """Standard retrieval recall computation.""" + assert retrieval_recall_at_k([1, 3, 5], [1, 5, 7], k=3) == 2 / 3 + + def test_retrieval_recall_perfect(self) -> None: + """All gold passages retrieved.""" + assert retrieval_recall_at_k([1, 2, 3], [1, 2, 3], k=3) == 1.0 + + def test_retrieval_recall_none(self) -> None: + """No gold passages retrieved.""" + assert retrieval_recall_at_k([4, 5, 6], [1, 2, 3], k=3) == 0.0 + + def test_retrieval_recall_empty_gold(self) -> None: + """No gold passages means perfect recall by convention.""" + assert retrieval_recall_at_k([1, 2, 3], [], k=3) == 1.0 diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py new file mode 100644 index 0000000..3ac81ce --- /dev/null +++ b/tests/test_pipeline.py @@ -0,0 +1,132 @@ +"""Integration tests for the RAG/HAG pipeline using mock encoder and generator.""" + +import numpy as np +import torch +import torch.nn.functional as F + +from hag.config import ( + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.encoder import FakeEncoder +from hag.generator import FakeGenerator +from hag.memory_bank import MemoryBank +from hag.pipeline import RAGPipeline +from hag.retriever_faiss import FAISSRetriever + + +def _make_memory_bank(n: int = 20, d: int = 64) -> MemoryBank: + """Create a small memory bank with random embeddings.""" + torch.manual_seed(123) + config = MemoryBankConfig(embedding_dim=d, normalize=True) + mb = MemoryBank(config) + embeddings = torch.randn(n, d) + passages = [f"passage {i} content" for i in range(n)] + mb.build_from_embeddings(embeddings, passages) + return mb + + +def _make_faiss_retriever(n: int = 20, d: int = 64) -> FAISSRetriever: + """Create a small FAISS retriever with random embeddings.""" + np.random.seed(123) + embeddings = np.random.randn(n, d).astype(np.float32) + passages = [f"passage {i} content" for i in range(n)] + retriever = FAISSRetriever(top_k=3) + retriever.build_index(embeddings, passages) + return retriever + + +class TestPipeline: + """Integration tests using mock encoder and generator.""" + + def test_pipeline_runs_end_to_end_hopfield(self) -> None: + """Mock encoder + mock LLM + Hopfield retriever -> produces an answer.""" + config = PipelineConfig( + hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3), + memory=MemoryBankConfig(embedding_dim=64, normalize=True), + retriever_type="hopfield", + ) + mb = _make_memory_bank(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + result = pipeline.run("What is the capital of France?") + + assert isinstance(result.answer, str) + assert len(result.answer) > 0 + assert len(result.retrieved_passages) == 3 + + def test_pipeline_runs_end_to_end_faiss(self) -> None: + """Mock encoder + mock LLM + FAISS retriever -> produces an answer.""" + config = PipelineConfig( + hopfield=HopfieldConfig(top_k=3), + retriever_type="faiss", + ) + faiss_ret = _make_faiss_retriever(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + faiss_retriever=faiss_ret, + ) + result = pipeline.run("What is the capital of France?") + + assert isinstance(result.answer, str) + assert len(result.answer) > 0 + assert len(result.retrieved_passages) == 3 + + def test_pipeline_retrieval_result_included(self) -> None: + """PipelineResult should contain retrieval metadata.""" + config = PipelineConfig( + hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3), + memory=MemoryBankConfig(embedding_dim=64, normalize=True), + retriever_type="hopfield", + ) + mb = _make_memory_bank(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + result = pipeline.run("Test question?") + + assert result.retrieval_result is not None + assert result.retrieval_result.scores is not None + assert result.retrieval_result.indices is not None + assert result.question == "Test question?" + + def test_pipeline_batch(self) -> None: + """run_batch should return results for each question.""" + config = PipelineConfig( + hopfield=HopfieldConfig(beta=1.0, max_iter=3, top_k=3), + memory=MemoryBankConfig(embedding_dim=64, normalize=True), + retriever_type="hopfield", + ) + mb = _make_memory_bank(n=20, d=64) + encoder = FakeEncoder(dim=64) + generator = FakeGenerator() + + pipeline = RAGPipeline( + config=config, + encoder=encoder, + generator=generator, + memory_bank=mb, + ) + questions = ["Q1?", "Q2?", "Q3?"] + results = pipeline.run_batch(questions) + assert len(results) == 3 + assert all(r.answer == "mock answer" for r in results) diff --git a/tests/test_retriever.py b/tests/test_retriever.py new file mode 100644 index 0000000..fa96dca --- /dev/null +++ b/tests/test_retriever.py @@ -0,0 +1,149 @@ +"""Unit tests for both FAISS and Hopfield retrievers.""" + +import numpy as np +import torch +import torch.nn.functional as F + +from hag.config import HopfieldConfig, MemoryBankConfig +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank +from hag.retriever_faiss import FAISSRetriever +from hag.retriever_hopfield import HopfieldRetriever + + +class TestHopfieldRetriever: + """Tests for the Hopfield-based retriever.""" + + def setup_method(self) -> None: + """Set up a small memory bank and Hopfield retriever.""" + torch.manual_seed(42) + self.d = 64 + self.N = 50 + self.top_k = 3 + + config = MemoryBankConfig(embedding_dim=self.d, normalize=True) + self.mb = MemoryBank(config) + embeddings = torch.randn(self.N, self.d) + self.passages = [f"passage {i}" for i in range(self.N)] + self.mb.build_from_embeddings(embeddings, self.passages) + + hopfield_config = HopfieldConfig(beta=2.0, max_iter=5) + hopfield = HopfieldRetrieval(hopfield_config) + self.retriever = HopfieldRetriever(hopfield, self.mb, top_k=self.top_k) + + def test_retrieves_correct_number_of_passages(self) -> None: + """top_k=3 should return exactly 3 passages.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query) + assert len(result.passages) == self.top_k + assert result.scores.shape == (self.top_k,) + assert result.indices.shape == (self.top_k,) + + def test_scores_are_sorted_descending(self) -> None: + """Returned scores should be in descending order.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query) + scores = result.scores.tolist() + assert scores == sorted(scores, reverse=True) + + def test_passages_match_indices(self) -> None: + """Returned passage texts should correspond to returned indices.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query) + expected_passages = [self.passages[i] for i in result.indices.tolist()] + assert result.passages == expected_passages + + def test_analysis_mode_returns_hopfield_result(self) -> None: + """When return_analysis=True, hopfield_result should be populated.""" + query = F.normalize(torch.randn(1, self.d), dim=-1) + result = self.retriever.retrieve(query, return_analysis=True) + assert result.hopfield_result is not None + assert result.hopfield_result.energy_curve is not None + assert result.hopfield_result.trajectory is not None + + +class TestFAISSRetriever: + """Tests for the FAISS baseline retriever.""" + + def setup_method(self) -> None: + """Set up a small FAISS index.""" + np.random.seed(42) + self.d = 64 + self.N = 50 + self.top_k = 3 + + self.embeddings = np.random.randn(self.N, self.d).astype(np.float32) + self.passages = [f"passage {i}" for i in range(self.N)] + + self.retriever = FAISSRetriever(top_k=self.top_k) + self.retriever.build_index(self.embeddings.copy(), self.passages) + + def test_retrieves_correct_number(self) -> None: + """top_k=3 should return exactly 3 passages.""" + query = np.random.randn(self.d).astype(np.float32) + result = self.retriever.retrieve(query) + assert len(result.passages) == self.top_k + + def test_nearest_neighbor_is_self(self) -> None: + """If query = passage_i's embedding, top-1 should be passage_i.""" + # Use the original embedding (before normalization in build_index) + # Rebuild with fresh copy so we know the normalization state + embeddings = self.embeddings.copy() + retriever = FAISSRetriever(top_k=1) + retriever.build_index(embeddings.copy(), self.passages) + + # Use the normalized version as query (FAISS normalizes internally) + target_idx = 10 + query = self.embeddings[target_idx].copy() + result = retriever.retrieve(query) + assert result.indices[0].item() == target_idx + + def test_scores_sorted_descending(self) -> None: + """Returned scores should be in descending order.""" + query = np.random.randn(self.d).astype(np.float32) + result = self.retriever.retrieve(query) + scores = result.scores.tolist() + assert scores == sorted(scores, reverse=True) + + +class TestRetrieverComparison: + """Compare FAISS and Hopfield retrievers.""" + + def test_same_top1_for_obvious_query(self) -> None: + """When query is very close to one memory, both should agree on top-1.""" + torch.manual_seed(42) + np.random.seed(42) + d = 64 + N = 50 + target_idx = 25 + + # Create embeddings + embeddings_np = np.random.randn(N, d).astype(np.float32) + # Normalize + norms = np.linalg.norm(embeddings_np, axis=1, keepdims=True) + embeddings_np = embeddings_np / norms + + # Query is exactly the target embedding + query_np = embeddings_np[target_idx].copy() + query_torch = torch.from_numpy(query_np).unsqueeze(0) # (1, d) + + passages = [f"passage {i}" for i in range(N)] + + # FAISS retriever + faiss_ret = FAISSRetriever(top_k=1) + faiss_ret.build_index(embeddings_np.copy(), passages) + faiss_result = faiss_ret.retrieve(query_np.copy()) + + # Hopfield retriever + mb_config = MemoryBankConfig(embedding_dim=d, normalize=False) + mb = MemoryBank(mb_config) + mb.build_from_embeddings( + torch.from_numpy(embeddings_np), passages + ) + hop_config = HopfieldConfig(beta=10.0, max_iter=5) + hopfield = HopfieldRetrieval(hop_config) + hop_ret = HopfieldRetriever(hopfield, mb, top_k=1) + hop_result = hop_ret.retrieve(query_torch) + + assert faiss_result.indices[0].item() == target_idx + assert hop_result.indices[0].item() == target_idx |
