summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore9
-rw-r--r--src/personalization/models/__init__.py0
-rw-r--r--src/personalization/models/embedding/__init__.py11
-rw-r--r--src/personalization/models/embedding/base.py37
-rw-r--r--src/personalization/models/embedding/nemotron_8b.py63
-rw-r--r--src/personalization/models/embedding/qwen3_8b.py89
-rw-r--r--src/personalization/models/llm/__init__.py4
-rw-r--r--src/personalization/models/llm/base.py29
-rw-r--r--src/personalization/models/llm/llama_instruct.py129
-rw-r--r--src/personalization/models/llm/prompt_builder.py0
-rw-r--r--src/personalization/models/llm/qwen_instruct.py164
-rw-r--r--src/personalization/models/llm/vllm_chat.py217
-rw-r--r--src/personalization/models/preference_extractor/__init__.py5
-rw-r--r--src/personalization/models/preference_extractor/base.py17
-rw-r--r--src/personalization/models/preference_extractor/gpt4o_extractor.py97
-rw-r--r--src/personalization/models/preference_extractor/llm_extractor.py153
-rw-r--r--src/personalization/models/preference_extractor/rule_extractor.py152
-rw-r--r--src/personalization/models/reranker/__init__.py0
-rw-r--r--src/personalization/models/reranker/base.py16
-rw-r--r--src/personalization/models/reranker/bge_reranker.py95
-rw-r--r--src/personalization/models/reranker/nemotron_reranker.py0
-rw-r--r--src/personalization/models/reranker/qwen3_reranker.py96
22 files changed, 1381 insertions, 2 deletions
diff --git a/.gitignore b/.gitignore
index 27c4d29..4cbe1bf 100644
--- a/.gitignore
+++ b/.gitignore
@@ -12,8 +12,10 @@ venv/
.venv/
*.egg-info/
-# Models (Large model weights)
-models/
+# Models (Large model weights - top level only)
+/models/
+# But include src/personalization/models/ (Python wrappers)
+!src/personalization/models/
*.safetensors
*.bin
*.pt
@@ -53,3 +55,6 @@ collaborativeagents/results/
*.whl
*.tar
wandb/
+
+*.out
+*.err
diff --git a/src/personalization/models/__init__.py b/src/personalization/models/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/models/__init__.py
diff --git a/src/personalization/models/embedding/__init__.py b/src/personalization/models/embedding/__init__.py
new file mode 100644
index 0000000..05221aa
--- /dev/null
+++ b/src/personalization/models/embedding/__init__.py
@@ -0,0 +1,11 @@
+from .base import EmbeddingModel
+from .qwen3_8b import Qwen3Embedding8B
+from .nemotron_8b import LlamaEmbedNemotron8B
+
+__all__ = [
+ "EmbeddingModel",
+ "Qwen3Embedding8B",
+ "LlamaEmbedNemotron8B",
+]
+
+
diff --git a/src/personalization/models/embedding/base.py b/src/personalization/models/embedding/base.py
new file mode 100644
index 0000000..9f9d4d1
--- /dev/null
+++ b/src/personalization/models/embedding/base.py
@@ -0,0 +1,37 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Iterable, List, Sequence
+
+import torch
+
+
+class EmbeddingModel(ABC):
+ @abstractmethod
+ def encode(
+ self,
+ texts: Sequence[str],
+ batch_size: int = 8,
+ max_length: int = 512,
+ normalize: bool = True,
+ return_tensor: bool = False,
+ ) -> List[List[float]] | torch.Tensor:
+ """Encode a batch of texts into dense embeddings."""
+ raise NotImplementedError
+
+
+def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
+ # last_hidden_state: [batch, seq_len, hidden]
+ # attention_mask: [batch, seq_len]
+ mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [b, s, 1]
+ summed = (last_hidden_state * mask).sum(dim=1)
+ counts = mask.sum(dim=1).clamp_min(1e-6)
+ return summed / counts
+
+
+def _maybe_normalize(x: torch.Tensor, normalize: bool) -> torch.Tensor:
+ if not normalize:
+ return x
+ return torch.nn.functional.normalize(x, p=2, dim=-1)
+
+
diff --git a/src/personalization/models/embedding/nemotron_8b.py b/src/personalization/models/embedding/nemotron_8b.py
new file mode 100644
index 0000000..6348aee
--- /dev/null
+++ b/src/personalization/models/embedding/nemotron_8b.py
@@ -0,0 +1,63 @@
+from __future__ import annotations
+
+from typing import List, Sequence
+
+import torch
+from transformers import AutoModel, AutoTokenizer
+
+from personalization.config.registry import choose_dtype, choose_device_map
+from personalization.config.settings import LocalModelsConfig
+from .base import EmbeddingModel, _mean_pool, _maybe_normalize
+
+
+class LlamaEmbedNemotron8B(EmbeddingModel):
+ def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=True, trust_remote_code=True
+ )
+ self.model = AutoModel.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "LlamaEmbedNemotron8B":
+ if not cfg.embedding or not cfg.embedding.nemotron:
+ raise ValueError("Embedding config for nemotron is missing")
+ spec = cfg.embedding.nemotron
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(spec.local_path, dtype=dtype, device_map=device_map)
+
+ @torch.inference_mode()
+ def encode(
+ self,
+ texts: Sequence[str],
+ batch_size: int = 8,
+ max_length: int = 512,
+ normalize: bool = True,
+ return_tensor: bool = False,
+ ) -> List[List[float]] | torch.Tensor:
+ device = next(self.model.parameters()).device
+ outputs: List[torch.Tensor] = []
+ for i in range(0, len(texts), batch_size):
+ batch = list(texts[i : i + batch_size])
+ enc = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+ ).to(device)
+ model_out = self.model(**enc, output_hidden_states=False, return_dict=True)
+ pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined]
+ pooled = _maybe_normalize(pooled, normalize)
+ outputs.append(pooled)
+ emb = torch.cat(outputs, dim=0)
+ if return_tensor:
+ return emb
+ return emb.cpu().to(torch.float32).tolist()
+
+
diff --git a/src/personalization/models/embedding/qwen3_8b.py b/src/personalization/models/embedding/qwen3_8b.py
new file mode 100644
index 0000000..fb02e67
--- /dev/null
+++ b/src/personalization/models/embedding/qwen3_8b.py
@@ -0,0 +1,89 @@
+from __future__ import annotations
+
+from typing import List, Sequence
+
+import torch
+from transformers import AutoModel, AutoTokenizer
+
+from personalization.config.registry import choose_dtype, choose_device_map
+from personalization.config.settings import LocalModelsConfig
+from .base import EmbeddingModel, _mean_pool, _maybe_normalize
+
+
+class Qwen3Embedding8B(EmbeddingModel):
+ def __init__(
+ self,
+ model_path: str,
+ dtype: torch.dtype,
+ device_map: str = "auto",
+ trust_remote_code: bool = True,
+ ) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=True, trust_remote_code=trust_remote_code
+ )
+
+ # Handle specific device assignment (e.g., "cuda:0", "cuda:1")
+ if device_map and device_map.startswith("cuda:"):
+ # Load to CPU first, then move to specific GPU
+ self.model = AutoModel.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=None, # Don't use accelerate's device_map
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+ self.model = self.model.to(device_map)
+ else:
+ # Use accelerate's auto device mapping
+ self.model = AutoModel.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=trust_remote_code,
+ low_cpu_mem_usage=True,
+ )
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Embedding8B":
+ if not cfg.embedding or not cfg.embedding.qwen3:
+ raise ValueError("Embedding config for qwen3 is missing")
+ spec = cfg.embedding.qwen3
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(
+ spec.local_path,
+ dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+
+ @torch.inference_mode()
+ def encode(
+ self,
+ texts: Sequence[str],
+ batch_size: int = 8,
+ max_length: int = 512,
+ normalize: bool = True,
+ return_tensor: bool = False,
+ ) -> List[List[float]] | torch.Tensor:
+ device = next(self.model.parameters()).device
+ outputs: List[torch.Tensor] = []
+ for i in range(0, len(texts), batch_size):
+ batch = list(texts[i : i + batch_size])
+ enc = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ max_length=max_length,
+ return_tensors="pt",
+ ).to(device)
+ model_out = self.model(**enc, output_hidden_states=False, return_dict=True)
+ pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined]
+ pooled = _maybe_normalize(pooled, normalize)
+ outputs.append(pooled)
+ emb = torch.cat(outputs, dim=0)
+ if return_tensor:
+ return emb
+ return emb.cpu().to(torch.float32).tolist()
+
+
diff --git a/src/personalization/models/llm/__init__.py b/src/personalization/models/llm/__init__.py
new file mode 100644
index 0000000..3f1af81
--- /dev/null
+++ b/src/personalization/models/llm/__init__.py
@@ -0,0 +1,4 @@
+from .qwen_instruct import QwenInstruct
+
+__all__ = ["QwenInstruct"]
+
diff --git a/src/personalization/models/llm/base.py b/src/personalization/models/llm/base.py
new file mode 100644
index 0000000..72b6ca8
--- /dev/null
+++ b/src/personalization/models/llm/base.py
@@ -0,0 +1,29 @@
+from typing import List, Protocol, Optional
+from personalization.types import ChatTurn
+
+class ChatModel(Protocol):
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ """
+ Generate an assistant response given conversation history and memory notes.
+
+ Args:
+ history: The conversation history ending with the current user turn.
+ memory_notes: List of retrieved memory content strings.
+ max_new_tokens: Max tokens to generate.
+ temperature: Sampling temperature.
+ top_p: Top-p sampling.
+ top_k: Top-k sampling.
+
+ Returns:
+ The generated assistant response text.
+ """
+ ...
+
diff --git a/src/personalization/models/llm/llama_instruct.py b/src/personalization/models/llm/llama_instruct.py
new file mode 100644
index 0000000..bdf0dff
--- /dev/null
+++ b/src/personalization/models/llm/llama_instruct.py
@@ -0,0 +1,129 @@
+from typing import List, Optional
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.models.llm.base import ChatModel
+from personalization.types import ChatTurn
+
+class LlamaChatModel(ChatModel):
+ def __init__(
+ self,
+ model_path: str,
+ device: str = "cuda",
+ dtype: str = "bfloat16", # Keep type hint as str for legacy, but handle torch.dtype
+ max_context_length: int = 8192,
+ ):
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+ # Handle dtype if it's already a torch.dtype object
+ if isinstance(dtype, str):
+ torch_dtype = getattr(torch, dtype)
+ else:
+ torch_dtype = dtype
+
+ # Handle specific device assignment (e.g., "cuda:0", "cuda:1")
+ if device and device.startswith("cuda:"):
+ # Load to CPU first, then move to specific GPU
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=torch_dtype,
+ device_map=None,
+ low_cpu_mem_usage=True,
+ )
+ self.model = self.model.to(device)
+ else:
+ # Use accelerate's device mapping
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=torch_dtype,
+ device_map=device,
+ )
+
+ self.max_context_length = max_context_length
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ def _build_prompt(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ ) -> str:
+ memory_block = ""
+ if memory_notes:
+ bullet = "\n".join(f"- {n}" for n in memory_notes)
+ memory_block = (
+ "Here are the user's preferences and memories:\n"
+ f"{bullet}\n\n"
+ )
+
+ # Build prompt manually or use chat template if available.
+ # Llama 3 use specific tags.
+ # <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n ... <|eot_id|>
+ # But we can try to use tokenizer.apply_chat_template if it exists.
+
+ if hasattr(self.tokenizer, "apply_chat_template"):
+ messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}]
+ for turn in history:
+ messages.append({"role": turn.role, "content": turn.text})
+ return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+ # Fallback manual construction (simplified Llama 2/3 style or generic)
+ # This is risky for Llama 3 specifically which needs exact tokens.
+ # Let's assume apply_chat_template works for Llama-3-Instruct models.
+
+ # If fallback needed:
+ history_lines = []
+ for turn in history[-8:]:
+ role_tag = "user" if turn.role == "user" else "assistant"
+ # Generic format
+ history_lines.append(f"{role_tag}: {turn.text}")
+
+ prompt = (
+ "System: You are a helpful assistant.\n"
+ + memory_block
+ + "\n".join(history_lines)
+ + "\nassistant:"
+ )
+ return prompt
+
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ prompt = self._build_prompt(history, memory_notes)
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True,
+ max_length=self.max_context_length).to(self.model.device)
+
+ gen_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": temperature > 0,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+ if top_k is not None:
+ gen_kwargs["top_k"] = top_k
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ eos_token_id=self.tokenizer.eos_token_id,
+ **gen_kwargs,
+ )
+ full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+ # naive stripping
+ # With chat template, 'full' usually contains the whole conversation.
+ # We need to extract just the new part.
+ # But 'prompt' string might not match decoded output exactly due to special tokens skipping.
+ # Better: slice output ids.
+
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
+
+ return answer_text
+
diff --git a/src/personalization/models/llm/prompt_builder.py b/src/personalization/models/llm/prompt_builder.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/models/llm/prompt_builder.py
diff --git a/src/personalization/models/llm/qwen_instruct.py b/src/personalization/models/llm/qwen_instruct.py
new file mode 100644
index 0000000..cf2047d
--- /dev/null
+++ b/src/personalization/models/llm/qwen_instruct.py
@@ -0,0 +1,164 @@
+from typing import List, Optional, Dict, Any
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import torch
+
+from personalization.models.llm.base import ChatModel
+from personalization.types import ChatTurn
+from personalization.config.settings import LocalModelsConfig
+from personalization.config.registry import choose_dtype, choose_device_map
+
+class QwenInstruct(ChatModel):
+ def __init__(
+ self,
+ model_path: str,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.bfloat16,
+ max_context_length: int = 4096,
+ ):
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ use_fast=True,
+ trust_remote_code=True,
+ )
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype, # dtype is already torch.dtype, no getattr needed
+ device_map=device,
+ trust_remote_code=True,
+ )
+ self.max_context_length = max_context_length
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ # Legacy helper for manual generation without template
+ @torch.inference_mode()
+ def generate(
+ self,
+ prompt: str,
+ max_new_tokens: int = 256,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ stop: Optional[List[str]] = None,
+ top_k: Optional[int] = None,
+ ) -> str:
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+
+ gen_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": temperature > 0,
+ "temperature": temperature,
+ "top_p": top_p,
+ "pad_token_id": self.tokenizer.pad_token_id,
+ "eos_token_id": self.tokenizer.eos_token_id,
+ }
+ if top_k is not None:
+ gen_kwargs["top_k"] = top_k
+
+ outputs = self.model.generate(
+ **inputs,
+ **gen_kwargs
+ )
+ # Return only the newly generated portion, not the echoed prompt
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
+ if stop:
+ for s in stop:
+ if s in text:
+ text = text.split(s)[0]
+ break
+ return text
+
+ def _build_prompt(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ ) -> str:
+ """
+ Construct prompt using ChatML-like structure via apply_chat_template if available,
+ or manual construction. Qwen usually supports apply_chat_template.
+ We will map ChatTurn to messages list.
+ """
+ memory_block = ""
+ if memory_notes:
+ bullet = "\n".join(f"- {n}" for n in memory_notes)
+ memory_block = (
+ "Here are the user's preferences and memories:\n"
+ f"{bullet}\n\n"
+ )
+
+ messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}]
+
+ for turn in history:
+ messages.append({"role": turn.role, "content": turn.text})
+
+ return self.tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ # Compatibility check: if history is dict list (legacy), convert to ChatTurn
+ # This allows old code to work if not fully updated, though we should update callers.
+ # But ChatTurn is required by Protocol. We assume callers are updated.
+ if history and isinstance(history[0], dict):
+ # Auto-convert for safety during migration
+ history = [ChatTurn(
+ user_id="unknown", session_id="unknown", turn_id=i,
+ role=h["role"], text=h["content"]
+ ) for i, h in enumerate(history)]
+
+ prompt = self._build_prompt(history, memory_notes)
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True,
+ max_length=self.max_context_length).to(self.model.device)
+
+ gen_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": temperature > 0,
+ "temperature": temperature,
+ "top_p": top_p,
+ "pad_token_id": self.tokenizer.pad_token_id,
+ "eos_token_id": self.tokenizer.eos_token_id,
+ }
+ if top_k is not None:
+ gen_kwargs["top_k"] = top_k
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ **gen_kwargs,
+ )
+
+ full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+ # remove prompt part manually since we didn't use self.generate helper here to keep full control
+ # input_ids length is inputs['input_ids'].shape[1]
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
+
+ return answer_text
+
+ # Factory method for legacy config loading
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "QwenInstruct":
+ spec = cfg.llm
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ # device_map usually handled by transformers if passed as device_map argument
+ # Here we pass it as 'device' arg to constructor if it is a string like "cuda:0"
+ # If it is "auto", constructor might need adjustment or we trust transformers.
+ # Our constructor takes 'device' string.
+ device = spec.device_map if isinstance(spec.device_map, str) else "cuda"
+
+ return cls(
+ model_path=spec.local_path,
+ device=device, # Pass string
+ dtype=spec.dtype # Pass string name, constructor converts
+ )
diff --git a/src/personalization/models/llm/vllm_chat.py b/src/personalization/models/llm/vllm_chat.py
new file mode 100644
index 0000000..b5c3a05
--- /dev/null
+++ b/src/personalization/models/llm/vllm_chat.py
@@ -0,0 +1,217 @@
+"""
+vLLM-based ChatModel implementation for high-throughput inference.
+
+This provides the same interface as LlamaChatModel but uses vLLM HTTP API
+for much faster inference (3000+ sessions/hr vs 20 sessions/hr).
+"""
+
+from typing import List, Optional
+import time
+import requests
+
+from personalization.models.llm.base import ChatModel
+from personalization.types import ChatTurn
+
+
+class VLLMChatModel(ChatModel):
+ """
+ ChatModel implementation using vLLM HTTP API.
+
+ This is a drop-in replacement for LlamaChatModel that uses vLLM
+ for much faster inference.
+ """
+
+ def __init__(
+ self,
+ vllm_url: str = "http://localhost:8003/v1",
+ model_name: str = None,
+ max_context_length: int = 8192,
+ timeout: int = 120,
+ ):
+ self.vllm_url = vllm_url.rstrip('/')
+ self.model_name = model_name
+ self.max_context_length = max_context_length
+ self.timeout = timeout
+
+ # Discover model name if not provided
+ if self.model_name is None:
+ self._discover_model()
+
+ def _discover_model(self):
+ """Discover the model name from the vLLM server."""
+ max_retries = 30
+ for attempt in range(max_retries):
+ try:
+ response = requests.get(f"{self.vllm_url}/models", timeout=10)
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self.model_name = models["data"][0]["id"]
+ return
+ except Exception as e:
+ if attempt < max_retries - 1:
+ wait_time = min(2 ** attempt * 0.5, 10)
+ time.sleep(wait_time)
+
+ # Fallback
+ self.model_name = "default"
+ print(f"[VLLMChatModel] Warning: Could not discover model, using '{self.model_name}'")
+
+ def health_check(self) -> bool:
+ """Check if the vLLM server is healthy."""
+ try:
+ response = requests.get(f"{self.vllm_url.replace('/v1', '')}/health", timeout=5)
+ return response.status_code == 200
+ except:
+ return False
+
+ def _estimate_tokens(self, text: str) -> int:
+ """Estimate token count using character-based heuristic.
+
+ For Llama models, ~4 characters per token is a reasonable estimate.
+ We use 3.5 to be conservative (slightly overestimate tokens).
+ """
+ return int(len(text) / 3.5)
+
+ def _build_messages(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ ) -> List[dict]:
+ """Build messages list for chat completion API with auto-truncation.
+
+ If the context exceeds max_context_length, older conversation turns
+ are removed to keep only the most recent context that fits.
+ """
+ # Use CollaborativeAgents-style system prompt
+ if memory_notes:
+ bullet = "\n".join(f"- {n}" for n in memory_notes)
+ system_content = (
+ "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
+ "# User Preferences\n"
+ "The user has a set of preferences for how you should behave. If you do not follow these preferences, "
+ "the user will be unable to learn from your response and you will need to adjust your response to adhere "
+ "to these preferences (so it is best to follow them initially).\n"
+ "Based on your past interactions with the user, you have maintained a set of notes about the user's preferences:\n"
+ f"{bullet}\n\n"
+ "# Conversation Guidelines:\n"
+ "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
+ "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
+ "- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem.\n"
+ )
+ else:
+ # Vanilla mode - no preferences
+ system_content = (
+ "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
+ "# Conversation Guidelines:\n"
+ "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
+ "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
+ "- Your goal is to help the user solve their problem. Do your best to help them.\n"
+ )
+ system_message = {"role": "system", "content": system_content}
+
+ # Calculate available tokens for conversation history
+ # Reserve space for: system prompt + max_new_tokens + safety margin
+ system_tokens = self._estimate_tokens(system_content)
+ available_tokens = self.max_context_length - system_tokens - max_new_tokens - 100 # 100 token safety margin
+
+ # Build conversation messages from history
+ conversation_messages = []
+ for turn in history:
+ conversation_messages.append({"role": turn.role, "content": turn.text})
+
+ # Check if truncation is needed
+ total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation_messages)
+
+ if total_conv_tokens > available_tokens:
+ # Truncate from the beginning (keep recent messages)
+ truncated_messages = []
+ current_tokens = 0
+
+ # Iterate from most recent to oldest
+ for msg in reversed(conversation_messages):
+ msg_tokens = self._estimate_tokens(msg["content"])
+ if current_tokens + msg_tokens <= available_tokens:
+ truncated_messages.insert(0, msg)
+ current_tokens += msg_tokens
+ else:
+ # Stop adding older messages
+ break
+
+ conversation_messages = truncated_messages
+ if len(truncated_messages) < len(history):
+ print(f"[VLLMChatModel] Truncated context: kept {len(truncated_messages)}/{len(history)} turns "
+ f"({current_tokens}/{total_conv_tokens} estimated tokens)")
+
+ messages = [system_message] + conversation_messages
+ return messages
+
+ def build_messages(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ ) -> List[dict]:
+ """Public method to build messages without calling the API.
+
+ Used for batch processing where messages are collected first,
+ then sent in batch to vLLM for concurrent processing.
+ """
+ return self._build_messages(history, memory_notes, max_new_tokens)
+
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ """Generate a response using vLLM HTTP API."""
+ messages = self._build_messages(history, memory_notes, max_new_tokens)
+
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": max_new_tokens,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+
+ # Retry with exponential backoff
+ max_retries = 5
+ for attempt in range(max_retries):
+ try:
+ response = requests.post(
+ f"{self.vllm_url}/chat/completions",
+ json=payload,
+ timeout=self.timeout
+ )
+
+ if response.status_code == 200:
+ result = response.json()
+ return result["choices"][0]["message"]["content"]
+ elif response.status_code == 400:
+ error_text = response.text
+ # Handle context length error
+ if "max_tokens" in error_text and max_new_tokens > 64:
+ payload["max_tokens"] = max(64, max_new_tokens // 2)
+ continue
+ raise RuntimeError(f"vLLM error: {error_text[:200]}")
+ else:
+ raise RuntimeError(f"vLLM HTTP {response.status_code}: {response.text[:200]}")
+
+ except requests.exceptions.Timeout:
+ if attempt < max_retries - 1:
+ time.sleep(2 ** attempt)
+ continue
+ raise RuntimeError("vLLM request timeout")
+ except requests.exceptions.ConnectionError as e:
+ if attempt < max_retries - 1:
+ time.sleep(2 ** attempt)
+ continue
+ raise RuntimeError(f"vLLM connection error: {e}")
+
+ raise RuntimeError("Max retries exceeded for vLLM request")
diff --git a/src/personalization/models/preference_extractor/__init__.py b/src/personalization/models/preference_extractor/__init__.py
new file mode 100644
index 0000000..65e2595
--- /dev/null
+++ b/src/personalization/models/preference_extractor/__init__.py
@@ -0,0 +1,5 @@
+from .rule_extractor import QwenRuleExtractor
+from .gpt4o_extractor import GPT4OExtractor
+from .base import PreferenceExtractor
+
+__all__ = ["QwenRuleExtractor", "GPT4OExtractor", "PreferenceExtractor"]
diff --git a/src/personalization/models/preference_extractor/base.py b/src/personalization/models/preference_extractor/base.py
new file mode 100644
index 0000000..850292f
--- /dev/null
+++ b/src/personalization/models/preference_extractor/base.py
@@ -0,0 +1,17 @@
+from __future__ import annotations
+
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List
+from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList
+
+class PreferenceExtractorBase(ABC):
+ @abstractmethod
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ """
+ Extract preferences from a window of chat turns (history + current query).
+ """
+ raise NotImplementedError
+
+# Alias for backward compatibility if needed,
+# though specific extractors should inherit from PreferenceExtractorBase now.
+PreferenceExtractor = PreferenceExtractorBase
diff --git a/src/personalization/models/preference_extractor/gpt4o_extractor.py b/src/personalization/models/preference_extractor/gpt4o_extractor.py
new file mode 100644
index 0000000..212bb13
--- /dev/null
+++ b/src/personalization/models/preference_extractor/gpt4o_extractor.py
@@ -0,0 +1,97 @@
+from __future__ import annotations
+
+import json
+import os
+from typing import Any, Dict, List
+
+from openai import OpenAI
+from personalization.config.settings import LocalModelsConfig
+from personalization.models.preference_extractor.base import PreferenceExtractorBase as PreferenceExtractor
+from personalization.retrieval.preference_store.schemas import (
+ ChatTurn,
+ PreferenceList,
+ preference_list_json_schema,
+)
+
+
+class GPT4OExtractor(PreferenceExtractor):
+ def __init__(self, api_key: str, model: str = "gpt-4o") -> None:
+ self.client = OpenAI(api_key=api_key)
+ self.model = model
+
+ # Load system prompt template
+ template_path = "fine_tuning_prompt_template.txt"
+ if os.path.exists(template_path):
+ with open(template_path, "r", encoding="utf-8") as f:
+ self.system_prompt = f.read()
+ else:
+ # Fallback simple prompt if file missing
+ self.system_prompt = (
+ "You are a preference extraction assistant. "
+ "Extract user preferences from the query into a JSON object."
+ )
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "GPT4OExtractor":
+ # We rely on env var for API key, config for other potential settings if needed
+ api_key = os.getenv("OPENAI_API_KEY")
+ if not api_key:
+ raise ValueError("OPENAI_API_KEY environment variable not set")
+ return cls(api_key=api_key)
+
+ def build_preference_prompt(self, query: str) -> str:
+ # GPT4OExtractor uses the system prompt loaded in __init__
+ return self.system_prompt
+
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ # Reuse logic but return raw dict
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": query},
+ ],
+ response_format={"type": "json_object"},
+ temperature=0.0,
+ )
+ content = response.choices[0].message.content
+ if content:
+ return json.loads(content)
+ except Exception as e:
+ print(f"Error calling GPT-4o: {e}")
+ return {"preferences": []}
+
+ def extract_turn(self, turn: ChatTurn) -> PreferenceList:
+ if turn.role != "user":
+ return PreferenceList(preferences=[])
+
+ try:
+ response = self.client.chat.completions.create(
+ model=self.model,
+ messages=[
+ {"role": "system", "content": self.system_prompt},
+ {"role": "user", "content": turn.text},
+ ],
+ response_format={"type": "json_object"},
+ temperature=0.0,
+ )
+
+ content = response.choices[0].message.content
+ if not content:
+ return PreferenceList(preferences=[])
+
+ data = json.loads(content)
+ # The prompt might return {"preferences": [...]}, validate it
+ return PreferenceList.model_validate(data)
+
+ except Exception as e:
+ print(f"Error calling GPT-4o: {e}")
+ return PreferenceList(preferences=[])
+
+ def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]:
+ results = []
+ for turn in turns:
+ results.append(self.extract_turn(turn))
+ return results
+
diff --git a/src/personalization/models/preference_extractor/llm_extractor.py b/src/personalization/models/preference_extractor/llm_extractor.py
new file mode 100644
index 0000000..8f7a6cb
--- /dev/null
+++ b/src/personalization/models/preference_extractor/llm_extractor.py
@@ -0,0 +1,153 @@
+from typing import List, Dict, Any
+import torch
+import json
+import os
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.models.preference_extractor.base import PreferenceExtractorBase
+from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList
+from personalization.config.settings import LocalModelsConfig
+from personalization.config.registry import choose_dtype, choose_device_map
+
+class PreferenceExtractorLLM(PreferenceExtractorBase):
+ def __init__(
+ self,
+ model_path: str,
+ prompt_template_path: str = "fine_tuning_prompt_template.txt",
+ device_map: str = "auto",
+ dtype: torch.dtype = torch.bfloat16,
+ max_new_tokens: int = 512,
+ ) -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+ self.max_new_tokens = max_new_tokens
+
+ if os.path.exists(prompt_template_path):
+ with open(prompt_template_path, "r", encoding="utf-8") as f:
+ self.prompt_template = f.read()
+ else:
+ print(f"Warning: Prompt template not found at {prompt_template_path}. Using fallback.")
+ self.prompt_template = "Extract user preferences from the following conversation."
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig, name: str = "qwen3_0_6b_sft") -> "PreferenceExtractorLLM":
+ # We need to access the specific extractor config by name
+ # Assuming cfg has a way to access extra configs or we update LocalModelsConfig to support multiple extractors
+ # For now, let's look for it in the 'preference_extractor' dict if it was a Dict, but it is a ModelSpec.
+ # We need to update LocalModelsConfig to support a dictionary of extractors or a specific one.
+ # Based on user design doc:
+ # preference_extractor:
+ # qwen3_0_6b_sft: ...
+
+ # We might need to manually parse the raw config or update settings.py
+ # Let's assume settings.py will be updated to hold a map or specific fields.
+ # For now, if we use the existing ModelSpec for preference_extractor in cfg, we assume it points to this model.
+
+ # BUT the design doc says "preference_extractor" in local_models.yaml will have "qwen3_0_6b_sft" key.
+ # The current settings.py defines preference_extractor as a single ModelSpec.
+ # We will need to update settings.py first to support multiple extractors or a dict.
+ # I will proceed implementing this class assuming arguments are passed, and update settings/registry later.
+
+ # This from_config might change depending on how settings.py is refactored.
+ # For now I will implement it assuming a direct ModelSpec is passed, or we handle it in registry.
+ pass
+ return None
+
+ def _build_prompt(self, turns: List[ChatTurn]) -> str:
+ # Construct messages list for chat template
+ messages = [{"role": "system", "content": self.prompt_template}]
+
+ # Window size 6
+ window = turns[-6:]
+
+ # Add conversation history
+ # We need to format the conversation as input context.
+ # Since the task is to extract preferences from the *whole* context (or latest turn?),
+ # usually we provide the conversation and ask for extraction.
+ # But LLaMA-Factory SFT usually expects:
+ # System: <template>
+ # User: <input>
+ # Assistant: <output>
+
+ # We should pack the conversation history into the User message?
+ # Or if we trained with multi-turn chat format?
+ # Assuming "Input" column in dataset was the conversation history.
+
+ history_texts = []
+ for t in window:
+ role = "User" if t.role == "user" else "Assistant"
+ history_texts.append(f"{role}: {t.text}")
+
+ conversation_text = "\n".join(history_texts)
+
+ # Construct the User input
+ # We append a trigger instruction if it wasn't part of the training input implicitly.
+ # But based on your template, the User Input Example was just the query "I am a Python developer..."
+ # So likely we should just feed the conversation text as the user message.
+
+ messages.append({"role": "user", "content": conversation_text})
+
+ # Apply chat template
+ prompt = self.tokenizer.apply_chat_template(
+ messages,
+ tokenize=False,
+ add_generation_prompt=True
+ )
+
+ return prompt
+
+ def _generate(self, prompt: str) -> str:
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ max_new_tokens=self.max_new_tokens,
+ do_sample=False,
+ temperature=0.0,
+ eos_token_id=self.tokenizer.eos_token_id,
+ pad_token_id=self.tokenizer.pad_token_id,
+ )
+ full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+ return full_text[len(prompt):]
+
+ def _parse_preferences(self, raw_output: str) -> PreferenceList:
+ start = raw_output.find("{")
+ end = raw_output.rfind("}")
+
+ if start == -1 or end == -1 or end <= start:
+ return PreferenceList(preferences=[])
+
+ json_str = raw_output[start:end+1]
+ try:
+ data = json.loads(json_str)
+ return PreferenceList.model_validate(data)
+ except Exception:
+ return PreferenceList(preferences=[])
+
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ prompt = self._build_prompt(turns)
+ raw_output = self._generate(prompt)
+ return self._parse_preferences(raw_output)
+
+ # Legacy support
+ def build_preference_prompt(self, query: str) -> str:
+ # Wrap query in a dummy turn
+ turn = ChatTurn(
+ user_id="dummy", session_id="dummy", turn_id=0,
+ role="user", text=query
+ )
+ return self._build_prompt([turn])
+
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ turn = ChatTurn(
+ user_id="dummy", session_id="dummy", turn_id=0,
+ role="user", text=query
+ )
+ prefs = self.extract_turn([turn])
+ return prefs.model_dump()
+
diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py
new file mode 100644
index 0000000..0f743d9
--- /dev/null
+++ b/src/personalization/models/preference_extractor/rule_extractor.py
@@ -0,0 +1,152 @@
+from __future__ import annotations
+
+import json
+import re
+import os
+from typing import Any, Dict, List
+
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.config.registry import choose_dtype, choose_device_map
+from personalization.config.settings import LocalModelsConfig
+from .base import PreferenceExtractor
+from personalization.retrieval.preference_store.schemas import (
+ PreferenceList,
+ preference_list_json_schema,
+ ChatTurn,
+)
+
+# Hardcoded System Prompt to match SFT training
+# This MUST match what was used in training (scripts/split_train_test.py)
+SFT_SYSTEM_PROMPT = (
+ "Extract user preferences from the query into JSON format based on the PreferenceList schema. "
+ "If no preferences are found, return {\"preferences\": []}."
+)
+
+class QwenRuleExtractor(PreferenceExtractor):
+ """
+ Extractor using a Fine-Tuned (SFT) Qwen model.
+ Despite the name 'RuleExtractor' (legacy), this now performs direct End-to-End extraction.
+ """
+ def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None:
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path, use_fast=True, trust_remote_code=True
+ )
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "QwenRuleExtractor":
+ spec = cfg.preference_extractor
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(spec.local_path, dtype=dtype, device_map=device_map)
+
+ def build_preference_prompt(self, query: str) -> str:
+ """
+ Construct the prompt string using the tokenizer's chat template.
+ Matches the format seen during SFT training.
+ """
+ messages = [
+ {"role": "system", "content": SFT_SYSTEM_PROMPT},
+ {"role": "user", "content": query}
+ ]
+ prompt = self.tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+ return prompt
+
+ @torch.inference_mode()
+ def extract_preferences(self, query: str) -> Dict[str, Any]:
+ """
+ Directly extract preferences from query using the SFT model.
+ Returns a dict compatible with PreferenceList model (key: 'preferences').
+ """
+ prompt = self.build_preference_prompt(query)
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+
+ outputs = self.model.generate(
+ **inputs,
+ do_sample=False, # Deterministic greedy decoding
+ max_new_tokens=512, # Allow enough space for JSON
+ pad_token_id=self.tokenizer.pad_token_id,
+ eos_token_id=self.tokenizer.eos_token_id,
+ )
+
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
+
+ if os.getenv("PREF_DEBUG") == "1":
+ print(f"[debug][extractor] Raw output: {text}")
+
+ # Try parsing JSON
+ try:
+ # 1. Direct parse
+ data = json.loads(text)
+
+ # 2. Validate against schema structure
+ validated = PreferenceList.model_validate(data)
+ return validated.model_dump()
+
+ except Exception:
+ # Fallback: Try to find JSON blob if model outputted extra text (rare for SFT but possible)
+ extracted_json = self._extract_json_substring(text)
+ if extracted_json:
+ try:
+ data = json.loads(extracted_json)
+ validated = PreferenceList.model_validate(data)
+ return validated.model_dump()
+ except:
+ pass
+
+ # If all fails, return empty
+ return {"preferences": []}
+
+ def _extract_json_substring(self, text: str) -> str | None:
+ """Helper to find { ... } block in text."""
+ # Find first '{' and last '}'
+ start = text.find('{')
+ end = text.rfind('}')
+ if start != -1 and end != -1 and end > start:
+ return text[start : end + 1]
+ return None
+
+ def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList:
+ """
+ Extract preferences from the LAST user turn in the history.
+ We don't concat history because our SFT model was trained on single-turn extraction.
+ Using context might confuse it unless we trained it that way.
+ """
+ # Find the last user message
+ last_user_msg = None
+ for t in reversed(turns):
+ if t.role == "user":
+ last_user_msg = t.text
+ break
+
+ if not last_user_msg:
+ return PreferenceList(preferences=[])
+
+ result_dict = self.extract_preferences(last_user_msg)
+ return PreferenceList.model_validate(result_dict)
+
+ def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]:
+ """
+ Extract preferences from ALL user turns individually.
+ """
+ results = []
+ for turn in turns:
+ if turn.role == "user":
+ res = self.extract_preferences(turn.text)
+ results.append(PreferenceList.model_validate(res))
+ else:
+ results.append(PreferenceList(preferences=[]))
+ return results
diff --git a/src/personalization/models/reranker/__init__.py b/src/personalization/models/reranker/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/models/reranker/__init__.py
diff --git a/src/personalization/models/reranker/base.py b/src/personalization/models/reranker/base.py
new file mode 100644
index 0000000..34cf6ce
--- /dev/null
+++ b/src/personalization/models/reranker/base.py
@@ -0,0 +1,16 @@
+from typing import List, Protocol
+
+class Reranker(Protocol):
+ def score(
+ self,
+ query: str,
+ docs: List[str],
+ **kwargs,
+ ) -> List[float]:
+ """
+ Score multiple candidate documents for the same query.
+ Higher score indicates higher relevance.
+ Returns a list of floats with length equal to len(docs).
+ """
+ ...
+
diff --git a/src/personalization/models/reranker/bge_reranker.py b/src/personalization/models/reranker/bge_reranker.py
new file mode 100644
index 0000000..a672f0a
--- /dev/null
+++ b/src/personalization/models/reranker/bge_reranker.py
@@ -0,0 +1,95 @@
+"""BGE Reranker - lightweight 278M parameter cross-encoder reranker."""
+
+from typing import List
+import torch
+from transformers import AutoModelForSequenceClassification, AutoTokenizer
+from .base import Reranker
+
+
+class BGEReranker(Reranker):
+ """
+ BGE Reranker using cross-encoder architecture.
+
+ Much lighter than Qwen3-Reranker-8B:
+ - bge-reranker-base: 278M params
+ - bge-reranker-large: 560M params
+ """
+
+ def __init__(
+ self,
+ model_path: str = "BAAI/bge-reranker-base",
+ device_map: str = "auto",
+ dtype: torch.dtype = torch.float16
+ ):
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+ # Handle specific device assignment
+ if device_map and device_map.startswith("cuda:"):
+ self.model = AutoModelForSequenceClassification.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=None,
+ )
+ self.model = self.model.to(device_map)
+ self.device = device_map
+ else:
+ self.model = AutoModelForSequenceClassification.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ )
+ self.device = next(self.model.parameters()).device
+
+ self.model.eval()
+
+ def score(
+ self,
+ query: str,
+ docs: List[str],
+ batch_size: int = 32,
+ **kwargs,
+ ) -> List[float]:
+ """
+ Score documents using cross-encoder.
+
+ Args:
+ query: The query string
+ docs: List of document strings to score
+ batch_size: Batch size for processing
+
+ Returns:
+ List of relevance scores (higher = more relevant)
+ """
+ if not docs:
+ return []
+
+ # Create query-doc pairs
+ pairs = [[query, doc] for doc in docs]
+
+ all_scores = []
+
+ with torch.no_grad():
+ for i in range(0, len(pairs), batch_size):
+ batch = pairs[i:i + batch_size]
+
+ # Tokenize
+ inputs = self.tokenizer(
+ batch,
+ padding=True,
+ truncation=True,
+ max_length=512,
+ return_tensors="pt"
+ )
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
+
+ # Get scores
+ outputs = self.model(**inputs)
+ scores = outputs.logits.squeeze(-1).float().cpu().tolist()
+
+ # Handle single item case
+ if isinstance(scores, float):
+ scores = [scores]
+
+ all_scores.extend(scores)
+
+ return all_scores
diff --git a/src/personalization/models/reranker/nemotron_reranker.py b/src/personalization/models/reranker/nemotron_reranker.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/models/reranker/nemotron_reranker.py
diff --git a/src/personalization/models/reranker/qwen3_reranker.py b/src/personalization/models/reranker/qwen3_reranker.py
new file mode 100644
index 0000000..b648421
--- /dev/null
+++ b/src/personalization/models/reranker/qwen3_reranker.py
@@ -0,0 +1,96 @@
+from typing import List
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+from .base import Reranker
+from personalization.config.settings import LocalModelsConfig
+from personalization.config.registry import choose_dtype, choose_device_map
+
+class Qwen3Reranker(Reranker):
+ def __init__(self, model_path: str, device_map: str = "auto", dtype: torch.dtype = torch.bfloat16):
+ # Ensure we pass trust_remote_code=True for Qwen models
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True)
+
+ # Handle specific device assignment (e.g., "cuda:0", "cuda:1")
+ if device_map and device_map.startswith("cuda:"):
+ # Load to CPU first, then move to specific GPU
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=None,
+ trust_remote_code=True,
+ low_cpu_mem_usage=True,
+ )
+ self.model = self.model.to(device_map)
+ else:
+ # Use accelerate's auto device mapping
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype,
+ device_map=device_map,
+ trust_remote_code=True,
+ )
+
+ self.yes_token_id = self.tokenizer("yes", add_special_tokens=False).input_ids[0]
+
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Reranker":
+ if not cfg.reranker or not cfg.reranker.qwen3_8b:
+ raise ValueError("Reranker config for qwen3_8b is missing")
+ spec = cfg.reranker.qwen3_8b
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ return cls(spec.local_path, device_map=device_map, dtype=dtype)
+
+ def _build_prompt(self, query: str, doc: str) -> str:
+ return (
+ "You are a reranker. "
+ "Given a user query and a memory note, answer 'yes' if the note is helpful "
+ "for answering the query, otherwise answer 'no'.\n\n"
+ f"Query: {query}\n"
+ f"Note: {doc}\n"
+ "Answer with a single token: yes or no."
+ )
+
+ @torch.inference_mode()
+ def score(self, query: str, docs: List[str], batch_size: int = 8, **kwargs) -> List[float]:
+ scores = []
+ for i in range(0, len(docs), batch_size):
+ batch_docs = docs[i : i + batch_size]
+ prompts = [self._build_prompt(query, d) for d in batch_docs]
+
+ inputs = self.tokenizer(
+ prompts,
+ return_tensors="pt",
+ padding=True,
+ truncation=True,
+ max_length=512
+ ).to(self.model.device)
+
+ outputs = self.model(**inputs)
+ # Take logits of the last token
+ # shape: [batch, seq_len, vocab_size]
+ logits = outputs.logits
+
+ # We want the logits for the token position immediately after the prompt ends.
+ # But since we generated inputs directly from tokenizer(prompts),
+ # we look at the last position of the input.
+ # For causal LM, we usually look at the logits of the last token
+ # to predict the *next* token (which we hope is 'yes' or 'no').
+
+ # Get logits for the next token prediction (last position)
+ # For each sequence in batch, select the last token's logits
+ # inputs['input_ids'] shape: [B, L]
+ # logits shape: [B, L, V]
+ # We want logits[:, -1, :]
+
+ last_token_logits = logits[:, -1, :]
+
+ # Calculate log prob of 'yes'
+ # We can use log_softmax over the vocab dimension
+ log_probs = torch.log_softmax(last_token_logits, dim=-1)
+ yes_log_probs = log_probs[:, self.yes_token_id]
+
+ scores.extend(yes_log_probs.float().cpu().numpy().tolist())
+
+ return scores
+