summaryrefslogtreecommitdiff
path: root/src/personalization/models/llm/vllm_chat.py
diff options
context:
space:
mode:
authorYurenHao0426 <blackhao0426@gmail.com>2026-01-27 15:43:42 -0600
committerYurenHao0426 <blackhao0426@gmail.com>2026-01-27 15:43:42 -0600
commitf918fc90b8d71d1287590b016d926268be573de0 (patch)
treed9009c8612c8e7f866c31d22fb979892a5b55eeb /src/personalization/models/llm/vllm_chat.py
parent680513b7771a29f27cbbb3ffb009a69a913de6f9 (diff)
Add model wrapper modules (embedding, reranker, llm, preference_extractor)
Add Python wrappers for: - Qwen3/Nemotron embedding models - BGE/Qwen3 rerankers - vLLM/Llama/Qwen LLM backends - GPT-4o/LLM-based preference extractors Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
Diffstat (limited to 'src/personalization/models/llm/vllm_chat.py')
-rw-r--r--src/personalization/models/llm/vllm_chat.py217
1 files changed, 217 insertions, 0 deletions
diff --git a/src/personalization/models/llm/vllm_chat.py b/src/personalization/models/llm/vllm_chat.py
new file mode 100644
index 0000000..b5c3a05
--- /dev/null
+++ b/src/personalization/models/llm/vllm_chat.py
@@ -0,0 +1,217 @@
+"""
+vLLM-based ChatModel implementation for high-throughput inference.
+
+This provides the same interface as LlamaChatModel but uses vLLM HTTP API
+for much faster inference (3000+ sessions/hr vs 20 sessions/hr).
+"""
+
+from typing import List, Optional
+import time
+import requests
+
+from personalization.models.llm.base import ChatModel
+from personalization.types import ChatTurn
+
+
+class VLLMChatModel(ChatModel):
+ """
+ ChatModel implementation using vLLM HTTP API.
+
+ This is a drop-in replacement for LlamaChatModel that uses vLLM
+ for much faster inference.
+ """
+
+ def __init__(
+ self,
+ vllm_url: str = "http://localhost:8003/v1",
+ model_name: str = None,
+ max_context_length: int = 8192,
+ timeout: int = 120,
+ ):
+ self.vllm_url = vllm_url.rstrip('/')
+ self.model_name = model_name
+ self.max_context_length = max_context_length
+ self.timeout = timeout
+
+ # Discover model name if not provided
+ if self.model_name is None:
+ self._discover_model()
+
+ def _discover_model(self):
+ """Discover the model name from the vLLM server."""
+ max_retries = 30
+ for attempt in range(max_retries):
+ try:
+ response = requests.get(f"{self.vllm_url}/models", timeout=10)
+ response.raise_for_status()
+ models = response.json()
+ if models.get("data") and len(models["data"]) > 0:
+ self.model_name = models["data"][0]["id"]
+ return
+ except Exception as e:
+ if attempt < max_retries - 1:
+ wait_time = min(2 ** attempt * 0.5, 10)
+ time.sleep(wait_time)
+
+ # Fallback
+ self.model_name = "default"
+ print(f"[VLLMChatModel] Warning: Could not discover model, using '{self.model_name}'")
+
+ def health_check(self) -> bool:
+ """Check if the vLLM server is healthy."""
+ try:
+ response = requests.get(f"{self.vllm_url.replace('/v1', '')}/health", timeout=5)
+ return response.status_code == 200
+ except:
+ return False
+
+ def _estimate_tokens(self, text: str) -> int:
+ """Estimate token count using character-based heuristic.
+
+ For Llama models, ~4 characters per token is a reasonable estimate.
+ We use 3.5 to be conservative (slightly overestimate tokens).
+ """
+ return int(len(text) / 3.5)
+
+ def _build_messages(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ ) -> List[dict]:
+ """Build messages list for chat completion API with auto-truncation.
+
+ If the context exceeds max_context_length, older conversation turns
+ are removed to keep only the most recent context that fits.
+ """
+ # Use CollaborativeAgents-style system prompt
+ if memory_notes:
+ bullet = "\n".join(f"- {n}" for n in memory_notes)
+ system_content = (
+ "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
+ "# User Preferences\n"
+ "The user has a set of preferences for how you should behave. If you do not follow these preferences, "
+ "the user will be unable to learn from your response and you will need to adjust your response to adhere "
+ "to these preferences (so it is best to follow them initially).\n"
+ "Based on your past interactions with the user, you have maintained a set of notes about the user's preferences:\n"
+ f"{bullet}\n\n"
+ "# Conversation Guidelines:\n"
+ "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
+ "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
+ "- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem.\n"
+ )
+ else:
+ # Vanilla mode - no preferences
+ system_content = (
+ "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
+ "# Conversation Guidelines:\n"
+ "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
+ "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
+ "- Your goal is to help the user solve their problem. Do your best to help them.\n"
+ )
+ system_message = {"role": "system", "content": system_content}
+
+ # Calculate available tokens for conversation history
+ # Reserve space for: system prompt + max_new_tokens + safety margin
+ system_tokens = self._estimate_tokens(system_content)
+ available_tokens = self.max_context_length - system_tokens - max_new_tokens - 100 # 100 token safety margin
+
+ # Build conversation messages from history
+ conversation_messages = []
+ for turn in history:
+ conversation_messages.append({"role": turn.role, "content": turn.text})
+
+ # Check if truncation is needed
+ total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation_messages)
+
+ if total_conv_tokens > available_tokens:
+ # Truncate from the beginning (keep recent messages)
+ truncated_messages = []
+ current_tokens = 0
+
+ # Iterate from most recent to oldest
+ for msg in reversed(conversation_messages):
+ msg_tokens = self._estimate_tokens(msg["content"])
+ if current_tokens + msg_tokens <= available_tokens:
+ truncated_messages.insert(0, msg)
+ current_tokens += msg_tokens
+ else:
+ # Stop adding older messages
+ break
+
+ conversation_messages = truncated_messages
+ if len(truncated_messages) < len(history):
+ print(f"[VLLMChatModel] Truncated context: kept {len(truncated_messages)}/{len(history)} turns "
+ f"({current_tokens}/{total_conv_tokens} estimated tokens)")
+
+ messages = [system_message] + conversation_messages
+ return messages
+
+ def build_messages(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ ) -> List[dict]:
+ """Public method to build messages without calling the API.
+
+ Used for batch processing where messages are collected first,
+ then sent in batch to vLLM for concurrent processing.
+ """
+ return self._build_messages(history, memory_notes, max_new_tokens)
+
+ def answer(
+ self,
+ history: List[ChatTurn],
+ memory_notes: List[str],
+ max_new_tokens: int = 512,
+ temperature: float = 0.7,
+ top_p: float = 0.9,
+ top_k: Optional[int] = None,
+ ) -> str:
+ """Generate a response using vLLM HTTP API."""
+ messages = self._build_messages(history, memory_notes, max_new_tokens)
+
+ payload = {
+ "model": self.model_name,
+ "messages": messages,
+ "max_tokens": max_new_tokens,
+ "temperature": temperature,
+ "top_p": top_p,
+ }
+
+ # Retry with exponential backoff
+ max_retries = 5
+ for attempt in range(max_retries):
+ try:
+ response = requests.post(
+ f"{self.vllm_url}/chat/completions",
+ json=payload,
+ timeout=self.timeout
+ )
+
+ if response.status_code == 200:
+ result = response.json()
+ return result["choices"][0]["message"]["content"]
+ elif response.status_code == 400:
+ error_text = response.text
+ # Handle context length error
+ if "max_tokens" in error_text and max_new_tokens > 64:
+ payload["max_tokens"] = max(64, max_new_tokens // 2)
+ continue
+ raise RuntimeError(f"vLLM error: {error_text[:200]}")
+ else:
+ raise RuntimeError(f"vLLM HTTP {response.status_code}: {response.text[:200]}")
+
+ except requests.exceptions.Timeout:
+ if attempt < max_retries - 1:
+ time.sleep(2 ** attempt)
+ continue
+ raise RuntimeError("vLLM request timeout")
+ except requests.exceptions.ConnectionError as e:
+ if attempt < max_retries - 1:
+ time.sleep(2 ** attempt)
+ continue
+ raise RuntimeError(f"vLLM connection error: {e}")
+
+ raise RuntimeError("Max retries exceeded for vLLM request")