summaryrefslogtreecommitdiff
path: root/hag/encoder.py
blob: 7e103f3bbb8a6b213c64a7249dfdfe501a190630 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
"""Wrapper for encoding queries and passages into embeddings."""

import logging
from typing import List, Union

import torch

from hag.config import EncoderConfig

logger = logging.getLogger(__name__)


class Encoder:
    """Encodes text queries/passages into dense embeddings.

    Uses a HuggingFace transformer model (e.g., Contriever).
    For testing, use FakeEncoder instead.
    """

    def __init__(self, config: EncoderConfig) -> None:
        self.config = config
        self._tokenizer = None
        self._model = None

    def _load_model(self) -> None:
        """Lazy-load the model and tokenizer."""
        from transformers import AutoModel, AutoTokenizer

        logger.info("Loading encoder model: %s", self.config.model_name)
        self._tokenizer = AutoTokenizer.from_pretrained(self.config.model_name)
        self._model = AutoModel.from_pretrained(self.config.model_name)
        self._model.eval()

    @torch.no_grad()
    def encode(self, texts: Union[str, List[str]]) -> torch.Tensor:
        """Encode text(s) into embedding(s).

        Args:
            texts: single string or list of strings

        Returns:
            (1, d) tensor for single input, (N, d) for list input.
        """
        if self._model is None:
            self._load_model()

        if isinstance(texts, str):
            texts = [texts]

        inputs = self._tokenizer(
            texts,
            max_length=self.config.max_length,
            padding=True,
            truncation=True,
            return_tensors="pt",
        )
        outputs = self._model(**inputs)
        # Mean pooling over token embeddings
        embeddings = outputs.last_hidden_state.mean(dim=1)  # (N, d)
        return embeddings


class FakeEncoder:
    """Deterministic hash-based encoder for testing. No model download needed."""

    def __init__(self, dim: int = 64) -> None:
        self.dim = dim

    def encode(self, texts: Union[str, List[str]]) -> torch.Tensor:
        """Produce deterministic embeddings based on text hash.

        Args:
            texts: single string or list of strings

        Returns:
            (1, d) or (N, d) normalized tensor.
        """
        if isinstance(texts, str):
            texts = [texts]

        embeddings = []
        for text in texts:
            torch.manual_seed(hash(text) % 2**32)
            emb = torch.randn(1, self.dim)
            embeddings.append(emb)

        result = torch.cat(embeddings, dim=0)  # (N, d)
        return torch.nn.functional.normalize(result, dim=-1)