summaryrefslogtreecommitdiff
path: root/src/personalization/models/embedding
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/models/embedding')
-rw-r--r--src/personalization/models/embedding/__init__.py11
-rw-r--r--src/personalization/models/embedding/base.py37
-rw-r--r--src/personalization/models/embedding/nemotron_8b.py63
-rw-r--r--src/personalization/models/embedding/qwen3_8b.py89
4 files changed, 200 insertions, 0 deletions
diff --git a/src/personalization/models/embedding/__init__.py b/src/personalization/models/embedding/__init__.py
new file mode 100644
index 0000000..05221aa
--- /dev/null
+++ b/src/personalization/models/embedding/__init__.py
@@ -0,0 +1,11 @@
+from .base import EmbeddingModel
+from .qwen3_8b import Qwen3Embedding8B
+from .nemotron_8b import LlamaEmbedNemotron8B
+
+__all__ = [
+ "EmbeddingModel",
+ "Qwen3Embedding8B",
+ "LlamaEmbedNemotron8B",
+]
+
+
diff --git a/src/personalization/models/embedding/base.py b/src/personalization/models/embedding/base.py
new file mode 100644
index 0000000..9f9d4d1
--- /dev/null
+++ b/src/personalization/models/embedding/base.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Iterable, List, Sequence
+
+import torch
+
+
+class EmbeddingModel(ABC):
+ @abstractmethod
+ def encode(
+ self,
+ texts: Sequence[str],
+ batch_size: int = 8,
+ max_length: int = 512,
+ normalize: bool = True,
+ return_tensor: bool = False,
+ ) -> List[List[float]] | torch.Tensor:
+ """Encode a batch of texts into dense embeddings."""
+ raise NotImplementedError
+
+
+def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
+ # last_hidden_state: [batch, seq_len, hidden]
+ # attention_mask: [batch, seq_len]
+ mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [b, s, 1]
+ summed = (last_hidden_state * mask).sum(dim=1)
+ counts = mask.sum(dim=1).clamp_min(1e-6)
+ return summed / counts
+
+
+def _maybe_normalize(x: torch.Tensor, normalize: bool) -> torch.Tensor:
+ if not normalize:
+ return x
+ return torch.nn.functional.normalize(x, p=2, dim=-1)
+
+
diff --git a/src/personalization/models/embedding/nemotron_8b.py b/src/personalization/models/embedding/nemotron_8b.py
new file mode 100644
index 0000000..6348aee
--- /dev/null
+++ b/src/personalization/models/embedding/nemotron_8b.py
@@ -0,0 +1,63 @@
+from __future__ import annotations
+
+from typing import List, Sequence
+
+import torch
+from transformers import AutoModel, AutoTokenizer
+
+from personalization.config.registry import choose_dtype, choose_device_map
+from personalization.config.settings import LocalModelsConfig
+from .base import EmbeddingModel, _mean_pool, _maybe_normalize
+
+
+class LlamaEmbedNemotron8B(EmbeddingModel):
+ def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=True, trust_remote_code=True
+ )
+ self.model = AutoModel.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "LlamaEmbedNemotron8B":
+ if not cfg.embedding or not cfg.embedding.nemotron:
+ raise ValueError("Embedding config for nemotron is missing")
+ spec = cfg.embedding.nemotron
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(spec.local_path, dtype=dtype, device_map=device_map)
+
+ @torch.inference_mode()
+ def encode(
+ self,
+ texts: Sequence[str],
+ batch_size: int = 8,
+ max_length: int = 512,
+ normalize: bool = True,
+ return_tensor: bool = False,
+ ) -> List[List[float]] | torch.Tensor:
+ device = next(self.model.parameters()).device
+ outputs: List[torch.Tensor] = []
+ for i in range(0, len(texts), batch_size):
+ batch = list(texts[i : i + batch_size])
+ enc = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+ ).to(device)
+ model_out = self.model(**enc, output_hidden_states=False, return_dict=True)
+ pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined]
+ pooled = _maybe_normalize(pooled, normalize)
+ outputs.append(pooled)
+ emb = torch.cat(outputs, dim=0)
+ if return_tensor:
+ return emb
+ return emb.cpu().to(torch.float32).tolist()
+
+
diff --git a/src/personalization/models/embedding/qwen3_8b.py b/src/personalization/models/embedding/qwen3_8b.py
new file mode 100644
index 0000000..fb02e67
--- /dev/null
+++ b/src/personalization/models/embedding/qwen3_8b.py
@@ -0,0 +1,89 @@
+from __future__ import annotations
+
+from typing import List, Sequence
+
+import torch
+from transformers import AutoModel, AutoTokenizer
+
+from personalization.config.registry import choose_dtype, choose_device_map
+from personalization.config.settings import LocalModelsConfig
+from .base import EmbeddingModel, _mean_pool, _maybe_normalize
+
+
+class Qwen3Embedding8B(EmbeddingModel):
+ def __init__(
+ self,
+ model_path: str,
+ dtype: torch.dtype,
+ device_map: str = "auto",
+ trust_remote_code: bool = True,
+ ) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ # Handle specific device assignment (e.g., "cuda:0", "cuda:1")
+ if device_map and device_map.startswith("cuda:"):
+ # Load to CPU first, then move to specific GPU
+ self.model = AutoModel.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=None, # Don't use accelerate's device_map
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ self.model = self.model.to(device_map)
+ else:
+ # Use accelerate's auto device mapping
+ self.model = AutoModel.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Embedding8B":
+ if not cfg.embedding or not cfg.embedding.qwen3:
+ raise ValueError("Embedding config for qwen3 is missing")
+ spec = cfg.embedding.qwen3
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(
+ spec.local_path,
+ dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+
+ @torch.inference_mode()
+ def encode(
+ self,
+ texts: Sequence[str],
+ batch_size: int = 8,
+ max_length: int = 512,
+ normalize: bool = True,
+ return_tensor: bool = False,
+ ) -> List[List[float]] | torch.Tensor:
+ device = next(self.model.parameters()).device
+ outputs: List[torch.Tensor] = []
+ for i in range(0, len(texts), batch_size):
+ batch = list(texts[i : i + batch_size])
+ enc = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+ ).to(device)
+ model_out = self.model(**enc, output_hidden_states=False, return_dict=True)
+ pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined]
+ pooled = _maybe_normalize(pooled, normalize)
+ outputs.append(pooled)
+ emb = torch.cat(outputs, dim=0)
+ if return_tensor:
+ return emb
+ return emb.cpu().to(torch.float32).tolist()
+
+