From f918fc90b8d71d1287590b016d926268be573de0 Mon Sep 17 00:00:00 2001 From: YurenHao0426 Date: Tue, 27 Jan 2026 15:43:42 -0600 Subject: Add model wrapper modules (embedding, reranker, llm, preference_extractor) Add Python wrappers for: - Qwen3/Nemotron embedding models - BGE/Qwen3 rerankers - vLLM/Llama/Qwen LLM backends - GPT-4o/LLM-based preference extractors Co-Authored-By: Claude Opus 4.5 --- src/personalization/models/reranker/__init__.py | 0 src/personalization/models/reranker/base.py | 16 ++++ .../models/reranker/bge_reranker.py | 95 +++++++++++++++++++++ .../models/reranker/nemotron_reranker.py | 0 .../models/reranker/qwen3_reranker.py | 96 ++++++++++++++++++++++ 5 files changed, 207 insertions(+) create mode 100644 src/personalization/models/reranker/__init__.py create mode 100644 src/personalization/models/reranker/base.py create mode 100644 src/personalization/models/reranker/bge_reranker.py create mode 100644 src/personalization/models/reranker/nemotron_reranker.py create mode 100644 src/personalization/models/reranker/qwen3_reranker.py (limited to 'src/personalization/models/reranker') diff --git a/src/personalization/models/reranker/__init__.py b/src/personalization/models/reranker/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/personalization/models/reranker/base.py b/src/personalization/models/reranker/base.py new file mode 100644 index 0000000..34cf6ce --- /dev/null +++ b/src/personalization/models/reranker/base.py @@ -0,0 +1,16 @@ +from typing import List, Protocol + +class Reranker(Protocol): + def score( + self, + query: str, + docs: List[str], + **kwargs, + ) -> List[float]: + """ + Score multiple candidate documents for the same query. + Higher score indicates higher relevance. + Returns a list of floats with length equal to len(docs). + """ + ... + diff --git a/src/personalization/models/reranker/bge_reranker.py b/src/personalization/models/reranker/bge_reranker.py new file mode 100644 index 0000000..a672f0a --- /dev/null +++ b/src/personalization/models/reranker/bge_reranker.py @@ -0,0 +1,95 @@ +"""BGE Reranker - lightweight 278M parameter cross-encoder reranker.""" + +from typing import List +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from .base import Reranker + + +class BGEReranker(Reranker): + """ + BGE Reranker using cross-encoder architecture. + + Much lighter than Qwen3-Reranker-8B: + - bge-reranker-base: 278M params + - bge-reranker-large: 560M params + """ + + def __init__( + self, + model_path: str = "BAAI/bge-reranker-base", + device_map: str = "auto", + dtype: torch.dtype = torch.float16 + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + # Handle specific device assignment + if device_map and device_map.startswith("cuda:"): + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, + ) + self.model = self.model.to(device_map) + self.device = device_map + else: + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + ) + self.device = next(self.model.parameters()).device + + self.model.eval() + + def score( + self, + query: str, + docs: List[str], + batch_size: int = 32, + **kwargs, + ) -> List[float]: + """ + Score documents using cross-encoder. + + Args: + query: The query string + docs: List of document strings to score + batch_size: Batch size for processing + + Returns: + List of relevance scores (higher = more relevant) + """ + if not docs: + return [] + + # Create query-doc pairs + pairs = [[query, doc] for doc in docs] + + all_scores = [] + + with torch.no_grad(): + for i in range(0, len(pairs), batch_size): + batch = pairs[i:i + batch_size] + + # Tokenize + inputs = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=512, + return_tensors="pt" + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get scores + outputs = self.model(**inputs) + scores = outputs.logits.squeeze(-1).float().cpu().tolist() + + # Handle single item case + if isinstance(scores, float): + scores = [scores] + + all_scores.extend(scores) + + return all_scores diff --git a/src/personalization/models/reranker/nemotron_reranker.py b/src/personalization/models/reranker/nemotron_reranker.py new file mode 100644 index 0000000..e69de29 diff --git a/src/personalization/models/reranker/qwen3_reranker.py b/src/personalization/models/reranker/qwen3_reranker.py new file mode 100644 index 0000000..b648421 --- /dev/null +++ b/src/personalization/models/reranker/qwen3_reranker.py @@ -0,0 +1,96 @@ +from typing import List +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from .base import Reranker +from personalization.config.settings import LocalModelsConfig +from personalization.config.registry import choose_dtype, choose_device_map + +class Qwen3Reranker(Reranker): + def __init__(self, model_path: str, device_map: str = "auto", dtype: torch.dtype = torch.bfloat16): + # Ensure we pass trust_remote_code=True for Qwen models + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Handle specific device assignment (e.g., "cuda:0", "cuda:1") + if device_map and device_map.startswith("cuda:"): + # Load to CPU first, then move to specific GPU + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, + trust_remote_code=True, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(device_map) + else: + # Use accelerate's auto device mapping + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + + self.yes_token_id = self.tokenizer("yes", add_special_tokens=False).input_ids[0] + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Reranker": + if not cfg.reranker or not cfg.reranker.qwen3_8b: + raise ValueError("Reranker config for qwen3_8b is missing") + spec = cfg.reranker.qwen3_8b + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls(spec.local_path, device_map=device_map, dtype=dtype) + + def _build_prompt(self, query: str, doc: str) -> str: + return ( + "You are a reranker. " + "Given a user query and a memory note, answer 'yes' if the note is helpful " + "for answering the query, otherwise answer 'no'.\n\n" + f"Query: {query}\n" + f"Note: {doc}\n" + "Answer with a single token: yes or no." + ) + + @torch.inference_mode() + def score(self, query: str, docs: List[str], batch_size: int = 8, **kwargs) -> List[float]: + scores = [] + for i in range(0, len(docs), batch_size): + batch_docs = docs[i : i + batch_size] + prompts = [self._build_prompt(query, d) for d in batch_docs] + + inputs = self.tokenizer( + prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ).to(self.model.device) + + outputs = self.model(**inputs) + # Take logits of the last token + # shape: [batch, seq_len, vocab_size] + logits = outputs.logits + + # We want the logits for the token position immediately after the prompt ends. + # But since we generated inputs directly from tokenizer(prompts), + # we look at the last position of the input. + # For causal LM, we usually look at the logits of the last token + # to predict the *next* token (which we hope is 'yes' or 'no'). + + # Get logits for the next token prediction (last position) + # For each sequence in batch, select the last token's logits + # inputs['input_ids'] shape: [B, L] + # logits shape: [B, L, V] + # We want logits[:, -1, :] + + last_token_logits = logits[:, -1, :] + + # Calculate log prob of 'yes' + # We can use log_softmax over the vocab dimension + log_probs = torch.log_softmax(last_token_logits, dim=-1) + yes_log_probs = log_probs[:, self.yes_token_id] + + scores.extend(yes_log_probs.float().cpu().numpy().tolist()) + + return scores + -- cgit v1.2.3