From f918fc90b8d71d1287590b016d926268be573de0 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 27 Jan 2026 15:43:42 -0600 Subject: Add model wrapper modules (embedding, reranker, llm, preference_extractor) Add Python wrappers for: - Qwen3/Nemotron embedding models - BGE/Qwen3 rerankers - vLLM/Llama/Qwen LLM backends - GPT-4o/LLM-based preference extractors Co-Authored-By: Claude Opus 4.5 --- .../models/reranker/bge_reranker.py | 95 ++++++++++++++++++++++ 1 file changed, 95 insertions(+) create mode 100644 src/personalization/models/reranker/bge_reranker.py (limited to 'src/personalization/models/reranker/bge_reranker.py') diff --git a/src/personalization/models/reranker/bge_reranker.py b/src/personalization/models/reranker/bge_reranker.py new file mode 100644 index 0000000..a672f0a --- /dev/null +++ b/src/personalization/models/reranker/bge_reranker.py @@ -0,0 +1,95 @@ +"""BGE Reranker - lightweight 278M parameter cross-encoder reranker.""" + +from typing import List +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from .base import Reranker + + +class BGEReranker(Reranker): + """ + BGE Reranker using cross-encoder architecture. + + Much lighter than Qwen3-Reranker-8B: + - bge-reranker-base: 278M params + - bge-reranker-large: 560M params + """ + + def __init__( + self, + model_path: str = "BAAI/bge-reranker-base", + device_map: str = "auto", + dtype: torch.dtype = torch.float16 + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + # Handle specific device assignment + if device_map and device_map.startswith("cuda:"): + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, + ) + self.model = self.model.to(device_map) + self.device = device_map + else: + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + ) + self.device = next(self.model.parameters()).device + + self.model.eval() + + def score( + self, + query: str, + docs: List[str], + batch_size: int = 32, + **kwargs, + ) -> List[float]: + """ + Score documents using cross-encoder. + + Args: + query: The query string + docs: List of document strings to score + batch_size: Batch size for processing + + Returns: + List of relevance scores (higher = more relevant) + """ + if not docs: + return [] + + # Create query-doc pairs + pairs = [[query, doc] for doc in docs] + + all_scores = [] + + with torch.no_grad(): + for i in range(0, len(pairs), batch_size): + batch = pairs[i:i + batch_size] + + # Tokenize + inputs = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=512, + return_tensors="pt" + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get scores + outputs = self.model(**inputs) + scores = outputs.logits.squeeze(-1).float().cpu().tolist() + + # Handle single item case + if isinstance(scores, float): + scores = [scores] + + all_scores.extend(scores) + + return all_scores -- cgit v1.2.3