diff options
Diffstat (limited to 'src')
21 files changed, 1374 insertions, 0 deletions
diff --git a/src/personalization/models/__init__.py b/src/personalization/models/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/models/__init__.py diff --git a/src/personalization/models/embedding/__init__.py b/src/personalization/models/embedding/__init__.py new file mode 100644 index 0000000..05221aa --- /dev/null +++ b/src/personalization/models/embedding/__init__.py @@ -0,0 +1,11 @@ +from .base import EmbeddingModel +from .qwen3_8b import Qwen3Embedding8B +from .nemotron_8b import LlamaEmbedNemotron8B + +__all__ = [ + "EmbeddingModel", + "Qwen3Embedding8B", + "LlamaEmbedNemotron8B", +] + + diff --git a/src/personalization/models/embedding/base.py b/src/personalization/models/embedding/base.py new file mode 100644 index 0000000..9f9d4d1 --- /dev/null +++ b/src/personalization/models/embedding/base.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Iterable, List, Sequence + +import torch + + +class EmbeddingModel(ABC): + @abstractmethod + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + """Encode a batch of texts into dense embeddings.""" + raise NotImplementedError + + +def _mean_pool(last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor: + # last_hidden_state: [batch, seq_len, hidden] + # attention_mask: [batch, seq_len] + mask = attention_mask.unsqueeze(-1).type_as(last_hidden_state) # [b, s, 1] + summed = (last_hidden_state * mask).sum(dim=1) + counts = mask.sum(dim=1).clamp_min(1e-6) + return summed / counts + + +def _maybe_normalize(x: torch.Tensor, normalize: bool) -> torch.Tensor: + if not normalize: + return x + return torch.nn.functional.normalize(x, p=2, dim=-1) + + diff --git a/src/personalization/models/embedding/nemotron_8b.py b/src/personalization/models/embedding/nemotron_8b.py new file mode 100644 index 0000000..6348aee --- /dev/null +++ b/src/personalization/models/embedding/nemotron_8b.py @@ -0,0 +1,63 @@ +from __future__ import annotations + +from typing import List, Sequence + +import torch +from transformers import AutoModel, AutoTokenizer + +from personalization.config.registry import choose_dtype, choose_device_map +from personalization.config.settings import LocalModelsConfig +from .base import EmbeddingModel, _mean_pool, _maybe_normalize + + +class LlamaEmbedNemotron8B(EmbeddingModel): + def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=True + ) + self.model = AutoModel.from_pretrained( + model_path, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "LlamaEmbedNemotron8B": + if not cfg.embedding or not cfg.embedding.nemotron: + raise ValueError("Embedding config for nemotron is missing") + spec = cfg.embedding.nemotron + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls(spec.local_path, dtype=dtype, device_map=device_map) + + @torch.inference_mode() + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + device = next(self.model.parameters()).device + outputs: List[torch.Tensor] = [] + for i in range(0, len(texts), batch_size): + batch = list(texts[i : i + batch_size]) + enc = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ).to(device) + model_out = self.model(**enc, output_hidden_states=False, return_dict=True) + pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined] + pooled = _maybe_normalize(pooled, normalize) + outputs.append(pooled) + emb = torch.cat(outputs, dim=0) + if return_tensor: + return emb + return emb.cpu().to(torch.float32).tolist() + + diff --git a/src/personalization/models/embedding/qwen3_8b.py b/src/personalization/models/embedding/qwen3_8b.py new file mode 100644 index 0000000..fb02e67 --- /dev/null +++ b/src/personalization/models/embedding/qwen3_8b.py @@ -0,0 +1,89 @@ +from __future__ import annotations + +from typing import List, Sequence + +import torch +from transformers import AutoModel, AutoTokenizer + +from personalization.config.registry import choose_dtype, choose_device_map +from personalization.config.settings import LocalModelsConfig +from .base import EmbeddingModel, _mean_pool, _maybe_normalize + + +class Qwen3Embedding8B(EmbeddingModel): + def __init__( + self, + model_path: str, + dtype: torch.dtype, + device_map: str = "auto", + trust_remote_code: bool = True, + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=trust_remote_code + ) + + # Handle specific device assignment (e.g., "cuda:0", "cuda:1") + if device_map and device_map.startswith("cuda:"): + # Load to CPU first, then move to specific GPU + self.model = AutoModel.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, # Don't use accelerate's device_map + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(device_map) + else: + # Use accelerate's auto device mapping + self.model = AutoModel.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=trust_remote_code, + low_cpu_mem_usage=True, + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Embedding8B": + if not cfg.embedding or not cfg.embedding.qwen3: + raise ValueError("Embedding config for qwen3 is missing") + spec = cfg.embedding.qwen3 + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls( + spec.local_path, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + + @torch.inference_mode() + def encode( + self, + texts: Sequence[str], + batch_size: int = 8, + max_length: int = 512, + normalize: bool = True, + return_tensor: bool = False, + ) -> List[List[float]] | torch.Tensor: + device = next(self.model.parameters()).device + outputs: List[torch.Tensor] = [] + for i in range(0, len(texts), batch_size): + batch = list(texts[i : i + batch_size]) + enc = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=max_length, + return_tensors="pt", + ).to(device) + model_out = self.model(**enc, output_hidden_states=False, return_dict=True) + pooled = _mean_pool(model_out.last_hidden_state, enc["attention_mask"]) # type: ignore[attr-defined] + pooled = _maybe_normalize(pooled, normalize) + outputs.append(pooled) + emb = torch.cat(outputs, dim=0) + if return_tensor: + return emb + return emb.cpu().to(torch.float32).tolist() + + diff --git a/src/personalization/models/llm/__init__.py b/src/personalization/models/llm/__init__.py new file mode 100644 index 0000000..3f1af81 --- /dev/null +++ b/src/personalization/models/llm/__init__.py @@ -0,0 +1,4 @@ +from .qwen_instruct import QwenInstruct + +__all__ = ["QwenInstruct"] + diff --git a/src/personalization/models/llm/base.py b/src/personalization/models/llm/base.py new file mode 100644 index 0000000..72b6ca8 --- /dev/null +++ b/src/personalization/models/llm/base.py @@ -0,0 +1,29 @@ +from typing import List, Protocol, Optional +from personalization.types import ChatTurn + +class ChatModel(Protocol): + def answer( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: Optional[int] = None, + ) -> str: + """ + Generate an assistant response given conversation history and memory notes. + + Args: + history: The conversation history ending with the current user turn. + memory_notes: List of retrieved memory content strings. + max_new_tokens: Max tokens to generate. + temperature: Sampling temperature. + top_p: Top-p sampling. + top_k: Top-k sampling. + + Returns: + The generated assistant response text. + """ + ... + diff --git a/src/personalization/models/llm/llama_instruct.py b/src/personalization/models/llm/llama_instruct.py new file mode 100644 index 0000000..bdf0dff --- /dev/null +++ b/src/personalization/models/llm/llama_instruct.py @@ -0,0 +1,129 @@ +from typing import List, Optional +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from personalization.models.llm.base import ChatModel +from personalization.types import ChatTurn + +class LlamaChatModel(ChatModel): + def __init__( + self, + model_path: str, + device: str = "cuda", + dtype: str = "bfloat16", # Keep type hint as str for legacy, but handle torch.dtype + max_context_length: int = 8192, + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + # Handle dtype if it's already a torch.dtype object + if isinstance(dtype, str): + torch_dtype = getattr(torch, dtype) + else: + torch_dtype = dtype + + # Handle specific device assignment (e.g., "cuda:0", "cuda:1") + if device and device.startswith("cuda:"): + # Load to CPU first, then move to specific GPU + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch_dtype, + device_map=None, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(device) + else: + # Use accelerate's device mapping + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=torch_dtype, + device_map=device, + ) + + self.max_context_length = max_context_length + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + def _build_prompt( + self, + history: List[ChatTurn], + memory_notes: List[str], + ) -> str: + memory_block = "" + if memory_notes: + bullet = "\n".join(f"- {n}" for n in memory_notes) + memory_block = ( + "Here are the user's preferences and memories:\n" + f"{bullet}\n\n" + ) + + # Build prompt manually or use chat template if available. + # Llama 3 use specific tags. + # <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n ... <|eot_id|> + # But we can try to use tokenizer.apply_chat_template if it exists. + + if hasattr(self.tokenizer, "apply_chat_template"): + messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}] + for turn in history: + messages.append({"role": turn.role, "content": turn.text}) + return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + # Fallback manual construction (simplified Llama 2/3 style or generic) + # This is risky for Llama 3 specifically which needs exact tokens. + # Let's assume apply_chat_template works for Llama-3-Instruct models. + + # If fallback needed: + history_lines = [] + for turn in history[-8:]: + role_tag = "user" if turn.role == "user" else "assistant" + # Generic format + history_lines.append(f"{role_tag}: {turn.text}") + + prompt = ( + "System: You are a helpful assistant.\n" + + memory_block + + "\n".join(history_lines) + + "\nassistant:" + ) + return prompt + + def answer( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: Optional[int] = None, + ) -> str: + prompt = self._build_prompt(history, memory_notes) + inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, + max_length=self.max_context_length).to(self.model.device) + + gen_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0, + "temperature": temperature, + "top_p": top_p, + } + if top_k is not None: + gen_kwargs["top_k"] = top_k + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + eos_token_id=self.tokenizer.eos_token_id, + **gen_kwargs, + ) + full = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + # naive stripping + # With chat template, 'full' usually contains the whole conversation. + # We need to extract just the new part. + # But 'prompt' string might not match decoded output exactly due to special tokens skipping. + # Better: slice output ids. + + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip() + + return answer_text + diff --git a/src/personalization/models/llm/prompt_builder.py b/src/personalization/models/llm/prompt_builder.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/models/llm/prompt_builder.py diff --git a/src/personalization/models/llm/qwen_instruct.py b/src/personalization/models/llm/qwen_instruct.py new file mode 100644 index 0000000..cf2047d --- /dev/null +++ b/src/personalization/models/llm/qwen_instruct.py @@ -0,0 +1,164 @@ +from typing import List, Optional, Dict, Any +from transformers import AutoModelForCausalLM, AutoTokenizer +import torch + +from personalization.models.llm.base import ChatModel +from personalization.types import ChatTurn +from personalization.config.settings import LocalModelsConfig +from personalization.config.registry import choose_dtype, choose_device_map + +class QwenInstruct(ChatModel): + def __init__( + self, + model_path: str, + device: str = "cuda", + dtype: torch.dtype = torch.bfloat16, + max_context_length: int = 4096, + ): + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, + use_fast=True, + trust_remote_code=True, + ) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, # dtype is already torch.dtype, no getattr needed + device_map=device, + trust_remote_code=True, + ) + self.max_context_length = max_context_length + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Legacy helper for manual generation without template + @torch.inference_mode() + def generate( + self, + prompt: str, + max_new_tokens: int = 256, + temperature: float = 0.7, + top_p: float = 0.9, + stop: Optional[List[str]] = None, + top_k: Optional[int] = None, + ) -> str: + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + + gen_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0, + "temperature": temperature, + "top_p": top_p, + "pad_token_id": self.tokenizer.pad_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + } + if top_k is not None: + gen_kwargs["top_k"] = top_k + + outputs = self.model.generate( + **inputs, + **gen_kwargs + ) + # Return only the newly generated portion, not the echoed prompt + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) + if stop: + for s in stop: + if s in text: + text = text.split(s)[0] + break + return text + + def _build_prompt( + self, + history: List[ChatTurn], + memory_notes: List[str], + ) -> str: + """ + Construct prompt using ChatML-like structure via apply_chat_template if available, + or manual construction. Qwen usually supports apply_chat_template. + We will map ChatTurn to messages list. + """ + memory_block = "" + if memory_notes: + bullet = "\n".join(f"- {n}" for n in memory_notes) + memory_block = ( + "Here are the user's preferences and memories:\n" + f"{bullet}\n\n" + ) + + messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}] + + for turn in history: + messages.append({"role": turn.role, "content": turn.text}) + + return self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + def answer( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: Optional[int] = None, + ) -> str: + # Compatibility check: if history is dict list (legacy), convert to ChatTurn + # This allows old code to work if not fully updated, though we should update callers. + # But ChatTurn is required by Protocol. We assume callers are updated. + if history and isinstance(history[0], dict): + # Auto-convert for safety during migration + history = [ChatTurn( + user_id="unknown", session_id="unknown", turn_id=i, + role=h["role"], text=h["content"] + ) for i, h in enumerate(history)] + + prompt = self._build_prompt(history, memory_notes) + inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True, + max_length=self.max_context_length).to(self.model.device) + + gen_kwargs = { + "max_new_tokens": max_new_tokens, + "do_sample": temperature > 0, + "temperature": temperature, + "top_p": top_p, + "pad_token_id": self.tokenizer.pad_token_id, + "eos_token_id": self.tokenizer.eos_token_id, + } + if top_k is not None: + gen_kwargs["top_k"] = top_k + + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + **gen_kwargs, + ) + + full = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + # remove prompt part manually since we didn't use self.generate helper here to keep full control + # input_ids length is inputs['input_ids'].shape[1] + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip() + + return answer_text + + # Factory method for legacy config loading + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "QwenInstruct": + spec = cfg.llm + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + # device_map usually handled by transformers if passed as device_map argument + # Here we pass it as 'device' arg to constructor if it is a string like "cuda:0" + # If it is "auto", constructor might need adjustment or we trust transformers. + # Our constructor takes 'device' string. + device = spec.device_map if isinstance(spec.device_map, str) else "cuda" + + return cls( + model_path=spec.local_path, + device=device, # Pass string + dtype=spec.dtype # Pass string name, constructor converts + ) diff --git a/src/personalization/models/llm/vllm_chat.py b/src/personalization/models/llm/vllm_chat.py new file mode 100644 index 0000000..b5c3a05 --- /dev/null +++ b/src/personalization/models/llm/vllm_chat.py @@ -0,0 +1,217 @@ +""" +vLLM-based ChatModel implementation for high-throughput inference. + +This provides the same interface as LlamaChatModel but uses vLLM HTTP API +for much faster inference (3000+ sessions/hr vs 20 sessions/hr). +""" + +from typing import List, Optional +import time +import requests + +from personalization.models.llm.base import ChatModel +from personalization.types import ChatTurn + + +class VLLMChatModel(ChatModel): + """ + ChatModel implementation using vLLM HTTP API. + + This is a drop-in replacement for LlamaChatModel that uses vLLM + for much faster inference. + """ + + def __init__( + self, + vllm_url: str = "http://localhost:8003/v1", + model_name: str = None, + max_context_length: int = 8192, + timeout: int = 120, + ): + self.vllm_url = vllm_url.rstrip('/') + self.model_name = model_name + self.max_context_length = max_context_length + self.timeout = timeout + + # Discover model name if not provided + if self.model_name is None: + self._discover_model() + + def _discover_model(self): + """Discover the model name from the vLLM server.""" + max_retries = 30 + for attempt in range(max_retries): + try: + response = requests.get(f"{self.vllm_url}/models", timeout=10) + response.raise_for_status() + models = response.json() + if models.get("data") and len(models["data"]) > 0: + self.model_name = models["data"][0]["id"] + return + except Exception as e: + if attempt < max_retries - 1: + wait_time = min(2 ** attempt * 0.5, 10) + time.sleep(wait_time) + + # Fallback + self.model_name = "default" + print(f"[VLLMChatModel] Warning: Could not discover model, using '{self.model_name}'") + + def health_check(self) -> bool: + """Check if the vLLM server is healthy.""" + try: + response = requests.get(f"{self.vllm_url.replace('/v1', '')}/health", timeout=5) + return response.status_code == 200 + except: + return False + + def _estimate_tokens(self, text: str) -> int: + """Estimate token count using character-based heuristic. + + For Llama models, ~4 characters per token is a reasonable estimate. + We use 3.5 to be conservative (slightly overestimate tokens). + """ + return int(len(text) / 3.5) + + def _build_messages( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + ) -> List[dict]: + """Build messages list for chat completion API with auto-truncation. + + If the context exceeds max_context_length, older conversation turns + are removed to keep only the most recent context that fits. + """ + # Use CollaborativeAgents-style system prompt + if memory_notes: + bullet = "\n".join(f"- {n}" for n in memory_notes) + system_content = ( + "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n" + "# User Preferences\n" + "The user has a set of preferences for how you should behave. If you do not follow these preferences, " + "the user will be unable to learn from your response and you will need to adjust your response to adhere " + "to these preferences (so it is best to follow them initially).\n" + "Based on your past interactions with the user, you have maintained a set of notes about the user's preferences:\n" + f"{bullet}\n\n" + "# Conversation Guidelines:\n" + "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, " + "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n" + "- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem.\n" + ) + else: + # Vanilla mode - no preferences + system_content = ( + "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n" + "# Conversation Guidelines:\n" + "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, " + "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n" + "- Your goal is to help the user solve their problem. Do your best to help them.\n" + ) + system_message = {"role": "system", "content": system_content} + + # Calculate available tokens for conversation history + # Reserve space for: system prompt + max_new_tokens + safety margin + system_tokens = self._estimate_tokens(system_content) + available_tokens = self.max_context_length - system_tokens - max_new_tokens - 100 # 100 token safety margin + + # Build conversation messages from history + conversation_messages = [] + for turn in history: + conversation_messages.append({"role": turn.role, "content": turn.text}) + + # Check if truncation is needed + total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation_messages) + + if total_conv_tokens > available_tokens: + # Truncate from the beginning (keep recent messages) + truncated_messages = [] + current_tokens = 0 + + # Iterate from most recent to oldest + for msg in reversed(conversation_messages): + msg_tokens = self._estimate_tokens(msg["content"]) + if current_tokens + msg_tokens <= available_tokens: + truncated_messages.insert(0, msg) + current_tokens += msg_tokens + else: + # Stop adding older messages + break + + conversation_messages = truncated_messages + if len(truncated_messages) < len(history): + print(f"[VLLMChatModel] Truncated context: kept {len(truncated_messages)}/{len(history)} turns " + f"({current_tokens}/{total_conv_tokens} estimated tokens)") + + messages = [system_message] + conversation_messages + return messages + + def build_messages( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + ) -> List[dict]: + """Public method to build messages without calling the API. + + Used for batch processing where messages are collected first, + then sent in batch to vLLM for concurrent processing. + """ + return self._build_messages(history, memory_notes, max_new_tokens) + + def answer( + self, + history: List[ChatTurn], + memory_notes: List[str], + max_new_tokens: int = 512, + temperature: float = 0.7, + top_p: float = 0.9, + top_k: Optional[int] = None, + ) -> str: + """Generate a response using vLLM HTTP API.""" + messages = self._build_messages(history, memory_notes, max_new_tokens) + + payload = { + "model": self.model_name, + "messages": messages, + "max_tokens": max_new_tokens, + "temperature": temperature, + "top_p": top_p, + } + + # Retry with exponential backoff + max_retries = 5 + for attempt in range(max_retries): + try: + response = requests.post( + f"{self.vllm_url}/chat/completions", + json=payload, + timeout=self.timeout + ) + + if response.status_code == 200: + result = response.json() + return result["choices"][0]["message"]["content"] + elif response.status_code == 400: + error_text = response.text + # Handle context length error + if "max_tokens" in error_text and max_new_tokens > 64: + payload["max_tokens"] = max(64, max_new_tokens // 2) + continue + raise RuntimeError(f"vLLM error: {error_text[:200]}") + else: + raise RuntimeError(f"vLLM HTTP {response.status_code}: {response.text[:200]}") + + except requests.exceptions.Timeout: + if attempt < max_retries - 1: + time.sleep(2 ** attempt) + continue + raise RuntimeError("vLLM request timeout") + except requests.exceptions.ConnectionError as e: + if attempt < max_retries - 1: + time.sleep(2 ** attempt) + continue + raise RuntimeError(f"vLLM connection error: {e}") + + raise RuntimeError("Max retries exceeded for vLLM request") diff --git a/src/personalization/models/preference_extractor/__init__.py b/src/personalization/models/preference_extractor/__init__.py new file mode 100644 index 0000000..65e2595 --- /dev/null +++ b/src/personalization/models/preference_extractor/__init__.py @@ -0,0 +1,5 @@ +from .rule_extractor import QwenRuleExtractor +from .gpt4o_extractor import GPT4OExtractor +from .base import PreferenceExtractor + +__all__ = ["QwenRuleExtractor", "GPT4OExtractor", "PreferenceExtractor"] diff --git a/src/personalization/models/preference_extractor/base.py b/src/personalization/models/preference_extractor/base.py new file mode 100644 index 0000000..850292f --- /dev/null +++ b/src/personalization/models/preference_extractor/base.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from typing import Any, Dict, List +from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList + +class PreferenceExtractorBase(ABC): + @abstractmethod + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: + """ + Extract preferences from a window of chat turns (history + current query). + """ + raise NotImplementedError + +# Alias for backward compatibility if needed, +# though specific extractors should inherit from PreferenceExtractorBase now. +PreferenceExtractor = PreferenceExtractorBase diff --git a/src/personalization/models/preference_extractor/gpt4o_extractor.py b/src/personalization/models/preference_extractor/gpt4o_extractor.py new file mode 100644 index 0000000..212bb13 --- /dev/null +++ b/src/personalization/models/preference_extractor/gpt4o_extractor.py @@ -0,0 +1,97 @@ +from __future__ import annotations + +import json +import os +from typing import Any, Dict, List + +from openai import OpenAI +from personalization.config.settings import LocalModelsConfig +from personalization.models.preference_extractor.base import PreferenceExtractorBase as PreferenceExtractor +from personalization.retrieval.preference_store.schemas import ( + ChatTurn, + PreferenceList, + preference_list_json_schema, +) + + +class GPT4OExtractor(PreferenceExtractor): + def __init__(self, api_key: str, model: str = "gpt-4o") -> None: + self.client = OpenAI(api_key=api_key) + self.model = model + + # Load system prompt template + template_path = "fine_tuning_prompt_template.txt" + if os.path.exists(template_path): + with open(template_path, "r", encoding="utf-8") as f: + self.system_prompt = f.read() + else: + # Fallback simple prompt if file missing + self.system_prompt = ( + "You are a preference extraction assistant. " + "Extract user preferences from the query into a JSON object." + ) + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "GPT4OExtractor": + # We rely on env var for API key, config for other potential settings if needed + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + raise ValueError("OPENAI_API_KEY environment variable not set") + return cls(api_key=api_key) + + def build_preference_prompt(self, query: str) -> str: + # GPT4OExtractor uses the system prompt loaded in __init__ + return self.system_prompt + + def extract_preferences(self, query: str) -> Dict[str, Any]: + # Reuse logic but return raw dict + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": query}, + ], + response_format={"type": "json_object"}, + temperature=0.0, + ) + content = response.choices[0].message.content + if content: + return json.loads(content) + except Exception as e: + print(f"Error calling GPT-4o: {e}") + return {"preferences": []} + + def extract_turn(self, turn: ChatTurn) -> PreferenceList: + if turn.role != "user": + return PreferenceList(preferences=[]) + + try: + response = self.client.chat.completions.create( + model=self.model, + messages=[ + {"role": "system", "content": self.system_prompt}, + {"role": "user", "content": turn.text}, + ], + response_format={"type": "json_object"}, + temperature=0.0, + ) + + content = response.choices[0].message.content + if not content: + return PreferenceList(preferences=[]) + + data = json.loads(content) + # The prompt might return {"preferences": [...]}, validate it + return PreferenceList.model_validate(data) + + except Exception as e: + print(f"Error calling GPT-4o: {e}") + return PreferenceList(preferences=[]) + + def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]: + results = [] + for turn in turns: + results.append(self.extract_turn(turn)) + return results + diff --git a/src/personalization/models/preference_extractor/llm_extractor.py b/src/personalization/models/preference_extractor/llm_extractor.py new file mode 100644 index 0000000..8f7a6cb --- /dev/null +++ b/src/personalization/models/preference_extractor/llm_extractor.py @@ -0,0 +1,153 @@ +from typing import List, Dict, Any +import torch +import json +import os +from transformers import AutoModelForCausalLM, AutoTokenizer + +from personalization.models.preference_extractor.base import PreferenceExtractorBase +from personalization.retrieval.preference_store.schemas import ChatTurn, PreferenceList +from personalization.config.settings import LocalModelsConfig +from personalization.config.registry import choose_dtype, choose_device_map + +class PreferenceExtractorLLM(PreferenceExtractorBase): + def __init__( + self, + model_path: str, + prompt_template_path: str = "fine_tuning_prompt_template.txt", + device_map: str = "auto", + dtype: torch.dtype = torch.bfloat16, + max_new_tokens: int = 512, + ) -> None: + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + self.max_new_tokens = max_new_tokens + + if os.path.exists(prompt_template_path): + with open(prompt_template_path, "r", encoding="utf-8") as f: + self.prompt_template = f.read() + else: + print(f"Warning: Prompt template not found at {prompt_template_path}. Using fallback.") + self.prompt_template = "Extract user preferences from the following conversation." + + @classmethod + def from_config(cls, cfg: LocalModelsConfig, name: str = "qwen3_0_6b_sft") -> "PreferenceExtractorLLM": + # We need to access the specific extractor config by name + # Assuming cfg has a way to access extra configs or we update LocalModelsConfig to support multiple extractors + # For now, let's look for it in the 'preference_extractor' dict if it was a Dict, but it is a ModelSpec. + # We need to update LocalModelsConfig to support a dictionary of extractors or a specific one. + # Based on user design doc: + # preference_extractor: + # qwen3_0_6b_sft: ... + + # We might need to manually parse the raw config or update settings.py + # Let's assume settings.py will be updated to hold a map or specific fields. + # For now, if we use the existing ModelSpec for preference_extractor in cfg, we assume it points to this model. + + # BUT the design doc says "preference_extractor" in local_models.yaml will have "qwen3_0_6b_sft" key. + # The current settings.py defines preference_extractor as a single ModelSpec. + # We will need to update settings.py first to support multiple extractors or a dict. + # I will proceed implementing this class assuming arguments are passed, and update settings/registry later. + + # This from_config might change depending on how settings.py is refactored. + # For now I will implement it assuming a direct ModelSpec is passed, or we handle it in registry. + pass + return None + + def _build_prompt(self, turns: List[ChatTurn]) -> str: + # Construct messages list for chat template + messages = [{"role": "system", "content": self.prompt_template}] + + # Window size 6 + window = turns[-6:] + + # Add conversation history + # We need to format the conversation as input context. + # Since the task is to extract preferences from the *whole* context (or latest turn?), + # usually we provide the conversation and ask for extraction. + # But LLaMA-Factory SFT usually expects: + # System: <template> + # User: <input> + # Assistant: <output> + + # We should pack the conversation history into the User message? + # Or if we trained with multi-turn chat format? + # Assuming "Input" column in dataset was the conversation history. + + history_texts = [] + for t in window: + role = "User" if t.role == "user" else "Assistant" + history_texts.append(f"{role}: {t.text}") + + conversation_text = "\n".join(history_texts) + + # Construct the User input + # We append a trigger instruction if it wasn't part of the training input implicitly. + # But based on your template, the User Input Example was just the query "I am a Python developer..." + # So likely we should just feed the conversation text as the user message. + + messages.append({"role": "user", "content": conversation_text}) + + # Apply chat template + prompt = self.tokenizer.apply_chat_template( + messages, + tokenize=False, + add_generation_prompt=True + ) + + return prompt + + def _generate(self, prompt: str) -> str: + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + with torch.no_grad(): + outputs = self.model.generate( + **inputs, + max_new_tokens=self.max_new_tokens, + do_sample=False, + temperature=0.0, + eos_token_id=self.tokenizer.eos_token_id, + pad_token_id=self.tokenizer.pad_token_id, + ) + full_text = self.tokenizer.decode(outputs[0], skip_special_tokens=True) + return full_text[len(prompt):] + + def _parse_preferences(self, raw_output: str) -> PreferenceList: + start = raw_output.find("{") + end = raw_output.rfind("}") + + if start == -1 or end == -1 or end <= start: + return PreferenceList(preferences=[]) + + json_str = raw_output[start:end+1] + try: + data = json.loads(json_str) + return PreferenceList.model_validate(data) + except Exception: + return PreferenceList(preferences=[]) + + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: + prompt = self._build_prompt(turns) + raw_output = self._generate(prompt) + return self._parse_preferences(raw_output) + + # Legacy support + def build_preference_prompt(self, query: str) -> str: + # Wrap query in a dummy turn + turn = ChatTurn( + user_id="dummy", session_id="dummy", turn_id=0, + role="user", text=query + ) + return self._build_prompt([turn]) + + def extract_preferences(self, query: str) -> Dict[str, Any]: + turn = ChatTurn( + user_id="dummy", session_id="dummy", turn_id=0, + role="user", text=query + ) + prefs = self.extract_turn([turn]) + return prefs.model_dump() + diff --git a/src/personalization/models/preference_extractor/rule_extractor.py b/src/personalization/models/preference_extractor/rule_extractor.py new file mode 100644 index 0000000..0f743d9 --- /dev/null +++ b/src/personalization/models/preference_extractor/rule_extractor.py @@ -0,0 +1,152 @@ +from __future__ import annotations + +import json +import re +import os +from typing import Any, Dict, List + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + +from personalization.config.registry import choose_dtype, choose_device_map +from personalization.config.settings import LocalModelsConfig +from .base import PreferenceExtractor +from personalization.retrieval.preference_store.schemas import ( + PreferenceList, + preference_list_json_schema, + ChatTurn, +) + +# Hardcoded System Prompt to match SFT training +# This MUST match what was used in training (scripts/split_train_test.py) +SFT_SYSTEM_PROMPT = ( + "Extract user preferences from the query into JSON format based on the PreferenceList schema. " + "If no preferences are found, return {\"preferences\": []}." +) + +class QwenRuleExtractor(PreferenceExtractor): + """ + Extractor using a Fine-Tuned (SFT) Qwen model. + Despite the name 'RuleExtractor' (legacy), this now performs direct End-to-End extraction. + """ + def __init__(self, model_path: str, dtype: torch.dtype, device_map: str = "auto") -> None: + self.tokenizer = AutoTokenizer.from_pretrained( + model_path, use_fast=True, trust_remote_code=True + ) + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + if self.tokenizer.pad_token_id is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "QwenRuleExtractor": + spec = cfg.preference_extractor + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls(spec.local_path, dtype=dtype, device_map=device_map) + + def build_preference_prompt(self, query: str) -> str: + """ + Construct the prompt string using the tokenizer's chat template. + Matches the format seen during SFT training. + """ + messages = [ + {"role": "system", "content": SFT_SYSTEM_PROMPT}, + {"role": "user", "content": query} + ] + prompt = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + return prompt + + @torch.inference_mode() + def extract_preferences(self, query: str) -> Dict[str, Any]: + """ + Directly extract preferences from query using the SFT model. + Returns a dict compatible with PreferenceList model (key: 'preferences'). + """ + prompt = self.build_preference_prompt(query) + inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device) + + outputs = self.model.generate( + **inputs, + do_sample=False, # Deterministic greedy decoding + max_new_tokens=512, # Allow enough space for JSON + pad_token_id=self.tokenizer.pad_token_id, + eos_token_id=self.tokenizer.eos_token_id, + ) + + input_len = inputs["input_ids"].shape[1] + gen_ids = outputs[0][input_len:] + text = self.tokenizer.decode(gen_ids, skip_special_tokens=True) + + if os.getenv("PREF_DEBUG") == "1": + print(f"[debug][extractor] Raw output: {text}") + + # Try parsing JSON + try: + # 1. Direct parse + data = json.loads(text) + + # 2. Validate against schema structure + validated = PreferenceList.model_validate(data) + return validated.model_dump() + + except Exception: + # Fallback: Try to find JSON blob if model outputted extra text (rare for SFT but possible) + extracted_json = self._extract_json_substring(text) + if extracted_json: + try: + data = json.loads(extracted_json) + validated = PreferenceList.model_validate(data) + return validated.model_dump() + except: + pass + + # If all fails, return empty + return {"preferences": []} + + def _extract_json_substring(self, text: str) -> str | None: + """Helper to find { ... } block in text.""" + # Find first '{' and last '}' + start = text.find('{') + end = text.rfind('}') + if start != -1 and end != -1 and end > start: + return text[start : end + 1] + return None + + def extract_turn(self, turns: List[ChatTurn]) -> PreferenceList: + """ + Extract preferences from the LAST user turn in the history. + We don't concat history because our SFT model was trained on single-turn extraction. + Using context might confuse it unless we trained it that way. + """ + # Find the last user message + last_user_msg = None + for t in reversed(turns): + if t.role == "user": + last_user_msg = t.text + break + + if not last_user_msg: + return PreferenceList(preferences=[]) + + result_dict = self.extract_preferences(last_user_msg) + return PreferenceList.model_validate(result_dict) + + def extract_session(self, turns: List[ChatTurn]) -> List[PreferenceList]: + """ + Extract preferences from ALL user turns individually. + """ + results = [] + for turn in turns: + if turn.role == "user": + res = self.extract_preferences(turn.text) + results.append(PreferenceList.model_validate(res)) + else: + results.append(PreferenceList(preferences=[])) + return results diff --git a/src/personalization/models/reranker/__init__.py b/src/personalization/models/reranker/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/models/reranker/__init__.py diff --git a/src/personalization/models/reranker/base.py b/src/personalization/models/reranker/base.py new file mode 100644 index 0000000..34cf6ce --- /dev/null +++ b/src/personalization/models/reranker/base.py @@ -0,0 +1,16 @@ +from typing import List, Protocol + +class Reranker(Protocol): + def score( + self, + query: str, + docs: List[str], + **kwargs, + ) -> List[float]: + """ + Score multiple candidate documents for the same query. + Higher score indicates higher relevance. + Returns a list of floats with length equal to len(docs). + """ + ... + diff --git a/src/personalization/models/reranker/bge_reranker.py b/src/personalization/models/reranker/bge_reranker.py new file mode 100644 index 0000000..a672f0a --- /dev/null +++ b/src/personalization/models/reranker/bge_reranker.py @@ -0,0 +1,95 @@ +"""BGE Reranker - lightweight 278M parameter cross-encoder reranker.""" + +from typing import List +import torch +from transformers import AutoModelForSequenceClassification, AutoTokenizer +from .base import Reranker + + +class BGEReranker(Reranker): + """ + BGE Reranker using cross-encoder architecture. + + Much lighter than Qwen3-Reranker-8B: + - bge-reranker-base: 278M params + - bge-reranker-large: 560M params + """ + + def __init__( + self, + model_path: str = "BAAI/bge-reranker-base", + device_map: str = "auto", + dtype: torch.dtype = torch.float16 + ): + self.tokenizer = AutoTokenizer.from_pretrained(model_path) + + # Handle specific device assignment + if device_map and device_map.startswith("cuda:"): + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, + ) + self.model = self.model.to(device_map) + self.device = device_map + else: + self.model = AutoModelForSequenceClassification.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + ) + self.device = next(self.model.parameters()).device + + self.model.eval() + + def score( + self, + query: str, + docs: List[str], + batch_size: int = 32, + **kwargs, + ) -> List[float]: + """ + Score documents using cross-encoder. + + Args: + query: The query string + docs: List of document strings to score + batch_size: Batch size for processing + + Returns: + List of relevance scores (higher = more relevant) + """ + if not docs: + return [] + + # Create query-doc pairs + pairs = [[query, doc] for doc in docs] + + all_scores = [] + + with torch.no_grad(): + for i in range(0, len(pairs), batch_size): + batch = pairs[i:i + batch_size] + + # Tokenize + inputs = self.tokenizer( + batch, + padding=True, + truncation=True, + max_length=512, + return_tensors="pt" + ) + inputs = {k: v.to(self.device) for k, v in inputs.items()} + + # Get scores + outputs = self.model(**inputs) + scores = outputs.logits.squeeze(-1).float().cpu().tolist() + + # Handle single item case + if isinstance(scores, float): + scores = [scores] + + all_scores.extend(scores) + + return all_scores diff --git a/src/personalization/models/reranker/nemotron_reranker.py b/src/personalization/models/reranker/nemotron_reranker.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/src/personalization/models/reranker/nemotron_reranker.py diff --git a/src/personalization/models/reranker/qwen3_reranker.py b/src/personalization/models/reranker/qwen3_reranker.py new file mode 100644 index 0000000..b648421 --- /dev/null +++ b/src/personalization/models/reranker/qwen3_reranker.py @@ -0,0 +1,96 @@ +from typing import List +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer +from .base import Reranker +from personalization.config.settings import LocalModelsConfig +from personalization.config.registry import choose_dtype, choose_device_map + +class Qwen3Reranker(Reranker): + def __init__(self, model_path: str, device_map: str = "auto", dtype: torch.dtype = torch.bfloat16): + # Ensure we pass trust_remote_code=True for Qwen models + self.tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) + + # Handle specific device assignment (e.g., "cuda:0", "cuda:1") + if device_map and device_map.startswith("cuda:"): + # Load to CPU first, then move to specific GPU + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=None, + trust_remote_code=True, + low_cpu_mem_usage=True, + ) + self.model = self.model.to(device_map) + else: + # Use accelerate's auto device mapping + self.model = AutoModelForCausalLM.from_pretrained( + model_path, + torch_dtype=dtype, + device_map=device_map, + trust_remote_code=True, + ) + + self.yes_token_id = self.tokenizer("yes", add_special_tokens=False).input_ids[0] + + @classmethod + def from_config(cls, cfg: LocalModelsConfig) -> "Qwen3Reranker": + if not cfg.reranker or not cfg.reranker.qwen3_8b: + raise ValueError("Reranker config for qwen3_8b is missing") + spec = cfg.reranker.qwen3_8b + dtype = choose_dtype(spec.dtype) + device_map = choose_device_map(spec.device_map) + return cls(spec.local_path, device_map=device_map, dtype=dtype) + + def _build_prompt(self, query: str, doc: str) -> str: + return ( + "You are a reranker. " + "Given a user query and a memory note, answer 'yes' if the note is helpful " + "for answering the query, otherwise answer 'no'.\n\n" + f"Query: {query}\n" + f"Note: {doc}\n" + "Answer with a single token: yes or no." + ) + + @torch.inference_mode() + def score(self, query: str, docs: List[str], batch_size: int = 8, **kwargs) -> List[float]: + scores = [] + for i in range(0, len(docs), batch_size): + batch_docs = docs[i : i + batch_size] + prompts = [self._build_prompt(query, d) for d in batch_docs] + + inputs = self.tokenizer( + prompts, + return_tensors="pt", + padding=True, + truncation=True, + max_length=512 + ).to(self.model.device) + + outputs = self.model(**inputs) + # Take logits of the last token + # shape: [batch, seq_len, vocab_size] + logits = outputs.logits + + # We want the logits for the token position immediately after the prompt ends. + # But since we generated inputs directly from tokenizer(prompts), + # we look at the last position of the input. + # For causal LM, we usually look at the logits of the last token + # to predict the *next* token (which we hope is 'yes' or 'no'). + + # Get logits for the next token prediction (last position) + # For each sequence in batch, select the last token's logits + # inputs['input_ids'] shape: [B, L] + # logits shape: [B, L, V] + # We want logits[:, -1, :] + + last_token_logits = logits[:, -1, :] + + # Calculate log prob of 'yes' + # We can use log_softmax over the vocab dimension + log_probs = torch.log_softmax(last_token_logits, dim=-1) + yes_log_probs = log_probs[:, self.yes_token_id] + + scores.extend(yes_log_probs.float().cpu().numpy().tolist()) + + return scores + |
