summaryrefslogtreecommitdiff
path: root/src/personalization/models/llm/vllm_chat.py
blob: b5c3a05e7c4e046f777792aac32eef329d552cc2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
"""
vLLM-based ChatModel implementation for high-throughput inference.

This provides the same interface as LlamaChatModel but uses vLLM HTTP API
for much faster inference (3000+ sessions/hr vs 20 sessions/hr).
"""

from typing import List, Optional
import time
import requests

from personalization.models.llm.base import ChatModel
from personalization.types import ChatTurn


class VLLMChatModel(ChatModel):
    """
    ChatModel implementation using vLLM HTTP API.

    This is a drop-in replacement for LlamaChatModel that uses vLLM
    for much faster inference.
    """

    def __init__(
        self,
        vllm_url: str = "http://localhost:8003/v1",
        model_name: str = None,
        max_context_length: int = 8192,
        timeout: int = 120,
    ):
        self.vllm_url = vllm_url.rstrip('/')
        self.model_name = model_name
        self.max_context_length = max_context_length
        self.timeout = timeout

        # Discover model name if not provided
        if self.model_name is None:
            self._discover_model()

    def _discover_model(self):
        """Discover the model name from the vLLM server."""
        max_retries = 30
        for attempt in range(max_retries):
            try:
                response = requests.get(f"{self.vllm_url}/models", timeout=10)
                response.raise_for_status()
                models = response.json()
                if models.get("data") and len(models["data"]) > 0:
                    self.model_name = models["data"][0]["id"]
                    return
            except Exception as e:
                if attempt < max_retries - 1:
                    wait_time = min(2 ** attempt * 0.5, 10)
                    time.sleep(wait_time)

        # Fallback
        self.model_name = "default"
        print(f"[VLLMChatModel] Warning: Could not discover model, using '{self.model_name}'")

    def health_check(self) -> bool:
        """Check if the vLLM server is healthy."""
        try:
            response = requests.get(f"{self.vllm_url.replace('/v1', '')}/health", timeout=5)
            return response.status_code == 200
        except:
            return False

    def _estimate_tokens(self, text: str) -> int:
        """Estimate token count using character-based heuristic.

        For Llama models, ~4 characters per token is a reasonable estimate.
        We use 3.5 to be conservative (slightly overestimate tokens).
        """
        return int(len(text) / 3.5)

    def _build_messages(
        self,
        history: List[ChatTurn],
        memory_notes: List[str],
        max_new_tokens: int = 512,
    ) -> List[dict]:
        """Build messages list for chat completion API with auto-truncation.

        If the context exceeds max_context_length, older conversation turns
        are removed to keep only the most recent context that fits.
        """
        # Use CollaborativeAgents-style system prompt
        if memory_notes:
            bullet = "\n".join(f"- {n}" for n in memory_notes)
            system_content = (
                "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
                "# User Preferences\n"
                "The user has a set of preferences for how you should behave. If you do not follow these preferences, "
                "the user will be unable to learn from your response and you will need to adjust your response to adhere "
                "to these preferences (so it is best to follow them initially).\n"
                "Based on your past interactions with the user, you have maintained a set of notes about the user's preferences:\n"
                f"{bullet}\n\n"
                "# Conversation Guidelines:\n"
                "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
                "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
                "- Your goal is to help the user solve their problem. Adhere to their preferences and do your best to help them solve their problem.\n"
            )
        else:
            # Vanilla mode - no preferences
            system_content = (
                "You are a collaborative AI agent helping users solve writing, question answering, math, and coding problems.\n\n"
                "# Conversation Guidelines:\n"
                "- If the user's message is unclear, lacks details, or is ambiguous (e.g. length of an essay, format requirements, "
                "specific constraints), do not make assumptions. Ask for clarification and ensure you have enough information before providing an answer.\n"
                "- Your goal is to help the user solve their problem. Do your best to help them.\n"
            )
        system_message = {"role": "system", "content": system_content}

        # Calculate available tokens for conversation history
        # Reserve space for: system prompt + max_new_tokens + safety margin
        system_tokens = self._estimate_tokens(system_content)
        available_tokens = self.max_context_length - system_tokens - max_new_tokens - 100  # 100 token safety margin

        # Build conversation messages from history
        conversation_messages = []
        for turn in history:
            conversation_messages.append({"role": turn.role, "content": turn.text})

        # Check if truncation is needed
        total_conv_tokens = sum(self._estimate_tokens(m["content"]) for m in conversation_messages)

        if total_conv_tokens > available_tokens:
            # Truncate from the beginning (keep recent messages)
            truncated_messages = []
            current_tokens = 0

            # Iterate from most recent to oldest
            for msg in reversed(conversation_messages):
                msg_tokens = self._estimate_tokens(msg["content"])
                if current_tokens + msg_tokens <= available_tokens:
                    truncated_messages.insert(0, msg)
                    current_tokens += msg_tokens
                else:
                    # Stop adding older messages
                    break

            conversation_messages = truncated_messages
            if len(truncated_messages) < len(history):
                print(f"[VLLMChatModel] Truncated context: kept {len(truncated_messages)}/{len(history)} turns "
                      f"({current_tokens}/{total_conv_tokens} estimated tokens)")

        messages = [system_message] + conversation_messages
        return messages

    def build_messages(
        self,
        history: List[ChatTurn],
        memory_notes: List[str],
        max_new_tokens: int = 512,
    ) -> List[dict]:
        """Public method to build messages without calling the API.

        Used for batch processing where messages are collected first,
        then sent in batch to vLLM for concurrent processing.
        """
        return self._build_messages(history, memory_notes, max_new_tokens)

    def answer(
        self,
        history: List[ChatTurn],
        memory_notes: List[str],
        max_new_tokens: int = 512,
        temperature: float = 0.7,
        top_p: float = 0.9,
        top_k: Optional[int] = None,
    ) -> str:
        """Generate a response using vLLM HTTP API."""
        messages = self._build_messages(history, memory_notes, max_new_tokens)

        payload = {
            "model": self.model_name,
            "messages": messages,
            "max_tokens": max_new_tokens,
            "temperature": temperature,
            "top_p": top_p,
        }

        # Retry with exponential backoff
        max_retries = 5
        for attempt in range(max_retries):
            try:
                response = requests.post(
                    f"{self.vllm_url}/chat/completions",
                    json=payload,
                    timeout=self.timeout
                )

                if response.status_code == 200:
                    result = response.json()
                    return result["choices"][0]["message"]["content"]
                elif response.status_code == 400:
                    error_text = response.text
                    # Handle context length error
                    if "max_tokens" in error_text and max_new_tokens > 64:
                        payload["max_tokens"] = max(64, max_new_tokens // 2)
                        continue
                    raise RuntimeError(f"vLLM error: {error_text[:200]}")
                else:
                    raise RuntimeError(f"vLLM HTTP {response.status_code}: {response.text[:200]}")

            except requests.exceptions.Timeout:
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                    continue
                raise RuntimeError("vLLM request timeout")
            except requests.exceptions.ConnectionError as e:
                if attempt < max_retries - 1:
                    time.sleep(2 ** attempt)
                    continue
                raise RuntimeError(f"vLLM connection error: {e}")

        raise RuntimeError("Max retries exceeded for vLLM request")