summaryrefslogtreecommitdiff
path: root/hag/encoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'hag/encoder.py')
-rw-r--r--hag/encoder.py88
1 files changed, 88 insertions, 0 deletions
diff --git a/hag/encoder.py b/hag/encoder.py
new file mode 100644
index 0000000..7e103f3
--- /dev/null
+++ b/hag/encoder.py
@@ -0,0 +1,88 @@
+"""Wrapper for encoding queries and passages into embeddings."""
+
+import logging
+from typing import List, Union
+
+import torch
+
+from hag.config import EncoderConfig
+
+logger = logging.getLogger(__name__)
+
+
+class Encoder:
+ """Encodes text queries/passages into dense embeddings.
+
+ Uses a HuggingFace transformer model (e.g., Contriever).
+ For testing, use FakeEncoder instead.
+ """
+
+ def __init__(self, config: EncoderConfig) -> None:
+ self.config = config
+ self._tokenizer = None
+ self._model = None
+
+ def _load_model(self) -> None:
+ """Lazy-load the model and tokenizer."""
+ from transformers import AutoModel, AutoTokenizer
+
+ logger.info("Loading encoder model: %s", self.config.model_name)
+ self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
+ self._model = AutoModel.from_pretrained(self.config.model_name)
+ self._model.eval()
+
+ @torch.no_grad()
+ def encode(self, texts: Union[str, List[str]]) -> torch.Tensor:
+ """Encode text(s) into embedding(s).
+
+ Args:
+ texts: single string or list of strings
+
+ Returns:
+ (1, d) tensor for single input, (N, d) for list input.
+ """
+ if self._model is None:
+ self._load_model()
+
+ if isinstance(texts, str):
+ texts = [texts]
+
+ inputs = self._tokenizer(
+ texts,
+ max_length=self.config.max_length,
+ padding=True,
+ truncation=True,
+ return_tensors="pt",
+ )
+ outputs = self._model(**inputs)
+ # Mean pooling over token embeddings
+ embeddings = outputs.last_hidden_state.mean(dim=1) # (N, d)
+ return embeddings
+
+
+class FakeEncoder:
+ """Deterministic hash-based encoder for testing. No model download needed."""
+
+ def __init__(self, dim: int = 64) -> None:
+ self.dim = dim
+
+ def encode(self, texts: Union[str, List[str]]) -> torch.Tensor:
+ """Produce deterministic embeddings based on text hash.
+
+ Args:
+ texts: single string or list of strings
+
+ Returns:
+ (1, d) or (N, d) normalized tensor.
+ """
+ if isinstance(texts, str):
+ texts = [texts]
+
+ embeddings = []
+ for text in texts:
+ torch.manual_seed(hash(text) % 2**32)
+ emb = torch.randn(1, self.dim)
+ embeddings.append(emb)
+
+ result = torch.cat(embeddings, dim=0) # (N, d)
+ return torch.nn.functional.normalize(result, dim=-1)