diff options
Diffstat (limited to 'src/personalization/models/embedding/nemotron_8b.py')
| -rw-r--r-- | src/personalization/models/embedding/nemotron_8b.py | 63 |
1 files changed, 63 insertions, 0 deletions
diff --git a/src/personalization/models/embedding/nemotron_8b.py b/src/personalization/models/embedding/nemotron_8b.py new file mode 100644 index 0000000..6348aee --- /dev/null +++ b/src/personalization/models/embedding/nemotron_8b.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import List, Sequence + +import torch +from transformers import AutoModel, AutoTokenizer + +from personalization.config.registry import choose_dtype, choose_device_map +from personalization.config.settings import LocalModelsConfig +from .base import EmbeddingModel, _mean_pool, _maybe_normalize + + +class LlamaEmbedNemotron8B(EmbeddingModel): + def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=True + ) + self.model = AutoModel.from_pretrained( + model_path, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "LlamaEmbedNemotron8B": + if not cfg.embedding or not cfg.embedding.nemotron: + raise ValueError("Embedding config for nemotron is missing") + spec = cfg.embedding.nemotron + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls(spec.local_path, dtype=dtype, device_map=device_map) + + @torch.inference_mode() + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + device = next(self.model.parameters()).device + outputs: List[torch.Tensor] = [] + for i in range(0, len(texts), batch_size): + batch = list(texts[i : i + batch_size]) + enc = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ).to(device) + model_out = self.model(**enc, output_hidden_states=False, return_dict=True) + pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined] + pooled = _maybe_normalize(pooled, normalize) + outputs.append(pooled) + emb = torch.cat(outputs, dim=0) + if return_tensor: + return emb + return emb.cpu().to(torch.float32).tolist() + + |
