summaryrefslogtreecommitdiff
path: root/src/personalization/models/reranker/bge_reranker.py
blob: a672f0a7079479669f75f17d6c5840a6a0d3cc6f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
"""BGE Reranker - lightweight 278M parameter cross-encoder reranker."""

from typing import List
import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
from .base import Reranker


class BGEReranker(Reranker):
    """
    BGE Reranker using cross-encoder architecture.

    Much lighter than Qwen3-Reranker-8B:
    - bge-reranker-base: 278M params
    - bge-reranker-large: 560M params
    """

    def __init__(
        self,
        model_path: str = "BAAI/bge-reranker-base",
        device_map: str = "auto",
        dtype: torch.dtype = torch.float16
    ):
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)

        # Handle specific device assignment
        if device_map and device_map.startswith("cuda:"):
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_path,
                torch_dtype=dtype,
                device_map=None,
            )
            self.model = self.model.to(device_map)
            self.device = device_map
        else:
            self.model = AutoModelForSequenceClassification.from_pretrained(
                model_path,
                torch_dtype=dtype,
                device_map=device_map,
            )
            self.device = next(self.model.parameters()).device

        self.model.eval()

    def score(
        self,
        query: str,
        docs: List[str],
        batch_size: int = 32,
        **kwargs,
    ) -> List[float]:
        """
        Score documents using cross-encoder.

        Args:
            query: The query string
            docs: List of document strings to score
            batch_size: Batch size for processing

        Returns:
            List of relevance scores (higher = more relevant)
        """
        if not docs:
            return []

        # Create query-doc pairs
        pairs = [[query, doc] for doc in docs]

        all_scores = []

        with torch.no_grad():
            for i in range(0, len(pairs), batch_size):
                batch = pairs[i:i + batch_size]

                # Tokenize
                inputs = self.tokenizer(
                    batch,
                    padding=True,
                    truncation=True,
                    max_length=512,
                    return_tensors="pt"
                )
                inputs = {k: v.to(self.device) for k, v in inputs.items()}

                # Get scores
                outputs = self.model(**inputs)
                scores = outputs.logits.squeeze(-1).float().cpu().tolist()

                # Handle single item case
                if isinstance(scores, float):
                    scores = [scores]

                all_scores.extend(scores)

        return all_scores