summaryrefslogtreecommitdiff
path: root/src/personalization/models/llm
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 15:43:42 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 15:43:42 -0600
commitf918fc90b8d71d1287590b016d926268be573de0 (patch)
treed9009c8612c8e7f866c31d22fb979892a5b55eeb /src/personalization/models/llm
parent680513b7771a29f27cbbb3ffb009a69a913de6f9 (diff)
Add model wrapper modules (embedding, reranker, llm, preference_extractor)
Add Python wrappers for: - Qwen3/Nemotron embedding models - BGE/Qwen3 rerankers - vLLM/Llama/Qwen LLM backends - GPT-4o/LLM-based preference extractors Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/models/llm')
-rw-r--r--src/personalization/models/llm/__init__.py4
-rw-r--r--src/personalization/models/llm/base.py29
-rw-r--r--src/personalization/models/llm/llama_instruct.py129
-rw-r--r--src/personalization/models/llm/prompt_builder.py0
-rw-r--r--src/personalization/models/llm/qwen_instruct.py164
-rw-r--r--src/personalization/models/llm/vllm_chat.py217
6 files changed, 543 insertions, 0 deletions
diff --git a/src/personalization/models/llm/__init__.py b/src/personalization/models/llm/__init__.py
new file mode 100644
index 0000000..3f1af81
--- /dev/null
+++ b/src/personalization/models/llm/__init__.py
@@ -0,0 +1,4 @@
+from .qwen_instruct import QwenInstruct
+
+__all__ = ["QwenInstruct"]
+
diff --git a/src/personalization/models/llm/base.py b/src/personalization/models/llm/base.py
new file mode 100644
index 0000000..72b6ca8
--- /dev/null
+++ b/src/personalization/models/llm/base.py
@@ -0,0 +1,29 @@
+from typing import List, Protocol, Optional
+from personalization.types import ChatTurn
+
+class ChatModel(Protocol):
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ """
+ Generate an assistant response given conversation history and memory notes.
+
+ Args:
+ history: The conversation history ending with the current user turn.
+ memory_notes: List of retrieved memory content strings.
+ max_new_tokens: Max tokens to generate.
+ temperature: Sampling temperature.
+ top_p: Top-p sampling.
+ top_k: Top-k sampling.
+
+ Returns:
+ The generated assistant response text.
+ """
+ ...
+
diff --git a/src/personalization/models/llm/llama_instruct.py b/src/personalization/models/llm/llama_instruct.py
new file mode 100644
index 0000000..bdf0dff
--- /dev/null
+++ b/src/personalization/models/llm/llama_instruct.py
@@ -0,0 +1,129 @@
+from typing import List, Optional
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer
+
+from personalization.models.llm.base import ChatModel
+from personalization.types import ChatTurn
+
+class LlamaChatModel(ChatModel):
+ def __init__(
+ self,
+ model_path: str,
+ device: str = "cuda",
+ dtype: str = "bfloat16", # Keep type hint as str for legacy, but handle torch.dtype
+ max_context_length: int = 8192,
+ ):
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
+
+ # Handle dtype if it's already a torch.dtype object
+ if isinstance(dtype, str):
+ torch_dtype = getattr(torch, dtype)
+ else:
+ torch_dtype = dtype
+
+ # Handle specific device assignment (e.g., "cuda:0", "cuda:1")
+ if device and device.startswith("cuda:"):
+ # Load to CPU first, then move to specific GPU
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=torch_dtype,
+ device_map=None,
+ low_cpu_mem_usage=True,
+ )
+ self.model = self.model.to(device)
+ else:
+ # Use accelerate's device mapping
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=torch_dtype,
+ device_map=device,
+ )
+
+ self.max_context_length = max_context_length
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ def _build_prompt(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ ) -> str:
+ memory_block = ""
+ if memory_notes:
+ bullet = "\n".join(f"- {n}" for n in memory_notes)
+ memory_block = (
+ "Here are the user's preferences and memories:\n"
+ f"{bullet}\n\n"
+ )
+
+ # Build prompt manually or use chat template if available.
+ # Llama 3 use specific tags.
+ # <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n ... <|eot_id|>
+ # But we can try to use tokenizer.apply_chat_template if it exists.
+
+ if hasattr(self.tokenizer, "apply_chat_template"):
+ messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}]
+ for turn in history:
+ messages.append({"role": turn.role, "content": turn.text})
+ return self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
+
+ # Fallback manual construction (simplified Llama 2/3 style or generic)
+ # This is risky for Llama 3 specifically which needs exact tokens.
+ # Let's assume apply_chat_template works for Llama-3-Instruct models.
+
+ # If fallback needed:
+ history_lines = []
+ for turn in history[-8:]:
+ role_tag = "user" if turn.role == "user" else "assistant"
+ # Generic format
+ history_lines.append(f"{role_tag}: {turn.text}")
+
+ prompt = (
+ "System: You are a helpful assistant.\n"
+ + memory_block
+ + "\n".join(history_lines)
+ + "\nassistant:"
+ )
+ return prompt
+
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ prompt = self._build_prompt(history, memory_notes)
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True,
+ max_length=self.max_context_length).to(self.model.device)
+
+ gen_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": temperature > 0,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+ if top_k is not None:
+ gen_kwargs["top_k"] = top_k
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ eos_token_id=self.tokenizer.eos_token_id,
+ **gen_kwargs,
+ )
+ full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+ # naive stripping
+ # With chat template, 'full' usually contains the whole conversation.
+ # We need to extract just the new part.
+ # But 'prompt' string might not match decoded output exactly due to special tokens skipping.
+ # Better: slice output ids.
+
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
+
+ return answer_text
+
diff --git a/src/personalization/models/llm/prompt_builder.py b/src/personalization/models/llm/prompt_builder.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/src/personalization/models/llm/prompt_builder.py
diff --git a/src/personalization/models/llm/qwen_instruct.py b/src/personalization/models/llm/qwen_instruct.py
new file mode 100644
index 0000000..cf2047d
--- /dev/null
+++ b/src/personalization/models/llm/qwen_instruct.py
@@ -0,0 +1,164 @@
+from typing import List, Optional, Dict, Any
+from transformers import AutoModelForCausalLM, AutoTokenizer
+import torch
+
+from personalization.models.llm.base import ChatModel
+from personalization.types import ChatTurn
+from personalization.config.settings import LocalModelsConfig
+from personalization.config.registry import choose_dtype, choose_device_map
+
+class QwenInstruct(ChatModel):
+ def __init__(
+ self,
+ model_path: str,
+ device: str = "cuda",
+ dtype: torch.dtype = torch.bfloat16,
+ max_context_length: int = 4096,
+ ):
+ self.tokenizer = AutoTokenizer.from_pretrained(
+ model_path,
+ use_fast=True,
+ trust_remote_code=True,
+ )
+ self.model = AutoModelForCausalLM.from_pretrained(
+ model_path,
+ torch_dtype=dtype, # dtype is already torch.dtype, no getattr needed
+ device_map=device,
+ trust_remote_code=True,
+ )
+ self.max_context_length = max_context_length
+ if self.tokenizer.pad_token_id is None:
+ self.tokenizer.pad_token = self.tokenizer.eos_token
+
+ # Legacy helper for manual generation without template
+ @torch.inference_mode()
+ def generate(
+ self,
+ prompt: str,
+ max_new_tokens: int = 256,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ stop: Optional[List[str]] = None,
+ top_k: Optional[int] = None,
+ ) -> str:
+ inputs = self.tokenizer(prompt, return_tensors="pt").to(self.model.device)
+
+ gen_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": temperature > 0,
+ "temperature": temperature,
+ "top_p": top_p,
+ "pad_token_id": self.tokenizer.pad_token_id,
+ "eos_token_id": self.tokenizer.eos_token_id,
+ }
+ if top_k is not None:
+ gen_kwargs["top_k"] = top_k
+
+ outputs = self.model.generate(
+ **inputs,
+ **gen_kwargs
+ )
+ # Return only the newly generated portion, not the echoed prompt
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ text = self.tokenizer.decode(gen_ids, skip_special_tokens=True)
+ if stop:
+ for s in stop:
+ if s in text:
+ text = text.split(s)[0]
+ break
+ return text
+
+ def _build_prompt(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ ) -> str:
+ """
+ Construct prompt using ChatML-like structure via apply_chat_template if available,
+ or manual construction. Qwen usually supports apply_chat_template.
+ We will map ChatTurn to messages list.
+ """
+ memory_block = ""
+ if memory_notes:
+ bullet = "\n".join(f"- {n}" for n in memory_notes)
+ memory_block = (
+ "Here are the user's preferences and memories:\n"
+ f"{bullet}\n\n"
+ )
+
+ messages = [{"role": "system", "content": "You are a helpful assistant.\n" + memory_block}]
+
+ for turn in history:
+ messages.append({"role": turn.role, "content": turn.text})
+
+ return self.tokenizer.apply_chat_template(
+ messages, tokenize=False, add_generation_prompt=True
+ )
+
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ # Compatibility check: if history is dict list (legacy), convert to ChatTurn
+ # This allows old code to work if not fully updated, though we should update callers.
+ # But ChatTurn is required by Protocol. We assume callers are updated.
+ if history and isinstance(history[0], dict):
+ # Auto-convert for safety during migration
+ history = [ChatTurn(
+ user_id="unknown", session_id="unknown", turn_id=i,
+ role=h["role"], text=h["content"]
+ ) for i, h in enumerate(history)]
+
+ prompt = self._build_prompt(history, memory_notes)
+ inputs = self.tokenizer(prompt, return_tensors="pt", truncation=True,
+ max_length=self.max_context_length).to(self.model.device)
+
+ gen_kwargs = {
+ "max_new_tokens": max_new_tokens,
+ "do_sample": temperature > 0,
+ "temperature": temperature,
+ "top_p": top_p,
+ "pad_token_id": self.tokenizer.pad_token_id,
+ "eos_token_id": self.tokenizer.eos_token_id,
+ }
+ if top_k is not None:
+ gen_kwargs["top_k"] = top_k
+
+ with torch.no_grad():
+ outputs = self.model.generate(
+ **inputs,
+ **gen_kwargs,
+ )
+
+ full = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
+ # remove prompt part manually since we didn't use self.generate helper here to keep full control
+ # input_ids length is inputs['input_ids'].shape[1]
+ input_len = inputs["input_ids"].shape[1]
+ gen_ids = outputs[0][input_len:]
+ answer_text = self.tokenizer.decode(gen_ids, skip_special_tokens=True).strip()
+
+ return answer_text
+
+ # Factory method for legacy config loading
+ @classmethod
+ def from_config(cls, cfg: LocalModelsConfig) -> "QwenInstruct":
+ spec = cfg.llm
+ dtype = choose_dtype(spec.dtype)
+ device_map = choose_device_map(spec.device_map)
+ # device_map usually handled by transformers if passed as device_map argument
+ # Here we pass it as 'device' arg to constructor if it is a string like "cuda:0"
+ # If it is "auto", constructor might need adjustment or we trust transformers.
+ # Our constructor takes 'device' string.
+ device = spec.device_map if isinstance(spec.device_map, str) else "cuda"
+
+ return cls(
+ model_path=spec.local_path,
+ device=device, # Pass string
+ dtype=spec.dtype # Pass string name, constructor converts
+ )
diff --git a/src/personalization/models/llm/vllm_chat.py b/src/personalization/models/llm/vllm_chat.py
new file mode 100644
index 0000000..b5c3a05
--- /dev/null
+++ b/src/personalization/models/llm/vllm_chat.py
@@ -0,0 +1,217 @@
+"""
+vLLM-based ChatModel implementation for high-throughput inference.
+
+This provides the same interface as LlamaChatModel but uses vLLM HTTP API
+for much faster inference (3000+ sessions/hr vs 20 sessions/hr).
+"""
+
+from typing import List, Optional
+import time
+import requests
+
+from personalization.models.llm.base import ChatModel
+from personalization.types import ChatTurn
+
+
+class VLLMChatModel(ChatModel):
+ """
+ ChatModel implementation using vLLM HTTP API.
+
+ This is a drop-in replacement for LlamaChatModel that uses vLLM
+ for much faster inference.
+ """
+
+ def __init__(
+ self,
+ vllm_url: str = "http://localhost:8003/v1",
+ model_name: str = None,
+ max_context_length: int = 8192,
+ timeout: int = 120,
+ ):
+ self.vllm_url = vllm_url.rstrip('/')
+ self.model_name = model_name
+ self.max_context_length = max_context_length
+ self.timeout = timeout
+
+ # Discover model name if not provided
+ if self.model_name is None:
+ self._discover_model()
+
+ def _discover_model(self):
+ """Discover the model name from the vLLM server."""
+ max_retries = 30
+ for attempt in range(max_retries):
+ try:
+ response = requests.get(f"{self.vllm_url}/models", timeout=10)
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self.model_name = models["data"][0]["id"]
+ return
+ except Exception as e:
+ if attempt < max_retries - 1:
+ wait_time = min(2 ** attempt * 0.5, 10)
+ time.sleep(wait_time)
+
+ # Fallback
+ self.model_name = "default"
+ print(f"[VLLMChatModel] Warning: Could not discover model, using '{self.model_name}'")
+
+ def health_check(self) -> bool:
+ """Check if the vLLM server is healthy."""
+ try:
+ response = requests.get(f"{self.vllm_url.replace('/v1', '')}/health", timeout=5)
+ return response.status_code == 200
+ except:
+ return False
+
+ def _estimate_tokens(self, text: str) -> int:
+ """Estimate token count using character-based heuristic.
+
+ For Llama models, ~4 characters per token is a reasonable estimate.
+ We use 3.5 to be conservative (slightly overestimate tokens).
+ """
+ return int(len(text) / 3.5)
+
+ def _build_messages(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ ) -> List[dict]:
+ """Build messages list for chat completion API with auto-truncation.
+
+ If the context exceeds max_context_length, older conversation turns
+ are removed to keep only the most recent context that fits.
+ """
+ # Use CollaborativeAgents-style system prompt
+ if memory_notes:
+ bullet = "\n".join(f"- {n}" for n in memory_notes)
+ system_content = (
+ "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
+ "# User Preferences\n"
+ "The user has a set of preferences for how you should behave. If you do not follow these preferences, "
+ "the user will be unable to learn from your response and you will need to adjust your response to adhere "
+ "to these preferences (so it is best to follow them initially).\n"
+ "Based on your past interactions with the user, you have maintained a set of notes about the user's preferences:\n"
+ f"{bullet}\n\n"
+ "# Conversation Guidelines:\n"
+ "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
+ "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
+ "- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem.\n"
+ )
+ else:
+ # Vanilla mode - no preferences
+ system_content = (
+ "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
+ "# Conversation Guidelines:\n"
+ "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
+ "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
+ "- Your goal is to help the user solve their problem. Do your best to help them.\n"
+ )
+ system_message = {"role": "system", "content": system_content}
+
+ # Calculate available tokens for conversation history
+ # Reserve space for: system prompt + max_new_tokens + safety margin
+ system_tokens = self._estimate_tokens(system_content)
+ available_tokens = self.max_context_length - system_tokens - max_new_tokens - 100 # 100 token safety margin
+
+ # Build conversation messages from history
+ conversation_messages = []
+ for turn in history:
+ conversation_messages.append({"role": turn.role, "content": turn.text})
+
+ # Check if truncation is needed
+ total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation_messages)
+
+ if total_conv_tokens > available_tokens:
+ # Truncate from the beginning (keep recent messages)
+ truncated_messages = []
+ current_tokens = 0
+
+ # Iterate from most recent to oldest
+ for msg in reversed(conversation_messages):
+ msg_tokens = self._estimate_tokens(msg["content"])
+ if current_tokens + msg_tokens <= available_tokens:
+ truncated_messages.insert(0, msg)
+ current_tokens += msg_tokens
+ else:
+ # Stop adding older messages
+ break
+
+ conversation_messages = truncated_messages
+ if len(truncated_messages) < len(history):
+ print(f"[VLLMChatModel] Truncated context: kept {len(truncated_messages)}/{len(history)} turns "
+ f"({current_tokens}/{total_conv_tokens} estimated tokens)")
+
+ messages = [system_message] + conversation_messages
+ return messages
+
+ def build_messages(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ ) -> List[dict]:
+ """Public method to build messages without calling the API.
+
+ Used for batch processing where messages are collected first,
+ then sent in batch to vLLM for concurrent processing.
+ """
+ return self._build_messages(history, memory_notes, max_new_tokens)
+
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ """Generate a response using vLLM HTTP API."""
+ messages = self._build_messages(history, memory_notes, max_new_tokens)
+
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": max_new_tokens,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+
+ # Retry with exponential backoff
+ max_retries = 5
+ for attempt in range(max_retries):
+ try:
+ response = requests.post(
+ f"{self.vllm_url}/chat/completions",
+ json=payload,
+ timeout=self.timeout
+ )
+
+ if response.status_code == 200:
+ result = response.json()
+ return result["choices"][0]["message"]["content"]
+ elif response.status_code == 400:
+ error_text = response.text
+ # Handle context length error
+ if "max_tokens" in error_text and max_new_tokens > 64:
+ payload["max_tokens"] = max(64, max_new_tokens // 2)
+ continue
+ raise RuntimeError(f"vLLM error: {error_text[:200]}")
+ else:
+ raise RuntimeError(f"vLLM HTTP {response.status_code}: {response.text[:200]}")
+
+ except requests.exceptions.Timeout:
+ if attempt < max_retries - 1:
+ time.sleep(2 ** attempt)
+ continue
+ raise RuntimeError("vLLM request timeout")
+ except requests.exceptions.ConnectionError as e:
+ if attempt < max_retries - 1:
+ time.sleep(2 ** attempt)
+ continue
+ raise RuntimeError(f"vLLM connection error: {e}")
+
+ raise RuntimeError("Max retries exceeded for vLLM request")