summaryrefslogtreecommitdiff
path: root/src/personalization/models/embedding/base.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 15:43:42 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 15:43:42 -0600
commitf918fc90b8d71d1287590b016d926268be573de0 (patch)
treed9009c8612c8e7f866c31d22fb979892a5b55eeb /src/personalization/models/embedding/base.py
parent680513b7771a29f27cbbb3ffb009a69a913de6f9 (diff)
Add model wrapper modules (embedding, reranker, llm, preference_extractor)
Add Python wrappers for: - Qwen3/Nemotron embedding models - BGE/Qwen3 rerankers - vLLM/Llama/Qwen LLM backends - GPT-4o/LLM-based preference extractors Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/models/embedding/base.py')
-rw-r--r--src/personalization/models/embedding/base.py37
1 files changed, 37 insertions, 0 deletions
diff --git a/src/personalization/models/embedding/base.py b/src/personalization/models/embedding/base.py
new file mode 100644
index 0000000..9f9d4d1
--- /dev/null
+++ b/src/personalization/models/embedding/base.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Iterable, List, Sequence
+
+import torch
+
+
+class EmbeddingModel(ABC):
+ @abstractmethod
+ def encode(
+ self,
+ texts: Sequence[str],
+ batch_size: int = 8,
+ max_length: int = 512,
+ normalize: bool = True,
+ return_tensor: bool = False,
+ ) -> List[List[float]] | torch.Tensor:
+ """Encode a batch of texts into dense embeddings."""
+ raise NotImplementedError
+
+
+def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
+ # last_hidden_state: [batch, seq_len, hidden]
+ # attention_mask: [batch, seq_len]
+ mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [b, s, 1]
+ summed = (last_hidden_state * mask).sum(dim=1)
+ counts = mask.sum(dim=1).clamp_min(1e-6)
+ return summed / counts
+
+
+def _maybe_normalize(x: torch.Tensor, normalize: bool) -> torch.Tensor:
+ if not normalize:
+ return x
+ return torch.nn.functional.normalize(x, p=2, dim=-1)
+
+