diff options
| author | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 15:43:42 -0600 |
|---|---|---|
| committer | YurenHao0426 <blackhao0426@gmail.com> | 2026-01-27 15:43:42 -0600 |
| commit | f918fc90b8d71d1287590b016d926268be573de0 (patch) | |
| tree | d9009c8612c8e7f866c31d22fb979892a5b55eeb /src/personalization/models/reranker/bge_reranker.py | |
| parent | 680513b7771a29f27cbbb3ffb009a69a913de6f9 (diff) | |
Add model wrapper modules (embedding, reranker, llm, preference_extractor)
Add Python wrappers for:
- Qwen3/Nemotron embedding models
- BGE/Qwen3 rerankers
- vLLM/Llama/Qwen LLM backends
- GPT-4o/LLM-based preference extractors
Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/models/reranker/bge_reranker.py')
| -rw-r--r-- | src/personalization/models/reranker/bge_reranker.py | 95 |
1 files changed, 95 insertions, 0 deletions
diff --git a/src/personalization/models/reranker/bge_reranker.py b/src/personalization/models/reranker/bge_reranker.py new file mode 100644 index 0000000..a672f0a --- /dev/null +++ b/src/personalization/models/reranker/bge_reranker.py @@ -0,0 +1,95 @@ +"""BGE Reranker - lightweight 278M parameter cross-encoder reranker.""" + +from typing import List +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from .base import Reranker + + +class BGEReranker(Reranker): + """ + BGE Reranker using cross-encoder architecture. + + Much lighter than Qwen3-Reranker-8B: + - bge-reranker-base: 278M params + - bge-reranker-large: 560M params + """ + + def __init__( + self, + model_path: str = "BAAI/bge-reranker-base", + device_map: str = "auto", + dtype: torch.dtype = torch.float16 + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + # Handle specific device assignment + if device_map and device_map.startswith("cuda:"): + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, + ) + self.model = self.model.to(device_map) + self.device = device_map + else: + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + ) + self.device = next(self.model.parameters()).device + + self.model.eval() + + def score( + self, + query: str, + docs: List[str], + batch_size: int = 32, + **kwargs, + ) -> List[float]: + """ + Score documents using cross-encoder. + + Args: + query: The query string + docs: List of document strings to score + batch_size: Batch size for processing + + Returns: + List of relevance scores (higher = more relevant) + """ + if not docs: + return [] + + # Create query-doc pairs + pairs = [[query, doc] for doc in docs] + + all_scores = [] + + with torch.no_grad(): + for i in range(0, len(pairs), batch_size): + batch = pairs[i:i + batch_size] + + # Tokenize + inputs = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=512, + return_tensors="pt" + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get scores + outputs = self.model(**inputs) + scores = outputs.logits.squeeze(-1).float().cpu().tolist() + + # Handle single item case + if isinstance(scores, float): + scores = [scores] + + all_scores.extend(scores) + + return all_scores |
