diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-02-15 18:19:50 +0000 |
| commit | c90b48e3f8da9dd0f8d2ae82ddf977436bb0cfc3 (patch) | |
| tree | 43edac8013fec4e65a0b9cddec5314489b4aafc2 /hag/encoder.py | |
Core Hopfield retrieval module with energy-based convergence guarantees,
memory bank, FAISS baseline retriever, evaluation metrics, and end-to-end
pipeline. All 45 tests passing on CPU with synthetic data.
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
Diffstat (limited to 'hag/encoder.py')
| -rw-r--r-- | hag/encoder.py | 88 |
1 files changed, 88 insertions, 0 deletions
diff --git a/hag/encoder.py b/hag/encoder.py new file mode 100644 index 0000000..7e103f3 --- /dev/null +++ b/hag/encoder.py @@ -0,0 +1,88 @@ +"""Wrapper for encoding queries and passages into embeddings.""" + +import logging +from typing import List, Union + +import torch + +from hag.config import EncoderConfig + +logger = logging.getLogger(__name__) + + +class Encoder: + """Encodes text queries/passages into dense embeddings. + + Uses a HuggingFace transformer model (e.g., Contriever). + For testing, use FakeEncoder instead. + """ + + def __init__(self, config: EncoderConfig) -> None: + self.config = config + self._tokenizer = None + self._model = None + + def _load_model(self) -> None: + """Lazy-load the model and tokenizer.""" + from transformers import AutoModel, AutoTokenizer + + logger.info("Loading encoder model: %s", self.config.model_name) + self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) + self._model = AutoModel.from_pretrained(self.config.model_name) + self._model.eval() + + @torch.no_grad() + def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: + """Encode text(s) into embedding(s). + + Args: + texts: single string or list of strings + + Returns: + (1, d) tensor for single input, (N, d) for list input. + """ + if self._model is None: + self._load_model() + + if isinstance(texts, str): + texts = [texts] + + inputs = self._tokenizer( + texts, + max_length=self.config.max_length, + padding=True, + truncation=True, + return_tensors="pt", + ) + outputs = self._model(**inputs) + # Mean pooling over token embeddings + embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d) + return embeddings + + +class FakeEncoder: + """Deterministic hash-based encoder for testing. No model download needed.""" + + def __init__(self, dim: int = 64) -> None: + self.dim = dim + + def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: + """Produce deterministic embeddings based on text hash. + + Args: + texts: single string or list of strings + + Returns: + (1, d) or (N, d) normalized tensor. + """ + if isinstance(texts, str): + texts = [texts] + + embeddings = [] + for text in texts: + torch.manual_seed(hash(text) % 2**32) + emb = torch.randn(1, self.dim) + embeddings.append(emb) + + result = torch.cat(embeddings, dim=0) # (N, d) + return torch.nn.functional.normalize(result, dim=-1) |
