summaryrefslogtreecommitdiff
path: root/src/personalization/models/reranker
diff options
context:
space:
mode:
Diffstat (limited to 'src/personalization/models/reranker')
-rw-r--r--src/personalization/models/reranker/__init__.py0
-rw-r--r--src/personalization/models/reranker/base.py16
-rw-r--r--src/personalization/models/reranker/bge_reranker.py95
-rw-r--r--src/personalization/models/reranker/nemotron_reranker.py0
-rw-r--r--src/personalization/models/reranker/qwen3_reranker.py96
5 files changed, 207 insertions, 0 deletions
diff --git a/src/personalization/models/reranker/__init__.py b/src/personalization/models/reranker/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/models/reranker/__init__.py
diff --git a/src/personalization/models/reranker/base.py b/src/personalization/models/reranker/base.py
new file mode 100644
index 0000000..34cf6ce
--- /dev/null
+++ b/src/personalization/models/reranker/base.py
@@ -0,0 +1,16 @@
+from typing import List, Protocol
+
+class Reranker(Protocol):
+ def score(
+ self,
+ query: str,
+ docs: List[str],
+ **kwargs,
+ ) -> List[float]:
+ """
+ Score multiple candidate documents for the same query.
+ Higher score indicates higher relevance.
+ Returns a list of floats with length equal to len(docs).
+ """
+ ...
+
diff --git a/src/personalization/models/reranker/bge_reranker.py b/src/personalization/models/reranker/bge_reranker.py
new file mode 100644
index 0000000..a672f0a
--- /dev/null
+++ b/src/personalization/models/reranker/bge_reranker.py
@@ -0,0 +1,95 @@
+"""BGE Reranker - lightweight 278M parameter cross-encoder reranker."""
+
+from typing import List
+import torch
+from transformers import AutoModelForSequenceClassification, AutoTokenizer
+from .base import Reranker
+
+
+class BGEReranker(Reranker):
+ """
+ BGE Reranker using cross-encoder architecture.
+
+ Much lighter than Qwen3-Reranker-8B:
+ - bge-reranker-base: 278M params
+ - bge-reranker-large: 560M params
+ """
+
+ def __init__(
+ self,
+ model_path: str = "BAAI/bge-reranker-base",
+ device_map: str = "auto",
+ dtype: torch.dtype = torch.float16
+ ):
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+ # Handle specific device assignment
+ if device_map and device_map.startswith("cuda:"):
+ self.model = AutoModelForSequenceClassification.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=None,
+ )
+ self.model = self.model.to(device_map)
+ self.device = device_map
+ else:
+ self.model = AutoModelForSequenceClassification.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ )
+ self.device = next(self.model.parameters()).device
+
+ self.model.eval()
+
+ def score(
+ self,
+ query: str,
+ docs: List[str],
+ batch_size: int = 32,
+ **kwargs,
+ ) -> List[float]:
+ """
+ Score documents using cross-encoder.
+
+ Args:
+ query: The query string
+ docs: List of document strings to score
+ batch_size: Batch size for processing
+
+ Returns:
+ List of relevance scores (higher = more relevant)
+ """
+ if not docs:
+ return []
+
+ # Create query-doc pairs
+ pairs = [[query, doc] for doc in docs]
+
+ all_scores = []
+
+ with torch.no_grad():
+ for i in range(0, len(pairs), batch_size):
+ batch = pairs[i:i + batch_size]
+
+ # Tokenize
+ inputs = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ max_length=512,
+ return_tensors="pt"
+ )
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
+
+ # Get scores
+ outputs = self.model(**inputs)
+ scores = outputs.logits.squeeze(-1).float().cpu().tolist()
+
+ # Handle single item case
+ if isinstance(scores, float):
+ scores = [scores]
+
+ all_scores.extend(scores)
+
+ return all_scores
diff --git a/src/personalization/models/reranker/nemotron_reranker.py b/src/personalization/models/reranker/nemotron_reranker.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/models/reranker/nemotron_reranker.py
diff --git a/src/personalization/models/reranker/qwen3_reranker.py b/src/personalization/models/reranker/qwen3_reranker.py
new file mode 100644
index 0000000..b648421
--- /dev/null
+++ b/src/personalization/models/reranker/qwen3_reranker.py
@@ -0,0 +1,96 @@
+from typing import List
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from .base import Reranker
+from personalization.config.settings import LocalModelsConfig
+from personalization.config.registry import choose_dtype, choose_device_map
+
+class Qwen3Reranker(Reranker):
+ def __init__(self, model_path: str, device_map: str = "auto", dtype: torch.dtype = torch.bfloat16):
+ # Ensure we pass trust_remote_code=True for Qwen models
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Handle specific device assignment (e.g., "cuda:0", "cuda:1")
+ if device_map and device_map.startswith("cuda:"):
+ # Load to CPU first, then move to specific GPU
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=None,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ )
+ self.model = self.model.to(device_map)
+ else:
+ # Use accelerate's auto device mapping
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+
+ self.yes_token_id = self.tokenizer("yes", add_special_tokens=False).input_ids[0]
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Reranker":
+ if not cfg.reranker or not cfg.reranker.qwen3_8b:
+ raise ValueError("Reranker config for qwen3_8b is missing")
+ spec = cfg.reranker.qwen3_8b
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(spec.local_path, device_map=device_map, dtype=dtype)
+
+ def _build_prompt(self, query: str, doc: str) -> str:
+ return (
+ "You are a reranker. "
+ "Given a user query and a memory note, answer 'yes' if the note is helpful "
+ "for answering the query, otherwise answer 'no'.\n\n"
+ f"Query: {query}\n"
+ f"Note: {doc}\n"
+ "Answer with a single token: yes or no."
+ )
+
+ @torch.inference_mode()
+ def score(self, query: str, docs: List[str], batch_size: int = 8, **kwargs) -> List[float]:
+ scores = []
+ for i in range(0, len(docs), batch_size):
+ batch_docs = docs[i : i + batch_size]
+ prompts = [self._build_prompt(query, d) for d in batch_docs]
+
+ inputs = self.tokenizer(
+ prompts,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=512
+ ).to(self.model.device)
+
+ outputs = self.model(**inputs)
+ # Take logits of the last token
+ # shape: [batch, seq_len, vocab_size]
+ logits = outputs.logits
+
+ # We want the logits for the token position immediately after the prompt ends.
+ # But since we generated inputs directly from tokenizer(prompts),
+ # we look at the last position of the input.
+ # For causal LM, we usually look at the logits of the last token
+ # to predict the *next* token (which we hope is 'yes' or 'no').
+
+ # Get logits for the next token prediction (last position)
+ # For each sequence in batch, select the last token's logits
+ # inputs['input_ids'] shape: [B, L]
+ # logits shape: [B, L, V]
+ # We want logits[:, -1, :]
+
+ last_token_logits = logits[:, -1, :]
+
+ # Calculate log prob of 'yes'
+ # We can use log_softmax over the vocab dimension
+ log_probs = torch.log_softmax(last_token_logits, dim=-1)
+ yes_log_probs = log_probs[:, self.yes_token_id]
+
+ scores.extend(yes_log_probs.float().cpu().numpy().tolist())
+
+ return scores
+