summaryrefslogtreecommitdiff
path: root/hag
diff options
context:
space:
mode:
Diffstat (limited to 'hag')
-rw-r--r--hag/__init__.py21
-rw-r--r--hag/config.py50
-rw-r--r--hag/datatypes.py37
-rw-r--r--hag/encoder.py88
-rw-r--r--hag/energy.py83
-rw-r--r--hag/generator.py87
-rw-r--r--hag/hopfield.py124
-rw-r--r--hag/memory_bank.py93
-rw-r--r--hag/metrics.py113
-rw-r--r--hag/pipeline.py107
-rw-r--r--hag/retriever_faiss.py73
-rw-r--r--hag/retriever_hopfield.py77
12 files changed, 953 insertions, 0 deletions
diff --git a/hag/__init__.py b/hag/__init__.py
new file mode 100644
index 0000000..18496e9
--- /dev/null
+++ b/hag/__init__.py
@@ -0,0 +1,21 @@
+"""HAG: Hopfield-Augmented Generation."""
+
+from hag.config import (
+ EncoderConfig,
+ GeneratorConfig,
+ HopfieldConfig,
+ MemoryBankConfig,
+ PipelineConfig,
+)
+from hag.datatypes import HopfieldResult, PipelineResult, RetrievalResult
+
+__all__ = [
+ "HopfieldConfig",
+ "MemoryBankConfig",
+ "EncoderConfig",
+ "GeneratorConfig",
+ "PipelineConfig",
+ "HopfieldResult",
+ "RetrievalResult",
+ "PipelineResult",
+]
diff --git a/hag/config.py b/hag/config.py
new file mode 100644
index 0000000..793e3a6
--- /dev/null
+++ b/hag/config.py
@@ -0,0 +1,50 @@
+"""All hyperparameters and configuration dataclasses for HAG."""
+
+from dataclasses import dataclass, field
+
+
+@dataclass
+class HopfieldConfig:
+ """Configuration for the Hopfield retrieval module."""
+
+ beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval
+ max_iter: int = 5 # Maximum Hopfield iteration steps
+ conv_threshold: float = 1e-4 # Stop if ||q_{t+1} - q_t|| < threshold
+ top_k: int = 5 # Number of passages to retrieve from final attention weights
+
+
+@dataclass
+class MemoryBankConfig:
+ """Configuration for the memory bank."""
+
+ embedding_dim: int = 768 # Must match encoder output dim
+ normalize: bool = True # L2-normalize embeddings in memory bank
+
+
+@dataclass
+class EncoderConfig:
+ """Configuration for the query/passage encoder."""
+
+ model_name: str = "facebook/contriever-msmarco"
+ max_length: int = 512
+ batch_size: int = 64
+
+
+@dataclass
+class GeneratorConfig:
+ """Configuration for the LLM generator."""
+
+ model_name: str = "meta-llama/Llama-3.1-8B-Instruct"
+ max_new_tokens: int = 128
+ temperature: float = 0.0 # Greedy decoding for reproducibility
+
+
+@dataclass
+class PipelineConfig:
+ """Top-level pipeline configuration."""
+
+ hopfield: HopfieldConfig = field(default_factory=HopfieldConfig)
+ memory: MemoryBankConfig = field(default_factory=MemoryBankConfig)
+ encoder: EncoderConfig = field(default_factory=EncoderConfig)
+ generator: GeneratorConfig = field(default_factory=GeneratorConfig)
+ retriever_type: str = "hopfield" # "hopfield" or "faiss"
diff --git a/hag/datatypes.py b/hag/datatypes.py
new file mode 100644
index 0000000..0f4254d
--- /dev/null
+++ b/hag/datatypes.py
@@ -0,0 +1,37 @@
+"""Data types used across HAG modules."""
+
+from dataclasses import dataclass, field
+from typing import List, Optional
+
+import torch
+
+
+@dataclass
+class HopfieldResult:
+ """Result from Hopfield iterative retrieval."""
+
+ attention_weights: torch.Tensor # (batch, N) or (N,)
+ converged_query: torch.Tensor # (batch, d) or (d,)
+ num_steps: int
+ trajectory: Optional[List[torch.Tensor]] = None # list of q_t
+ energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t)
+
+
+@dataclass
+class RetrievalResult:
+ """Result from a retriever (FAISS or Hopfield)."""
+
+ passages: List[str]
+ scores: torch.Tensor # top-k scores
+ indices: torch.Tensor # top-k indices
+ hopfield_result: Optional[HopfieldResult] = None
+
+
+@dataclass
+class PipelineResult:
+ """Result from the full RAG/HAG pipeline."""
+
+ question: str
+ answer: str
+ retrieved_passages: List[str]
+ retrieval_result: RetrievalResult
diff --git a/hag/encoder.py b/hag/encoder.py
new file mode 100644
index 0000000..7e103f3
--- /dev/null
+++ b/hag/encoder.py
@@ -0,0 +1,88 @@
+"""Wrapper for encoding queries and passages into embeddings."""
+
+import logging
+from typing import List, Union
+
+import torch
+
+from hag.config import EncoderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class Encoder:
+ """Encodes text queries/passages into dense embeddings.
+
+ Uses a HuggingFace transformer model (e.g., Contriever).
+ For testing, use FakeEncoder instead.
+ """
+
+ def __init__(self, config: EncoderConfig) -> None:
+ self.config = config
+ self._tokenizer = None
+ self._model = None
+
+ def _load_model(self) -> None:
+ """Lazy-load the model and tokenizer."""
+ from transformers import AutoModel, AutoTokenizer
+
+ logger.info("Loading encoder model: %s", self.config.model_name)
+ self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
+ self._model = AutoModel.from_pretrained(self.config.model_name)
+ self._model.eval()
+
+ @torch.no_grad()
+ def encode(self, texts: Union[str, List[str]]) -> torch.Tensor:
+ """Encode text(s) into embedding(s).
+
+ Args:
+ texts: single string or list of strings
+
+ Returns:
+ (1, d) tensor for single input, (N, d) for list input.
+ """
+ if self._model is None:
+ self._load_model()
+
+ if isinstance(texts, str):
+ texts = [texts]
+
+ inputs = self._tokenizer(
+ texts,
+ max_length=self.config.max_length,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ )
+ outputs = self._model(**inputs)
+ # Mean pooling over token embeddings
+ embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d)
+ return embeddings
+
+
+class FakeEncoder:
+ """Deterministic hash-based encoder for testing. No model download needed."""
+
+ def __init__(self, dim: int = 64) -> None:
+ self.dim = dim
+
+ def encode(self, texts: Union[str, List[str]]) -> torch.Tensor:
+ """Produce deterministic embeddings based on text hash.
+
+ Args:
+ texts: single string or list of strings
+
+ Returns:
+ (1, d) or (N, d) normalized tensor.
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ embeddings = []
+ for text in texts:
+ torch.manual_seed(hash(text) % 2**32)
+ emb = torch.randn(1, self.dim)
+ embeddings.append(emb)
+
+ result = torch.cat(embeddings, dim=0) # (N, d)
+ return torch.nn.functional.normalize(result, dim=-1)
diff --git a/hag/energy.py b/hag/energy.py
new file mode 100644
index 0000000..62a39e9
--- /dev/null
+++ b/hag/energy.py
@@ -0,0 +1,83 @@
+"""Energy computation and analysis utilities for Hopfield retrieval."""
+
+import logging
+from typing import List
+
+import torch
+
+from hag.datatypes import HopfieldResult
+
+logger = logging.getLogger(__name__)
+
+
+def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]:
+ """Extract energy values at each iteration step.
+
+ Args:
+ hopfield_result: result from HopfieldRetrieval.retrieve() with return_energy=True
+
+ Returns:
+ List of energy values (floats) at each step.
+ """
+ if hopfield_result.energy_curve is None:
+ return []
+ return [e.item() if e.dim() == 0 else e.mean().item() for e in hopfield_result.energy_curve]
+
+
+def compute_energy_gap(energy_curve: List[float]) -> float:
+ """Compute the energy gap: Delta_E = E(q_0) - E(q_T).
+
+ Larger gap means more refinement happened during iteration.
+
+ Args:
+ energy_curve: list of energy values at each step
+
+ Returns:
+ Energy gap (float). Positive if energy decreased.
+ """
+ if len(energy_curve) < 2:
+ return 0.0
+ return energy_curve[0] - energy_curve[-1]
+
+
+def verify_monotonic_decrease(energy_curve: List[float], tol: float = 1e-6) -> bool:
+ """Check that E(q_{t+1}) <= E(q_t) for all t.
+
+ This should always be True for the Modern Hopfield Network.
+
+ Args:
+ energy_curve: list of energy values at each step
+ tol: numerical tolerance for comparison
+
+ Returns:
+ True if energy decreases monotonically (within tolerance).
+ """
+ for i in range(len(energy_curve) - 1):
+ if energy_curve[i + 1] > energy_curve[i] + tol:
+ return False
+ return True
+
+
+def compute_attention_entropy(attention_weights: torch.Tensor) -> float:
+ """Compute the entropy of attention weights.
+
+ H(alpha) = -sum_i alpha_i * log(alpha_i)
+
+ Low entropy = sharp retrieval (confident).
+ High entropy = diffuse retrieval (uncertain).
+
+ Args:
+ attention_weights: (N,) or (batch, N) — attention distribution
+
+ Returns:
+ Entropy value (float). Averaged over batch if batched.
+ """
+ if attention_weights.dim() == 1:
+ attention_weights = attention_weights.unsqueeze(0) # (1, N)
+
+ # Clamp to avoid log(0)
+ eps = 1e-12
+ alpha = attention_weights.clamp(min=eps)
+ entropy = -(alpha * alpha.log()).sum(dim=-1) # (batch,)
+
+ return entropy.mean().item()
diff --git a/hag/generator.py b/hag/generator.py
new file mode 100644
index 0000000..2142e0c
--- /dev/null
+++ b/hag/generator.py
@@ -0,0 +1,87 @@
+"""LLM generation wrapper for producing answers from retrieved context."""
+
+import logging
+from typing import List
+
+from hag.config import GeneratorConfig
+
+logger = logging.getLogger(__name__)
+
+PROMPT_TEMPLATE = """Answer the following question based on the provided context passages.
+
+Context:
+{context}
+
+Question: {question}
+
+Answer:"""
+
+
+class Generator:
+ """LLM-based answer generator.
+
+ Uses a HuggingFace causal LM (e.g., Llama-3.1-8B-Instruct).
+ For testing, use FakeGenerator instead.
+ """
+
+ def __init__(self, config: GeneratorConfig) -> None:
+ self.config = config
+ self._tokenizer = None
+ self._model = None
+
+ def _load_model(self) -> None:
+ """Lazy-load the model and tokenizer."""
+ from transformers import AutoModelForCausalLM, AutoTokenizer
+
+ logger.info("Loading generator model: %s", self.config.model_name)
+ self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
+ self._model = AutoModelForCausalLM.from_pretrained(
+ self.config.model_name,
+ torch_dtype="auto",
+ )
+ self._model.eval()
+
+ def generate(self, question: str, passages: List[str]) -> str:
+ """Generate an answer given a question and retrieved passages.
+
+ Args:
+ question: the user question
+ passages: list of retrieved passage texts
+
+ Returns:
+ Generated answer string.
+ """
+ if self._model is None:
+ self._load_model()
+
+ context = "\n\n".join(
+ f"[{i+1}] {p}" for i, p in enumerate(passages)
+ )
+ prompt = PROMPT_TEMPLATE.format(context=context, question=question)
+
+ inputs = self._tokenizer(prompt, return_tensors="pt")
+ outputs = self._model.generate(
+ **inputs,
+ max_new_tokens=self.config.max_new_tokens,
+ temperature=self.config.temperature if self.config.temperature > 0 else None,
+ do_sample=self.config.temperature > 0,
+ )
+ # Decode only the generated tokens (skip the prompt)
+ generated = outputs[0][inputs["input_ids"].shape[1]:]
+ return self._tokenizer.decode(generated, skip_special_tokens=True).strip()
+
+
+class FakeGenerator:
+ """Deterministic mock generator for testing. No model download needed."""
+
+ def generate(self, question: str, passages: List[str]) -> str:
+ """Return a mock answer.
+
+ Args:
+ question: the user question
+ passages: list of retrieved passages
+
+ Returns:
+ Mock answer string.
+ """
+ return "mock answer"
diff --git a/hag/hopfield.py b/hag/hopfield.py
new file mode 100644
index 0000000..287e4af
--- /dev/null
+++ b/hag/hopfield.py
@@ -0,0 +1,124 @@
+"""Core Modern Continuous Hopfield Network retrieval module.
+
+Implements the iterative retrieval dynamics from:
+ Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021)
+
+Update rule: q_{t+1} = M * softmax(beta * M^T * q_t)
+Energy: E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+"""
+
+import logging
+from typing import Optional
+
+import torch
+
+from hag.config import HopfieldConfig
+from hag.datatypes import HopfieldResult
+
+logger = logging.getLogger(__name__)
+
+
+class HopfieldRetrieval:
+ """Modern Continuous Hopfield Network for memory retrieval.
+
+ Given memory bank M in R^{d x N} and query q in R^d:
+ 1. Compute attention: alpha = softmax(beta * M^T @ q)
+ 2. Update query: q_new = M @ alpha
+ 3. Repeat until convergence or max_iter
+
+ The energy function is:
+ E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+
+ Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease)
+ """
+
+ def __init__(self, config: HopfieldConfig) -> None:
+ self.config = config
+
+ @torch.no_grad()
+ def retrieve(
+ self,
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ return_trajectory: bool = False,
+ return_energy: bool = False,
+ ) -> HopfieldResult:
+ """Run iterative Hopfield retrieval.
+
+ Args:
+ query: (d,) or (batch, d) — query embedding(s)
+ memory: (d, N) — memory bank of passage embeddings
+ return_trajectory: if True, store q_t at each step
+ return_energy: if True, store E(q_t) at each step
+
+ Returns:
+ HopfieldResult with attention_weights, converged_query, num_steps,
+ and optionally trajectory and energy_curve.
+ """
+ # Ensure query is 2D: (batch, d)
+ if query.dim() == 1:
+ query = query.unsqueeze(0) # (1, d)
+
+ q = query.clone() # (batch, d)
+
+ trajectory = [q.clone()] if return_trajectory else None
+ energies = [self.compute_energy(q, memory)] if return_energy else None
+
+ num_steps = 0
+ for t in range(self.config.max_iter):
+ # Core Hopfield update
+ logits = self.config.beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+ q_new = alpha @ memory.T # (batch, d)
+
+ # Check convergence
+ delta = torch.norm(q_new - q, dim=-1).max() # scalar
+ q = q_new
+
+ if return_trajectory:
+ trajectory.append(q.clone())
+ if return_energy:
+ energies.append(self.compute_energy(q, memory))
+
+ num_steps = t + 1
+
+ if delta < self.config.conv_threshold:
+ break
+
+ # Final attention weights (recompute to ensure consistency)
+ logits = self.config.beta * (q @ memory) # (batch, N)
+ alpha = torch.softmax(logits, dim=-1) # (batch, N)
+
+ return HopfieldResult(
+ attention_weights=alpha,
+ converged_query=q,
+ num_steps=num_steps,
+ trajectory=trajectory,
+ energy_curve=energies,
+ )
+
+ def compute_energy(
+ self,
+ query: torch.Tensor,
+ memory: torch.Tensor,
+ ) -> torch.Tensor:
+ """Compute the Hopfield energy function.
+
+ E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2
+
+ Args:
+ query: (batch, d) or (d,) — query embedding(s)
+ memory: (d, N) — memory bank
+
+ Returns:
+ Energy scalar or (batch,) tensor.
+ """
+ if query.dim() == 1:
+ query = query.unsqueeze(0) # (1, d)
+
+ logits = self.config.beta * (query @ memory) # (batch, N)
+ lse = torch.logsumexp(logits, dim=-1) # (batch,)
+ norm_sq = 0.5 * (query**2).sum(dim=-1) # (batch,)
+ energy = -1.0 / self.config.beta * lse + norm_sq # (batch,)
+
+ return energy
diff --git a/hag/memory_bank.py b/hag/memory_bank.py
new file mode 100644
index 0000000..42dcc73
--- /dev/null
+++ b/hag/memory_bank.py
@@ -0,0 +1,93 @@
+"""Memory bank construction and management for passage embeddings."""
+
+import logging
+from typing import Dict, List, Optional
+
+import torch
+import torch.nn.functional as F
+
+from hag.config import MemoryBankConfig
+
+logger = logging.getLogger(__name__)
+
+
+class MemoryBank:
+ """Stores passage embeddings and provides lookup from indices back to text.
+
+ The memory bank is M in R^{d x N} where each column is a passage embedding.
+ Also maintains a mapping from column index to passage text for final retrieval.
+ """
+
+ def __init__(self, config: MemoryBankConfig) -> None:
+ self.config = config
+ self.embeddings: Optional[torch.Tensor] = None # (d, N)
+ self.passages: List[str] = []
+
+ def build_from_embeddings(
+ self, embeddings: torch.Tensor, passages: List[str]
+ ) -> None:
+ """Build memory bank from precomputed embeddings.
+
+ Args:
+ embeddings: (N, d) — passage embeddings (note: input is N x d)
+ passages: list of N passage strings
+ """
+ assert embeddings.shape[0] == len(passages), (
+ f"Number of embeddings ({embeddings.shape[0]}) must match "
+ f"number of passages ({len(passages)})"
+ )
+ if self.config.normalize:
+ embeddings = F.normalize(embeddings, dim=-1)
+ self.embeddings = embeddings.T # Store as (d, N) for efficient matmul
+ self.passages = list(passages)
+ logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim)
+
+ def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]:
+ """Given top-k indices, return corresponding passage texts.
+
+ Args:
+ indices: (k,) or (batch, k) tensor of integer indices
+
+ Returns:
+ List of passage strings.
+ """
+ flat_indices = indices.flatten().tolist()
+ return [self.passages[i] for i in flat_indices]
+
+ def save(self, path: str) -> None:
+ """Save memory bank to disk.
+
+ Args:
+ path: file path for saving (e.g., 'memory_bank.pt')
+ """
+ data: Dict = {
+ "embeddings": self.embeddings,
+ "passages": self.passages,
+ "config": {
+ "embedding_dim": self.config.embedding_dim,
+ "normalize": self.config.normalize,
+ },
+ }
+ torch.save(data, path)
+ logger.info("Saved memory bank to %s", path)
+
+ def load(self, path: str) -> None:
+ """Load memory bank from disk.
+
+ Args:
+ path: file path to load from
+ """
+ data = torch.load(path, weights_only=False)
+ self.embeddings = data["embeddings"]
+ self.passages = data["passages"]
+ logger.info("Loaded memory bank from %s (%d passages)", path, self.size)
+
+ @property
+ def size(self) -> int:
+ """Number of passages in the memory bank."""
+ return self.embeddings.shape[1] if self.embeddings is not None else 0
+
+ @property
+ def dim(self) -> int:
+ """Embedding dimensionality."""
+ return self.embeddings.shape[0] if self.embeddings is not None else 0
diff --git a/hag/metrics.py b/hag/metrics.py
new file mode 100644
index 0000000..6a196df
--- /dev/null
+++ b/hag/metrics.py
@@ -0,0 +1,113 @@
+"""Evaluation metrics for HAG: exact match, F1, retrieval recall."""
+
+import logging
+import re
+import string
+from collections import Counter
+from typing import Dict, List
+
+from hag.datatypes import PipelineResult
+
+logger = logging.getLogger(__name__)
+
+
+def _normalize_answer(text: str) -> str:
+ """Normalize answer text: lowercase, strip, remove articles and punctuation."""
+ text = text.lower().strip()
+ # Remove articles
+ text = re.sub(r"\b(a|an|the)\b", " ", text)
+ # Remove punctuation
+ text = text.translate(str.maketrans("", "", string.punctuation))
+ # Collapse whitespace
+ text = " ".join(text.split())
+ return text
+
+
+def exact_match(prediction: str, ground_truth: str) -> float:
+ """Normalized exact match.
+
+ Args:
+ prediction: predicted answer string
+ ground_truth: gold answer string
+
+ Returns:
+ 1.0 if normalized strings match, 0.0 otherwise.
+ """
+ return float(_normalize_answer(prediction) == _normalize_answer(ground_truth))
+
+
+def f1_score(prediction: str, ground_truth: str) -> float:
+ """Token-level F1 between prediction and ground truth.
+
+ Args:
+ prediction: predicted answer string
+ ground_truth: gold answer string
+
+ Returns:
+ F1 score between 0.0 and 1.0.
+ """
+ pred_tokens = _normalize_answer(prediction).split()
+ gold_tokens = _normalize_answer(ground_truth).split()
+
+ if not pred_tokens and not gold_tokens:
+ return 1.0
+ if not pred_tokens or not gold_tokens:
+ return 0.0
+
+ common = Counter(pred_tokens) & Counter(gold_tokens)
+ num_same = sum(common.values())
+
+ if num_same == 0:
+ return 0.0
+
+ precision = num_same / len(pred_tokens)
+ recall = num_same / len(gold_tokens)
+ f1 = 2 * precision * recall / (precision + recall)
+ return f1
+
+
+def retrieval_recall_at_k(
+ retrieved_indices: List[int], gold_indices: List[int], k: int
+) -> float:
+ """What fraction of gold passages appear in the retrieved top-k?
+
+ Args:
+ retrieved_indices: list of retrieved passage indices (top-k)
+ gold_indices: list of gold/relevant passage indices
+ k: number of retrieved passages to consider
+
+ Returns:
+ Recall score between 0.0 and 1.0.
+ """
+ if not gold_indices:
+ return 1.0
+ retrieved_set = set(retrieved_indices[:k])
+ gold_set = set(gold_indices)
+ return len(retrieved_set & gold_set) / len(gold_set)
+
+
+def evaluate_dataset(
+ results: List[PipelineResult], gold_answers: List[str]
+) -> Dict[str, float]:
+ """Compute aggregate metrics over a dataset.
+
+ Args:
+ results: list of PipelineResult from the pipeline
+ gold_answers: list of gold answer strings
+
+ Returns:
+ Dict with keys 'em', 'f1' containing averaged scores.
+ """
+ assert len(results) == len(gold_answers)
+
+ em_scores = []
+ f1_scores = []
+
+ for result, gold in zip(results, gold_answers):
+ em_scores.append(exact_match(result.answer, gold))
+ f1_scores.append(f1_score(result.answer, gold))
+
+ return {
+ "em": sum(em_scores) / len(em_scores) if em_scores else 0.0,
+ "f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0,
+ }
diff --git a/hag/pipeline.py b/hag/pipeline.py
new file mode 100644
index 0000000..1fefb84
--- /dev/null
+++ b/hag/pipeline.py
@@ -0,0 +1,107 @@
+"""End-to-end RAG/HAG pipeline: query -> encode -> retrieve -> generate."""
+
+import logging
+from typing import List, Optional, Protocol, Union
+
+import numpy as np
+import torch
+
+from hag.config import PipelineConfig
+from hag.datatypes import PipelineResult, RetrievalResult
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+from hag.retriever_faiss import FAISSRetriever
+from hag.retriever_hopfield import HopfieldRetriever
+
+logger = logging.getLogger(__name__)
+
+
+class EncoderProtocol(Protocol):
+ """Protocol for encoder interface."""
+
+ def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: ...
+
+
+class GeneratorProtocol(Protocol):
+ """Protocol for generator interface."""
+
+ def generate(self, question: str, passages: List[str]) -> str: ...
+
+
+class RAGPipeline:
+ """End-to-end pipeline: query -> encode -> retrieve -> generate.
+
+ Supports both FAISS (baseline) and Hopfield (ours) retrieval.
+ """
+
+ def __init__(
+ self,
+ config: PipelineConfig,
+ encoder: EncoderProtocol,
+ generator: GeneratorProtocol,
+ memory_bank: Optional[MemoryBank] = None,
+ faiss_retriever: Optional[FAISSRetriever] = None,
+ ) -> None:
+ self.config = config
+ self.encoder = encoder
+ self.generator = generator
+
+ if config.retriever_type == "faiss":
+ assert faiss_retriever is not None, "FAISSRetriever required for faiss mode"
+ self.retriever_type = "faiss"
+ self.faiss_retriever = faiss_retriever
+ self.hopfield_retriever: Optional[HopfieldRetriever] = None
+ elif config.retriever_type == "hopfield":
+ assert memory_bank is not None, "MemoryBank required for hopfield mode"
+ hopfield = HopfieldRetrieval(config.hopfield)
+ self.retriever_type = "hopfield"
+ self.hopfield_retriever = HopfieldRetriever(
+ hopfield, memory_bank, top_k=config.hopfield.top_k
+ )
+ self.faiss_retriever = None
+ else:
+ raise ValueError(f"Unknown retriever_type: {config.retriever_type}")
+
+ def run(self, question: str) -> PipelineResult:
+ """Run the full pipeline on a single question.
+
+ 1. Encode question -> query embedding
+ 2. Retrieve passages (FAISS or Hopfield)
+ 3. Generate answer with LLM
+
+ Args:
+ question: input question string
+
+ Returns:
+ PipelineResult with answer and retrieval metadata.
+ """
+ # Encode
+ query_emb = self.encoder.encode(question) # (1, d)
+
+ # Retrieve
+ if self.retriever_type == "hopfield":
+ retrieval_result = self.hopfield_retriever.retrieve(query_emb)
+ else:
+ query_np = query_emb.detach().numpy().astype(np.float32)
+ retrieval_result = self.faiss_retriever.retrieve(query_np)
+
+ # Generate
+ answer = self.generator.generate(question, retrieval_result.passages)
+
+ return PipelineResult(
+ question=question,
+ answer=answer,
+ retrieved_passages=retrieval_result.passages,
+ retrieval_result=retrieval_result,
+ )
+
+ def run_batch(self, questions: List[str]) -> List[PipelineResult]:
+ """Run pipeline on a batch of questions.
+
+ Args:
+ questions: list of question strings
+
+ Returns:
+ List of PipelineResult, one per question.
+ """
+ return [self.run(q) for q in questions]
diff --git a/hag/retriever_faiss.py b/hag/retriever_faiss.py
new file mode 100644
index 0000000..cd54a85
--- /dev/null
+++ b/hag/retriever_faiss.py
@@ -0,0 +1,73 @@
+"""Baseline FAISS top-k retriever for vanilla RAG."""
+
+import logging
+from typing import List, Optional
+
+import faiss
+import numpy as np
+import torch
+
+from hag.datatypes import RetrievalResult
+
+logger = logging.getLogger(__name__)
+
+
+class FAISSRetriever:
+ """Standard top-k retrieval using FAISS inner product search.
+
+ This is the baseline to compare against Hopfield retrieval.
+ """
+
+ def __init__(self, top_k: int = 5) -> None:
+ self.index: Optional[faiss.IndexFlatIP] = None
+ self.passages: List[str] = []
+ self.top_k = top_k
+
+ def build_index(self, embeddings: np.ndarray, passages: List[str]) -> None:
+ """Build FAISS IndexFlatIP from embeddings.
+
+ Args:
+ embeddings: (N, d) numpy array of passage embeddings
+ passages: list of N passage strings
+ """
+ assert embeddings.shape[0] == len(passages)
+ d = embeddings.shape[1]
+ self.index = faiss.IndexFlatIP(d)
+ # Normalize for cosine similarity via inner product
+ faiss.normalize_L2(embeddings)
+ self.index.add(embeddings)
+ self.passages = list(passages)
+ logger.info("Built FAISS index with %d passages, dim=%d", len(passages), d)
+
+ def retrieve(self, query: np.ndarray) -> RetrievalResult:
+ """Retrieve top-k passages for a query.
+
+ Args:
+ query: (d,) or (batch, d) numpy array
+
+ Returns:
+ RetrievalResult with passages, scores, and indices.
+ """
+ assert self.index is not None, "Index not built. Call build_index first."
+
+ if query.ndim == 1:
+ query = query.reshape(1, -1) # (1, d)
+
+ # Normalize query for cosine similarity
+ query_copy = query.copy()
+ faiss.normalize_L2(query_copy)
+
+ scores, indices = self.index.search(query_copy, self.top_k) # (batch, k)
+
+ # Flatten for single query case
+ if scores.shape[0] == 1:
+ scores = scores[0] # (k,)
+ indices = indices[0] # (k,)
+
+ passages = [self.passages[i] for i in indices.flatten().tolist()]
+
+ return RetrievalResult(
+ passages=passages,
+ scores=torch.from_numpy(scores).float(),
+ indices=torch.from_numpy(indices).long(),
+ )
diff --git a/hag/retriever_hopfield.py b/hag/retriever_hopfield.py
new file mode 100644
index 0000000..1cb6968
--- /dev/null
+++ b/hag/retriever_hopfield.py
@@ -0,0 +1,77 @@
+"""Hopfield-based retriever wrapping HopfieldRetrieval + MemoryBank."""
+
+import logging
+from typing import List
+
+import torch
+
+from hag.datatypes import RetrievalResult
+from hag.hopfield import HopfieldRetrieval
+from hag.memory_bank import MemoryBank
+
+logger = logging.getLogger(__name__)
+
+
+class HopfieldRetriever:
+ """Wraps HopfieldRetrieval + MemoryBank into a retriever interface.
+
+ The bridge between Hopfield's continuous retrieval and the discrete
+ passage selection needed for LLM prompting.
+ """
+
+ def __init__(
+ self,
+ hopfield: HopfieldRetrieval,
+ memory_bank: MemoryBank,
+ top_k: int = 5,
+ ) -> None:
+ self.hopfield = hopfield
+ self.memory_bank = memory_bank
+ self.top_k = top_k
+
+ def retrieve(
+ self,
+ query_embedding: torch.Tensor,
+ return_analysis: bool = False,
+ ) -> RetrievalResult:
+ """Retrieve top-k passages using iterative Hopfield retrieval.
+
+ 1. Run Hopfield iterative retrieval -> get attention weights alpha_T
+ 2. Take top_k indices from alpha_T
+ 3. Look up corresponding passage texts from memory bank
+ 4. Optionally return trajectory and energy for analysis
+
+ Args:
+ query_embedding: (d,) or (batch, d) — query embedding
+ return_analysis: if True, include full HopfieldResult
+
+ Returns:
+ RetrievalResult with passages, scores, indices, and optionally
+ the full hopfield_result.
+ """
+ hopfield_result = self.hopfield.retrieve(
+ query_embedding,
+ self.memory_bank.embeddings,
+ return_trajectory=return_analysis,
+ return_energy=return_analysis,
+ )
+
+ alpha = hopfield_result.attention_weights # (batch, N) or (1, N)
+
+ # Get top-k indices and scores
+ k = min(self.top_k, alpha.shape[-1])
+ scores, indices = torch.topk(alpha, k, dim=-1) # (batch, k)
+
+ # Flatten for single-query case
+ if scores.shape[0] == 1:
+ scores = scores.squeeze(0) # (k,)
+ indices = indices.squeeze(0) # (k,)
+
+ passages = self.memory_bank.get_passages_by_indices(indices)
+
+ return RetrievalResult(
+ passages=passages,
+ scores=scores,
+ indices=indices,
+ hopfield_result=hopfield_result if return_analysis else None,
+ )