1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
|
"""BGE Reranker - lightweight 278M parameter cross-encoder reranker."""
from typing import List
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import Reranker
class BGEReranker(Reranker):
"""
BGE Reranker using cross-encoder architecture.
Much lighter than Qwen3-Reranker-8B:
- bge-reranker-base: 278M params
- bge-reranker-large: 560M params
"""
def __init__(
self,
model_path: str = "BAAI/bge-reranker-base",
device_map: str = "auto",
dtype: torch.dtype = torch.float16
):
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
# Handle specific device assignment
if device_map and device_map.startswith("cuda:"):
self.model = AutoModelForSequenceClassification.from_pretrained(
model_path,
torch_dtype=dtype,
device_map=None,
)
self.model = self.model.to(device_map)
self.device = device_map
else:
self.model = AutoModelForSequenceClassification.from_pretrained(
model_path,
torch_dtype=dtype,
device_map=device_map,
)
self.device = next(self.model.parameters()).device
self.model.eval()
def score(
self,
query: str,
docs: List[str],
batch_size: int = 32,
**kwargs,
) -> List[float]:
"""
Score documents using cross-encoder.
Args:
query: The query string
docs: List of document strings to score
batch_size: Batch size for processing
Returns:
List of relevance scores (higher = more relevant)
"""
if not docs:
return []
# Create query-doc pairs
pairs = [[query, doc] for doc in docs]
all_scores = []
with torch.no_grad():
for i in range(0, len(pairs), batch_size):
batch = pairs[i:i + batch_size]
# Tokenize
inputs = self.tokenizer(
batch,
padding=True,
truncation=True,
max_length=512,
return_tensors="pt"
)
inputs = {k: v.to(self.device) for k, v in inputs.items()}
# Get scores
outputs = self.model(**inputs)
scores = outputs.logits.squeeze(-1).float().cpu().tolist()
# Handle single item case
if isinstance(scores, float):
scores = [scores]
all_scores.extend(scores)
return all_scores
|