1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
|
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Iterable, List, Sequence
import torch
class EmbeddingModel(ABC):
@abstractmethod
def encode(
self,
texts: Sequence[str],
batch_size: int = 8,
max_length: int = 512,
normalize: bool = True,
return_tensor: bool = False,
) -> List[List[float]] | torch.Tensor:
"""Encode a batch of texts into dense embeddings."""
raise NotImplementedError
def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
# last_hidden_state: [batch, seq_len, hidden]
# attention_mask: [batch, seq_len]
mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [b, s, 1]
summed = (last_hidden_state * mask).sum(dim=1)
counts = mask.sum(dim=1).clamp_min(1e-6)
return summed / counts
def _maybe_normalize(x: torch.Tensor, normalize: bool) -> torch.Tensor:
if not normalize:
return x
return torch.nn.functional.normalize(x, p=2, dim=-1)
|