"""Wrapper for encoding queries and passages into embeddings.""" import logging from typing import List, Union import torch from hag.config import EncoderConfig logger = logging.getLogger(__name__) class Encoder: """Encodes text queries/passages into dense embeddings. Uses a HuggingFace transformer model (e.g., Contriever). For testing, use FakeEncoder instead. """ def __init__(self, config: EncoderConfig) -> None: self.config = config self._tokenizer = None self._model = None def _load_model(self) -> None: """Lazy-load the model and tokenizer.""" from transformers import AutoModel, AutoTokenizer logger.info("Loading encoder model: %s", self.config.model_name) self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name) self._model = AutoModel.from_pretrained(self.config.model_name) self._model.eval() @torch.no_grad() def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: """Encode text(s) into embedding(s). Args: texts: single string or list of strings Returns: (1, d) tensor for single input, (N, d) for list input. """ if self._model is None: self._load_model() if isinstance(texts, str): texts = [texts] inputs = self._tokenizer( texts, max_length=self.config.max_length, padding=True, truncation=True, return_tensors="pt", ) outputs = self._model(**inputs) # Mean pooling over token embeddings embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d) return embeddings class FakeEncoder: """Deterministic hash-based encoder for testing. No model download needed.""" def __init__(self, dim: int = 64) -> None: self.dim = dim def encode(self, texts: Union[str, List[str]]) -> torch.Tensor: """Produce deterministic embeddings based on text hash. Args: texts: single string or list of strings Returns: (1, d) or (N, d) normalized tensor. """ if isinstance(texts, str): texts = [texts] embeddings = [] for text in texts: torch.manual_seed(hash(text) % 2**32) emb = torch.randn(1, self.dim) embeddings.append(emb) result = torch.cat(embeddings, dim=0) # (N, d) return torch.nn.functional.normalize(result, dim=-1)