summaryrefslogtreecommitdiff
path: root/src/personalization/models/reranker/bge_reranker.py
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/models/reranker/bge_reranker.py')
-rw-r--r--src/personalization/models/reranker/bge_reranker.py95
1 files changed, 95 insertions, 0 deletions
diff --git a/src/personalization/models/reranker/bge_reranker.py b/src/personalization/models/reranker/bge_reranker.py
new file mode 100644
index 0000000..a672f0a
--- /dev/null
+++ b/src/personalization/models/reranker/bge_reranker.py
@@ -0,0 +1,95 @@
+"""BGE Reranker - lightweight 278M parameter cross-encoder reranker."""
+
+from typing import List
+import torch
+from transformers import AutoModelForSequenceClassification, AutoTokenizer
+from .base import Reranker
+
+
+class BGEReranker(Reranker):
+ """
+ BGE Reranker using cross-encoder architecture.
+
+ Much lighter than Qwen3-Reranker-8B:
+ - bge-reranker-base: 278M params
+ - bge-reranker-large: 560M params
+ """
+
+ def __init__(
+ self,
+ model_path: str = "BAAI/bge-reranker-base",
+ device_map: str = "auto",
+ dtype: torch.dtype = torch.float16
+ ):
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+ # Handle specific device assignment
+ if device_map and device_map.startswith("cuda:"):
+ self.model = AutoModelForSequenceClassification.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=None,
+ )
+ self.model = self.model.to(device_map)
+ self.device = device_map
+ else:
+ self.model = AutoModelForSequenceClassification.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ )
+ self.device = next(self.model.parameters()).device
+
+ self.model.eval()
+
+ def score(
+ self,
+ query: str,
+ docs: List[str],
+ batch_size: int = 32,
+ **kwargs,
+ ) -> List[float]:
+ """
+ Score documents using cross-encoder.
+
+ Args:
+ query: The query string
+ docs: List of document strings to score
+ batch_size: Batch size for processing
+
+ Returns:
+ List of relevance scores (higher = more relevant)
+ """
+ if not docs:
+ return []
+
+ # Create query-doc pairs
+ pairs = [[query, doc] for doc in docs]
+
+ all_scores = []
+
+ with torch.no_grad():
+ for i in range(0, len(pairs), batch_size):
+ batch = pairs[i:i + batch_size]
+
+ # Tokenize
+ inputs = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ max_length=512,
+ return_tensors="pt"
+ )
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
+
+ # Get scores
+ outputs = self.model(**inputs)
+ scores = outputs.logits.squeeze(-1).float().cpu().tolist()
+
+ # Handle single item case
+ if isinstance(scores, float):
+ scores = [scores]
+
+ all_scores.extend(scores)
+
+ return all_scores