"""BGE Reranker - lightweight 278M parameter cross-encoder reranker.""" from typing import List import torch from transformers import AutoModelForSequenceClassification, AutoTokenizer from .base import Reranker class BGEReranker(Reranker): """ BGE Reranker using cross-encoder architecture. Much lighter than Qwen3-Reranker-8B: - bge-reranker-base: 278M params - bge-reranker-large: 560M params """ def __init__( self, model_path: str = "BAAI/bge-reranker-base", device_map: str = "auto", dtype: torch.dtype = torch.float16 ): self.tokenizer = AutoTokenizer.from_pretrained(model_path) # Handle specific device assignment if device_map and device_map.startswith("cuda:"): self.model = AutoModelForSequenceClassification.from_pretrained( model_path, torch_dtype=dtype, device_map=None, ) self.model = self.model.to(device_map) self.device = device_map else: self.model = AutoModelForSequenceClassification.from_pretrained( model_path, torch_dtype=dtype, device_map=device_map, ) self.device = next(self.model.parameters()).device self.model.eval() def score( self, query: str, docs: List[str], batch_size: int = 32, **kwargs, ) -> List[float]: """ Score documents using cross-encoder. Args: query: The query string docs: List of document strings to score batch_size: Batch size for processing Returns: List of relevance scores (higher = more relevant) """ if not docs: return [] # Create query-doc pairs pairs = [[query, doc] for doc in docs] all_scores = [] with torch.no_grad(): for i in range(0, len(pairs), batch_size): batch = pairs[i:i + batch_size] # Tokenize inputs = self.tokenizer( batch, padding=True, truncation=True, max_length=512, return_tensors="pt" ) inputs = {k: v.to(self.device) for k, v in inputs.items()} # Get scores outputs = self.model(**inputs) scores = outputs.logits.squeeze(-1).float().cpu().tolist() # Handle single item case if isinstance(scores, float): scores = [scores] all_scores.extend(scores) return all_scores