diff options
Diffstat (limited to 'src/personalization/models/embedding')
| -rw-r--r-- | src/personalization/models/embedding/__init__.py | 11 | ||||
| -rw-r--r-- | src/personalization/models/embedding/base.py | 37 | ||||
| -rw-r--r-- | src/personalization/models/embedding/nemotron_8b.py | 63 | ||||
| -rw-r--r-- | src/personalization/models/embedding/qwen3_8b.py | 89 |
4 files changed, 200 insertions, 0 deletions
diff --git a/src/personalization/models/embedding/__init__.py b/src/personalization/models/embedding/__init__.py new file mode 100644 index 0000000..05221aa --- /dev/null +++ b/src/personalization/models/embedding/__init__.py @@ -0,0 +1,11 @@ +from .base import EmbeddingModel +from .qwen3_8b import Qwen3Embedding8B +from .nemotron_8b import LlamaEmbedNemotron8B + +__all__ = [ + "EmbeddingModel", + "Qwen3Embedding8B", + "LlamaEmbedNemotron8B", +] + + diff --git a/src/personalization/models/embedding/base.py b/src/personalization/models/embedding/base.py new file mode 100644 index 0000000..9f9d4d1 --- /dev/null +++ b/src/personalization/models/embedding/base.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Iterable, List, Sequence + +import torch + + +class EmbeddingModel(ABC): + @abstractmethod + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + """Encode a batch of texts into dense embeddings.""" + raise NotImplementedError + + +def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + # last_hidden_state: [batch, seq_len, hidden] + # attention_mask: [batch, seq_len] + mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [b, s, 1] + summed = (last_hidden_state * mask).sum(dim=1) + counts = mask.sum(dim=1).clamp_min(1e-6) + return summed / counts + + +def _maybe_normalize(x: torch.Tensor, normalize: bool) -> torch.Tensor: + if not normalize: + return x + return torch.nn.functional.normalize(x, p=2, dim=-1) + + diff --git a/src/personalization/models/embedding/nemotron_8b.py b/src/personalization/models/embedding/nemotron_8b.py new file mode 100644 index 0000000..6348aee --- /dev/null +++ b/src/personalization/models/embedding/nemotron_8b.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import List, Sequence + +import torch +from transformers import AutoModel, AutoTokenizer + +from personalization.config.registry import choose_dtype, choose_device_map +from personalization.config.settings import LocalModelsConfig +from .base import EmbeddingModel, _mean_pool, _maybe_normalize + + +class LlamaEmbedNemotron8B(EmbeddingModel): + def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=True + ) + self.model = AutoModel.from_pretrained( + model_path, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "LlamaEmbedNemotron8B": + if not cfg.embedding or not cfg.embedding.nemotron: + raise ValueError("Embedding config for nemotron is missing") + spec = cfg.embedding.nemotron + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls(spec.local_path, dtype=dtype, device_map=device_map) + + @torch.inference_mode() + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + device = next(self.model.parameters()).device + outputs: List[torch.Tensor] = [] + for i in range(0, len(texts), batch_size): + batch = list(texts[i : i + batch_size]) + enc = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ).to(device) + model_out = self.model(**enc, output_hidden_states=False, return_dict=True) + pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined] + pooled = _maybe_normalize(pooled, normalize) + outputs.append(pooled) + emb = torch.cat(outputs, dim=0) + if return_tensor: + return emb + return emb.cpu().to(torch.float32).tolist() + + diff --git a/src/personalization/models/embedding/qwen3_8b.py b/src/personalization/models/embedding/qwen3_8b.py new file mode 100644 index 0000000..fb02e67 --- /dev/null +++ b/src/personalization/models/embedding/qwen3_8b.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import List, Sequence + +import torch +from transformers import AutoModel, AutoTokenizer + +from personalization.config.registry import choose_dtype, choose_device_map +from personalization.config.settings import LocalModelsConfig +from .base import EmbeddingModel, _mean_pool, _maybe_normalize + + +class Qwen3Embedding8B(EmbeddingModel): + def __init__( + self, + model_path: str, + dtype: torch.dtype, + device_map: str = "auto", + trust_remote_code: bool = True, + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=trust_remote_code + ) + + # Handle specific device assignment (e.g., "cuda:0", "cuda:1") + if device_map and device_map.startswith("cuda:"): + # Load to CPU first, then move to specific GPU + self.model = AutoModel.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, # Don't use accelerate's device_map + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(device_map) + else: + # Use accelerate's auto device mapping + self.model = AutoModel.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Embedding8B": + if not cfg.embedding or not cfg.embedding.qwen3: + raise ValueError("Embedding config for qwen3 is missing") + spec = cfg.embedding.qwen3 + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls( + spec.local_path, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + + @torch.inference_mode() + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + device = next(self.model.parameters()).device + outputs: List[torch.Tensor] = [] + for i in range(0, len(texts), batch_size): + batch = list(texts[i : i + batch_size]) + enc = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ).to(device) + model_out = self.model(**enc, output_hidden_states=False, return_dict=True) + pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined] + pooled = _maybe_normalize(pooled, normalize) + outputs.append(pooled) + emb = torch.cat(outputs, dim=0) + if return_tensor: + return emb + return emb.cpu().to(torch.float32).tolist() + + |
