summaryrefslogtreecommitdiff
path: root/src/personalization/models/embedding/nemotron_8b.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/models/embedding/nemotron_8b.py')
-rw-r--r--src/personalization/models/embedding/nemotron_8b.py63
1 files changed, 63 insertions, 0 deletions
diff --git a/src/personalization/models/embedding/nemotron_8b.py b/src/personalization/models/embedding/nemotron_8b.py
new file mode 100644
index 0000000..6348aee
--- /dev/null
+++ b/src/personalization/models/embedding/nemotron_8b.py
@@ -0,0 +1,63 @@
+from __future__ import annotations
+
+from typing import List, Sequence
+
+import torch
+from transformers import AutoModel, AutoTokenizer
+
+from personalization.config.registry import choose_dtype, choose_device_map
+from personalization.config.settings import LocalModelsConfig
+from .base import EmbeddingModel, _mean_pool, _maybe_normalize
+
+
+class LlamaEmbedNemotron8B(EmbeddingModel):
+ def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=True, trust_remote_code=True
+ )
+ self.model = AutoModel.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "LlamaEmbedNemotron8B":
+ if not cfg.embedding or not cfg.embedding.nemotron:
+ raise ValueError("Embedding config for nemotron is missing")
+ spec = cfg.embedding.nemotron
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(spec.local_path, dtype=dtype, device_map=device_map)
+
+ @torch.inference_mode()
+ def encode(
+ self,
+ texts: Sequence[str],
+ batch_size: int = 8,
+ max_length: int = 512,
+ normalize: bool = True,
+ return_tensor: bool = False,
+ ) -> List[List[float]] | torch.Tensor:
+ device = next(self.model.parameters()).device
+ outputs: List[torch.Tensor] = []
+ for i in range(0, len(texts), batch_size):
+ batch = list(texts[i : i + batch_size])
+ enc = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+ ).to(device)
+ model_out = self.model(**enc, output_hidden_states=False, return_dict=True)
+ pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined]
+ pooled = _maybe_normalize(pooled, normalize)
+ outputs.append(pooled)
+ emb = torch.cat(outputs, dim=0)
+ if return_tensor:
+ return emb
+ return emb.cpu().to(torch.float32).tolist()
+
+