diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag')
| -rw-r--r-- | hag/__init__.py | 21 | ||||
| -rw-r--r-- | hag/config.py | 50 | ||||
| -rw-r--r-- | hag/datatypes.py | 37 | ||||
| -rw-r--r-- | hag/encoder.py | 88 | ||||
| -rw-r--r-- | hag/energy.py | 83 | ||||
| -rw-r--r-- | hag/generator.py | 87 | ||||
| -rw-r--r-- | hag/hopfield.py | 124 | ||||
| -rw-r--r-- | hag/memory_bank.py | 93 | ||||
| -rw-r--r-- | hag/metrics.py | 113 | ||||
| -rw-r--r-- | hag/pipeline.py | 107 | ||||
| -rw-r--r-- | hag/retriever_faiss.py | 73 | ||||
| -rw-r--r-- | hag/retriever_hopfield.py | 77 |
12 files changed, 953 insertions, 0 deletions
diff --git a/hag/__init__.py b/hag/__init__.py new file mode 100644 index 0000000..18496e9 --- /dev/null +++ b/hag/__init__.py @@ -0,0 +1,21 @@ +"""HAG: Hopfield-Augmented Generation.""" + +from hag.config import ( + EncoderConfig, + GeneratorConfig, + HopfieldConfig, + MemoryBankConfig, + PipelineConfig, +) +from hag.datatypes import HopfieldResult, PipelineResult, RetrievalResult + +__all__ = [ + "HopfieldConfig", + "MemoryBankConfig", + "EncoderConfig", + "GeneratorConfig", + "PipelineConfig", + "HopfieldResult", + "RetrievalResult", + "PipelineResult", +] diff --git a/hag/config.py b/hag/config.py new file mode 100644 index 0000000..793e3a6 --- /dev/null +++ b/hag/config.py @@ -0,0 +1,50 @@ +"""All hyperparameters and configuration dataclasses for HAG.""" + +from dataclasses import dataclass, field + + +@dataclass +class HopfieldConfig: + """Configuration for the Hopfield retrieval module.""" + + beta: float = 1.0 # Inverse temperature. Higher = sharper retrieval + max_iter: int = 5 # Maximum Hopfield iteration steps + conv_threshold: float = 1e-4 # Stop if ||q_{t+1} - q_t|| < threshold + top_k: int = 5 # Number of passages to retrieve from final attention weights + + +@dataclass +class MemoryBankConfig: + """Configuration for the memory bank.""" + + embedding_dim: int = 768 # Must match encoder output dim + normalize: bool = True # L2-normalize embeddings in memory bank + + +@dataclass +class EncoderConfig: + """Configuration for the query/passage encoder.""" + + model_name: str = "facebook/contriever-msmarco" + max_length: int = 512 + batch_size: int = 64 + + +@dataclass +class GeneratorConfig: + """Configuration for the LLM generator.""" + + model_name: str = "meta-llama/Llama-3.1-8B-Instruct" + max_new_tokens: int = 128 + temperature: float = 0.0 # Greedy decoding for reproducibility + + +@dataclass +class PipelineConfig: + """Top-level pipeline configuration.""" + + hopfield: HopfieldConfig = field(default_factory=HopfieldConfig) + memory: MemoryBankConfig = field(default_factory=MemoryBankConfig) + encoder: EncoderConfig = field(default_factory=EncoderConfig) + generator: GeneratorConfig = field(default_factory=GeneratorConfig) + retriever_type: str = "hopfield" # "hopfield" or "faiss" diff --git a/hag/datatypes.py b/hag/datatypes.py new file mode 100644 index 0000000..0f4254d --- /dev/null +++ b/hag/datatypes.py @@ -0,0 +1,37 @@ +"""Data types used across HAG modules.""" + +from dataclasses import dataclass, field +from typing import List, Optional + +import torch + + +@dataclass +class HopfieldResult: + """Result from Hopfield iterative retrieval.""" + + attention_weights: torch.Tensor # (batch, N) or (N,) + converged_query: torch.Tensor # (batch, d) or (d,) + num_steps: int + trajectory: Optional[List[torch.Tensor]] = None # list of q_t + energy_curve: Optional[List[torch.Tensor]] = None # list of E(q_t) + + +@dataclass +class RetrievalResult: + """Result from a retriever (FAISS or Hopfield).""" + + passages: List[str] + scores: torch.Tensor # top-k scores + indices: torch.Tensor # top-k indices + hopfield_result: Optional[HopfieldResult] = None + + +@dataclass +class PipelineResult: + """Result from the full RAG/HAG pipeline.""" + + question: str + answer: str + retrieved_passages: List[str] + retrieval_result: RetrievalResult diff --git a/hag/encoder.py b/hag/encoder.py new file mode 100644 index 0000000..7e103f3 --- /dev/null +++ b/hag/encoder.py @@ -0,0 +1,88 @@ +"""Wrapper for encoding queries and passages into embeddings.""" + +import logging +from typing import List, Union + +import torch + +from hag.config import EncoderConfig + +logger = logging.getLogger(__name__) + + +class Encoder: + """Encodes text queries/passages into dense embeddings. + + Uses a HuggingFace transformer model (e.g., Contriever). + For testing, use FakeEncoder instead. + """ + + def __init__(self, config: EncoderConfig) -> None: + self.config = config + self._tokenizer = None + self._model = None + + def _load_model(self) -> None: + """Lazy-load the model and tokenizer.""" + from transformers import AutoModel, AutoTokenizer + + logger.info("Loading encoder model: %s", self.config.model_name) + self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) + self._model = AutoModel.from_pretrained(self.config.model_name) + self._model.eval() + + @torch.no_grad() + def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: + """Encode text(s) into embedding(s). + + Args: + texts: single string or list of strings + + Returns: + (1, d) tensor for single input, (N, d) for list input. + """ + if self._model is None: + self._load_model() + + if isinstance(texts, str): + texts = [texts] + + inputs = self._tokenizer( + texts, + max_length=self.config.max_length, + padding=True, + truncation=True, + return_tensors="pt", + ) + outputs = self._model(**inputs) + # Mean pooling over token embeddings + embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d) + return embeddings + + +class FakeEncoder: + """Deterministic hash-based encoder for testing. No model download needed.""" + + def __init__(self, dim: int = 64) -> None: + self.dim = dim + + def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: + """Produce deterministic embeddings based on text hash. + + Args: + texts: single string or list of strings + + Returns: + (1, d) or (N, d) normalized tensor. + """ + if isinstance(texts, str): + texts = [texts] + + embeddings = [] + for text in texts: + torch.manual_seed(hash(text) % 2**32) + emb = torch.randn(1, self.dim) + embeddings.append(emb) + + result = torch.cat(embeddings, dim=0) # (N, d) + return torch.nn.functional.normalize(result, dim=-1) diff --git a/hag/energy.py b/hag/energy.py new file mode 100644 index 0000000..62a39e9 --- /dev/null +++ b/hag/energy.py @@ -0,0 +1,83 @@ +"""Energy computation and analysis utilities for Hopfield retrieval.""" + +import logging +from typing import List + +import torch + +from hag.datatypes import HopfieldResult + +logger = logging.getLogger(__name__) + + +def compute_energy_curve(hopfield_result: HopfieldResult) -> List[float]: + """Extract energy values at each iteration step. + + Args: + hopfield_result: result from HopfieldRetrieval.retrieve() with return_energy=True + + Returns: + List of energy values (floats) at each step. + """ + if hopfield_result.energy_curve is None: + return [] + return [e.item() if e.dim() == 0 else e.mean().item() for e in hopfield_result.energy_curve] + + +def compute_energy_gap(energy_curve: List[float]) -> float: + """Compute the energy gap: Delta_E = E(q_0) - E(q_T). + + Larger gap means more refinement happened during iteration. + + Args: + energy_curve: list of energy values at each step + + Returns: + Energy gap (float). Positive if energy decreased. + """ + if len(energy_curve) < 2: + return 0.0 + return energy_curve[0] - energy_curve[-1] + + +def verify_monotonic_decrease(energy_curve: List[float], tol: float = 1e-6) -> bool: + """Check that E(q_{t+1}) <= E(q_t) for all t. + + This should always be True for the Modern Hopfield Network. + + Args: + energy_curve: list of energy values at each step + tol: numerical tolerance for comparison + + Returns: + True if energy decreases monotonically (within tolerance). + """ + for i in range(len(energy_curve) - 1): + if energy_curve[i + 1] > energy_curve[i] + tol: + return False + return True + + +def compute_attention_entropy(attention_weights: torch.Tensor) -> float: + """Compute the entropy of attention weights. + + H(alpha) = -sum_i alpha_i * log(alpha_i) + + Low entropy = sharp retrieval (confident). + High entropy = diffuse retrieval (uncertain). + + Args: + attention_weights: (N,) or (batch, N) — attention distribution + + Returns: + Entropy value (float). Averaged over batch if batched. + """ + if attention_weights.dim() == 1: + attention_weights = attention_weights.unsqueeze(0) # (1, N) + + # Clamp to avoid log(0) + eps = 1e-12 + alpha = attention_weights.clamp(min=eps) + entropy = -(alpha * alpha.log()).sum(dim=-1) # (batch,) + + return entropy.mean().item() diff --git a/hag/generator.py b/hag/generator.py new file mode 100644 index 0000000..2142e0c --- /dev/null +++ b/hag/generator.py @@ -0,0 +1,87 @@ +"""LLM generation wrapper for producing answers from retrieved context.""" + +import logging +from typing import List + +from hag.config import GeneratorConfig + +logger = logging.getLogger(__name__) + +PROMPT_TEMPLATE = """Answer the following question based on the provided context passages. + +Context: +{context} + +Question: {question} + +Answer:""" + + +class Generator: + """LLM-based answer generator. + + Uses a HuggingFace causal LM (e.g., Llama-3.1-8B-Instruct). + For testing, use FakeGenerator instead. + """ + + def __init__(self, config: GeneratorConfig) -> None: + self.config = config + self._tokenizer = None + self._model = None + + def _load_model(self) -> None: + """Lazy-load the model and tokenizer.""" + from transformers import AutoModelForCausalLM, AutoTokenizer + + logger.info("Loading generator model: %s", self.config.model_name) + self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) + self._model = AutoModelForCausalLM.from_pretrained( + self.config.model_name, + torch_dtype="auto", + ) + self._model.eval() + + def generate(self, question: str, passages: List[str]) -> str: + """Generate an answer given a question and retrieved passages. + + Args: + question: the user question + passages: list of retrieved passage texts + + Returns: + Generated answer string. + """ + if self._model is None: + self._load_model() + + context = "\n\n".join( + f"[{i+1}] {p}" for i, p in enumerate(passages) + ) + prompt = PROMPT_TEMPLATE.format(context=context, question=question) + + inputs = self._tokenizer(prompt, return_tensors="pt") + outputs = self._model.generate( + **inputs, + max_new_tokens=self.config.max_new_tokens, + temperature=self.config.temperature if self.config.temperature > 0 else None, + do_sample=self.config.temperature > 0, + ) + # Decode only the generated tokens (skip the prompt) + generated = outputs[0][inputs["input_ids"].shape[1]:] + return self._tokenizer.decode(generated, skip_special_tokens=True).strip() + + +class FakeGenerator: + """Deterministic mock generator for testing. No model download needed.""" + + def generate(self, question: str, passages: List[str]) -> str: + """Return a mock answer. + + Args: + question: the user question + passages: list of retrieved passages + + Returns: + Mock answer string. + """ + return "mock answer" diff --git a/hag/hopfield.py b/hag/hopfield.py new file mode 100644 index 0000000..287e4af --- /dev/null +++ b/hag/hopfield.py @@ -0,0 +1,124 @@ +"""Core Modern Continuous Hopfield Network retrieval module. + +Implements the iterative retrieval dynamics from: + Ramsauer et al., "Hopfield Networks is All You Need" (ICLR 2021) + +Update rule: q_{t+1} = M * softmax(beta * M^T * q_t) +Energy: E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 +""" + +import logging +from typing import Optional + +import torch + +from hag.config import HopfieldConfig +from hag.datatypes import HopfieldResult + +logger = logging.getLogger(__name__) + + +class HopfieldRetrieval: + """Modern Continuous Hopfield Network for memory retrieval. + + Given memory bank M in R^{d x N} and query q in R^d: + 1. Compute attention: alpha = softmax(beta * M^T @ q) + 2. Update query: q_new = M @ alpha + 3. Repeat until convergence or max_iter + + The energy function is: + E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 + + Key property: E(q_{t+1}) <= E(q_t) (monotonic decrease) + """ + + def __init__(self, config: HopfieldConfig) -> None: + self.config = config + + @torch.no_grad() + def retrieve( + self, + query: torch.Tensor, + memory: torch.Tensor, + return_trajectory: bool = False, + return_energy: bool = False, + ) -> HopfieldResult: + """Run iterative Hopfield retrieval. + + Args: + query: (d,) or (batch, d) — query embedding(s) + memory: (d, N) — memory bank of passage embeddings + return_trajectory: if True, store q_t at each step + return_energy: if True, store E(q_t) at each step + + Returns: + HopfieldResult with attention_weights, converged_query, num_steps, + and optionally trajectory and energy_curve. + """ + # Ensure query is 2D: (batch, d) + if query.dim() == 1: + query = query.unsqueeze(0) # (1, d) + + q = query.clone() # (batch, d) + + trajectory = [q.clone()] if return_trajectory else None + energies = [self.compute_energy(q, memory)] if return_energy else None + + num_steps = 0 + for t in range(self.config.max_iter): + # Core Hopfield update + logits = self.config.beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + q_new = alpha @ memory.T # (batch, d) + + # Check convergence + delta = torch.norm(q_new - q, dim=-1).max() # scalar + q = q_new + + if return_trajectory: + trajectory.append(q.clone()) + if return_energy: + energies.append(self.compute_energy(q, memory)) + + num_steps = t + 1 + + if delta < self.config.conv_threshold: + break + + # Final attention weights (recompute to ensure consistency) + logits = self.config.beta * (q @ memory) # (batch, N) + alpha = torch.softmax(logits, dim=-1) # (batch, N) + + return HopfieldResult( + attention_weights=alpha, + converged_query=q, + num_steps=num_steps, + trajectory=trajectory, + energy_curve=energies, + ) + + def compute_energy( + self, + query: torch.Tensor, + memory: torch.Tensor, + ) -> torch.Tensor: + """Compute the Hopfield energy function. + + E(q) = -1/beta * log(sum_i exp(beta * q^T m_i)) + 1/2 * ||q||^2 + + Args: + query: (batch, d) or (d,) — query embedding(s) + memory: (d, N) — memory bank + + Returns: + Energy scalar or (batch,) tensor. + """ + if query.dim() == 1: + query = query.unsqueeze(0) # (1, d) + + logits = self.config.beta * (query @ memory) # (batch, N) + lse = torch.logsumexp(logits, dim=-1) # (batch,) + norm_sq = 0.5 * (query**2).sum(dim=-1) # (batch,) + energy = -1.0 / self.config.beta * lse + norm_sq # (batch,) + + return energy diff --git a/hag/memory_bank.py b/hag/memory_bank.py new file mode 100644 index 0000000..42dcc73 --- /dev/null +++ b/hag/memory_bank.py @@ -0,0 +1,93 @@ +"""Memory bank construction and management for passage embeddings.""" + +import logging +from typing import Dict, List, Optional + +import torch +import torch.nn.functional as F + +from hag.config import MemoryBankConfig + +logger = logging.getLogger(__name__) + + +class MemoryBank: + """Stores passage embeddings and provides lookup from indices back to text. + + The memory bank is M in R^{d x N} where each column is a passage embedding. + Also maintains a mapping from column index to passage text for final retrieval. + """ + + def __init__(self, config: MemoryBankConfig) -> None: + self.config = config + self.embeddings: Optional[torch.Tensor] = None # (d, N) + self.passages: List[str] = [] + + def build_from_embeddings( + self, embeddings: torch.Tensor, passages: List[str] + ) -> None: + """Build memory bank from precomputed embeddings. + + Args: + embeddings: (N, d) — passage embeddings (note: input is N x d) + passages: list of N passage strings + """ + assert embeddings.shape[0] == len(passages), ( + f"Number of embeddings ({embeddings.shape[0]}) must match " + f"number of passages ({len(passages)})" + ) + if self.config.normalize: + embeddings = F.normalize(embeddings, dim=-1) + self.embeddings = embeddings.T # Store as (d, N) for efficient matmul + self.passages = list(passages) + logger.info("Built memory bank with %d passages, dim=%d", self.size, self.dim) + + def get_passages_by_indices(self, indices: torch.Tensor) -> List[str]: + """Given top-k indices, return corresponding passage texts. + + Args: + indices: (k,) or (batch, k) tensor of integer indices + + Returns: + List of passage strings. + """ + flat_indices = indices.flatten().tolist() + return [self.passages[i] for i in flat_indices] + + def save(self, path: str) -> None: + """Save memory bank to disk. + + Args: + path: file path for saving (e.g., 'memory_bank.pt') + """ + data: Dict = { + "embeddings": self.embeddings, + "passages": self.passages, + "config": { + "embedding_dim": self.config.embedding_dim, + "normalize": self.config.normalize, + }, + } + torch.save(data, path) + logger.info("Saved memory bank to %s", path) + + def load(self, path: str) -> None: + """Load memory bank from disk. + + Args: + path: file path to load from + """ + data = torch.load(path, weights_only=False) + self.embeddings = data["embeddings"] + self.passages = data["passages"] + logger.info("Loaded memory bank from %s (%d passages)", path, self.size) + + @property + def size(self) -> int: + """Number of passages in the memory bank.""" + return self.embeddings.shape[1] if self.embeddings is not None else 0 + + @property + def dim(self) -> int: + """Embedding dimensionality.""" + return self.embeddings.shape[0] if self.embeddings is not None else 0 diff --git a/hag/metrics.py b/hag/metrics.py new file mode 100644 index 0000000..6a196df --- /dev/null +++ b/hag/metrics.py @@ -0,0 +1,113 @@ +"""Evaluation metrics for HAG: exact match, F1, retrieval recall.""" + +import logging +import re +import string +from collections import Counter +from typing import Dict, List + +from hag.datatypes import PipelineResult + +logger = logging.getLogger(__name__) + + +def _normalize_answer(text: str) -> str: + """Normalize answer text: lowercase, strip, remove articles and punctuation.""" + text = text.lower().strip() + # Remove articles + text = re.sub(r"\b(a|an|the)\b", " ", text) + # Remove punctuation + text = text.translate(str.maketrans("", "", string.punctuation)) + # Collapse whitespace + text = " ".join(text.split()) + return text + + +def exact_match(prediction: str, ground_truth: str) -> float: + """Normalized exact match. + + Args: + prediction: predicted answer string + ground_truth: gold answer string + + Returns: + 1.0 if normalized strings match, 0.0 otherwise. + """ + return float(_normalize_answer(prediction) == _normalize_answer(ground_truth)) + + +def f1_score(prediction: str, ground_truth: str) -> float: + """Token-level F1 between prediction and ground truth. + + Args: + prediction: predicted answer string + ground_truth: gold answer string + + Returns: + F1 score between 0.0 and 1.0. + """ + pred_tokens = _normalize_answer(prediction).split() + gold_tokens = _normalize_answer(ground_truth).split() + + if not pred_tokens and not gold_tokens: + return 1.0 + if not pred_tokens or not gold_tokens: + return 0.0 + + common = Counter(pred_tokens) & Counter(gold_tokens) + num_same = sum(common.values()) + + if num_same == 0: + return 0.0 + + precision = num_same / len(pred_tokens) + recall = num_same / len(gold_tokens) + f1 = 2 * precision * recall / (precision + recall) + return f1 + + +def retrieval_recall_at_k( + retrieved_indices: List[int], gold_indices: List[int], k: int +) -> float: + """What fraction of gold passages appear in the retrieved top-k? + + Args: + retrieved_indices: list of retrieved passage indices (top-k) + gold_indices: list of gold/relevant passage indices + k: number of retrieved passages to consider + + Returns: + Recall score between 0.0 and 1.0. + """ + if not gold_indices: + return 1.0 + retrieved_set = set(retrieved_indices[:k]) + gold_set = set(gold_indices) + return len(retrieved_set & gold_set) / len(gold_set) + + +def evaluate_dataset( + results: List[PipelineResult], gold_answers: List[str] +) -> Dict[str, float]: + """Compute aggregate metrics over a dataset. + + Args: + results: list of PipelineResult from the pipeline + gold_answers: list of gold answer strings + + Returns: + Dict with keys 'em', 'f1' containing averaged scores. + """ + assert len(results) == len(gold_answers) + + em_scores = [] + f1_scores = [] + + for result, gold in zip(results, gold_answers): + em_scores.append(exact_match(result.answer, gold)) + f1_scores.append(f1_score(result.answer, gold)) + + return { + "em": sum(em_scores) / len(em_scores) if em_scores else 0.0, + "f1": sum(f1_scores) / len(f1_scores) if f1_scores else 0.0, + } diff --git a/hag/pipeline.py b/hag/pipeline.py new file mode 100644 index 0000000..1fefb84 --- /dev/null +++ b/hag/pipeline.py @@ -0,0 +1,107 @@ +"""End-to-end RAG/HAG pipeline: query -> encode -> retrieve -> generate.""" + +import logging +from typing import List, Optional, Protocol, Union + +import numpy as np +import torch + +from hag.config import PipelineConfig +from hag.datatypes import PipelineResult, RetrievalResult +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank +from hag.retriever_faiss import FAISSRetriever +from hag.retriever_hopfield import HopfieldRetriever + +logger = logging.getLogger(__name__) + + +class EncoderProtocol(Protocol): + """Protocol for encoder interface.""" + + def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: ... + + +class GeneratorProtocol(Protocol): + """Protocol for generator interface.""" + + def generate(self, question: str, passages: List[str]) -> str: ... + + +class RAGPipeline: + """End-to-end pipeline: query -> encode -> retrieve -> generate. + + Supports both FAISS (baseline) and Hopfield (ours) retrieval. + """ + + def __init__( + self, + config: PipelineConfig, + encoder: EncoderProtocol, + generator: GeneratorProtocol, + memory_bank: Optional[MemoryBank] = None, + faiss_retriever: Optional[FAISSRetriever] = None, + ) -> None: + self.config = config + self.encoder = encoder + self.generator = generator + + if config.retriever_type == "faiss": + assert faiss_retriever is not None, "FAISSRetriever required for faiss mode" + self.retriever_type = "faiss" + self.faiss_retriever = faiss_retriever + self.hopfield_retriever: Optional[HopfieldRetriever] = None + elif config.retriever_type == "hopfield": + assert memory_bank is not None, "MemoryBank required for hopfield mode" + hopfield = HopfieldRetrieval(config.hopfield) + self.retriever_type = "hopfield" + self.hopfield_retriever = HopfieldRetriever( + hopfield, memory_bank, top_k=config.hopfield.top_k + ) + self.faiss_retriever = None + else: + raise ValueError(f"Unknown retriever_type: {config.retriever_type}") + + def run(self, question: str) -> PipelineResult: + """Run the full pipeline on a single question. + + 1. Encode question -> query embedding + 2. Retrieve passages (FAISS or Hopfield) + 3. Generate answer with LLM + + Args: + question: input question string + + Returns: + PipelineResult with answer and retrieval metadata. + """ + # Encode + query_emb = self.encoder.encode(question) # (1, d) + + # Retrieve + if self.retriever_type == "hopfield": + retrieval_result = self.hopfield_retriever.retrieve(query_emb) + else: + query_np = query_emb.detach().numpy().astype(np.float32) + retrieval_result = self.faiss_retriever.retrieve(query_np) + + # Generate + answer = self.generator.generate(question, retrieval_result.passages) + + return PipelineResult( + question=question, + answer=answer, + retrieved_passages=retrieval_result.passages, + retrieval_result=retrieval_result, + ) + + def run_batch(self, questions: List[str]) -> List[PipelineResult]: + """Run pipeline on a batch of questions. + + Args: + questions: list of question strings + + Returns: + List of PipelineResult, one per question. + """ + return [self.run(q) for q in questions] diff --git a/hag/retriever_faiss.py b/hag/retriever_faiss.py new file mode 100644 index 0000000..cd54a85 --- /dev/null +++ b/hag/retriever_faiss.py @@ -0,0 +1,73 @@ +"""Baseline FAISS top-k retriever for vanilla RAG.""" + +import logging +from typing import List, Optional + +import faiss +import numpy as np +import torch + +from hag.datatypes import RetrievalResult + +logger = logging.getLogger(__name__) + + +class FAISSRetriever: + """Standard top-k retrieval using FAISS inner product search. + + This is the baseline to compare against Hopfield retrieval. + """ + + def __init__(self, top_k: int = 5) -> None: + self.index: Optional[faiss.IndexFlatIP] = None + self.passages: List[str] = [] + self.top_k = top_k + + def build_index(self, embeddings: np.ndarray, passages: List[str]) -> None: + """Build FAISS IndexFlatIP from embeddings. + + Args: + embeddings: (N, d) numpy array of passage embeddings + passages: list of N passage strings + """ + assert embeddings.shape[0] == len(passages) + d = embeddings.shape[1] + self.index = faiss.IndexFlatIP(d) + # Normalize for cosine similarity via inner product + faiss.normalize_L2(embeddings) + self.index.add(embeddings) + self.passages = list(passages) + logger.info("Built FAISS index with %d passages, dim=%d", len(passages), d) + + def retrieve(self, query: np.ndarray) -> RetrievalResult: + """Retrieve top-k passages for a query. + + Args: + query: (d,) or (batch, d) numpy array + + Returns: + RetrievalResult with passages, scores, and indices. + """ + assert self.index is not None, "Index not built. Call build_index first." + + if query.ndim == 1: + query = query.reshape(1, -1) # (1, d) + + # Normalize query for cosine similarity + query_copy = query.copy() + faiss.normalize_L2(query_copy) + + scores, indices = self.index.search(query_copy, self.top_k) # (batch, k) + + # Flatten for single query case + if scores.shape[0] == 1: + scores = scores[0] # (k,) + indices = indices[0] # (k,) + + passages = [self.passages[i] for i in indices.flatten().tolist()] + + return RetrievalResult( + passages=passages, + scores=torch.from_numpy(scores).float(), + indices=torch.from_numpy(indices).long(), + ) diff --git a/hag/retriever_hopfield.py b/hag/retriever_hopfield.py new file mode 100644 index 0000000..1cb6968 --- /dev/null +++ b/hag/retriever_hopfield.py @@ -0,0 +1,77 @@ +"""Hopfield-based retriever wrapping HopfieldRetrieval + MemoryBank.""" + +import logging +from typing import List + +import torch + +from hag.datatypes import RetrievalResult +from hag.hopfield import HopfieldRetrieval +from hag.memory_bank import MemoryBank + +logger = logging.getLogger(__name__) + + +class HopfieldRetriever: + """Wraps HopfieldRetrieval + MemoryBank into a retriever interface. + + The bridge between Hopfield's continuous retrieval and the discrete + passage selection needed for LLM prompting. + """ + + def __init__( + self, + hopfield: HopfieldRetrieval, + memory_bank: MemoryBank, + top_k: int = 5, + ) -> None: + self.hopfield = hopfield + self.memory_bank = memory_bank + self.top_k = top_k + + def retrieve( + self, + query_embedding: torch.Tensor, + return_analysis: bool = False, + ) -> RetrievalResult: + """Retrieve top-k passages using iterative Hopfield retrieval. + + 1. Run Hopfield iterative retrieval -> get attention weights alpha_T + 2. Take top_k indices from alpha_T + 3. Look up corresponding passage texts from memory bank + 4. Optionally return trajectory and energy for analysis + + Args: + query_embedding: (d,) or (batch, d) — query embedding + return_analysis: if True, include full HopfieldResult + + Returns: + RetrievalResult with passages, scores, indices, and optionally + the full hopfield_result. + """ + hopfield_result = self.hopfield.retrieve( + query_embedding, + self.memory_bank.embeddings, + return_trajectory=return_analysis, + return_energy=return_analysis, + ) + + alpha = hopfield_result.attention_weights # (batch, N) or (1, N) + + # Get top-k indices and scores + k = min(self.top_k, alpha.shape[-1]) + scores, indices = torch.topk(alpha, k, dim=-1) # (batch, k) + + # Flatten for single-query case + if scores.shape[0] == 1: + scores = scores.squeeze(0) # (k,) + indices = indices.squeeze(0) # (k,) + + passages = self.memory_bank.get_passages_by_indices(indices) + + return RetrievalResult( + passages=passages, + scores=scores, + indices=indices, + hopfield_result=hopfield_result if return_analysis else None, + ) |
