summaryrefslogtreecommitdiff
path: root/src/personalization/models/embedding/base.py
blob: 9f9d4d11c596c2e562c403f21935ababd9769359 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
from __future__ import annotations

from abc import ABC, abstractmethod
from typing import Iterable, List, Sequence

import torch


class EmbeddingModel(ABC):
    @abstractmethod
    def encode(
        self,
        texts: Sequence[str],
        batch_size: int = 8,
        max_length: int = 512,
        normalize: bool = True,
        return_tensor: bool = False,
    ) -> List[List[float]] | torch.Tensor:
        """Encode a batch of texts into dense embeddings."""
        raise NotImplementedError


def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
    # last_hidden_state: [batch, seq_len, hidden]
    # attention_mask: [batch, seq_len]
    mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state)  # [b, s, 1]
    summed = (last_hidden_state * mask).sum(dim=1)
    counts = mask.sum(dim=1).clamp_min(1e-6)
    return summed / counts


def _maybe_normalize(x: torch.Tensor, normalize: bool) -> torch.Tensor:
    if not normalize:
        return x
    return torch.nn.functional.normalize(x, p=2, dim=-1)