summaryrefslogtreecommitdiff
path: root/putnam-bench-anon/loader/vllm_local.py
blob: bc8c4fba822d3fdb24acbb67f43e6e63f0781fbd (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
"""
VLLM local model loader implementation.
Handles API calls to locally deployed VLLM services with OpenAI-compatible endpoints.
"""

import asyncio
import random
from typing import Dict, List, Tuple, Optional

try:
    from openai import AsyncOpenAI, RateLimitError, APIError, APIConnectionError
except ImportError:
    AsyncOpenAI = None
    RateLimitError = Exception
    APIError = Exception
    APIConnectionError = Exception

from .base import ModelLoader
from .prompts import RESPONSE_FORMAT


class VLLMModelLoader(ModelLoader):
    """VLLM local model implementation of the ModelLoader."""
    
    def __init__(self, 
                 solver_model: str = "meta-llama/Llama-3.2-3B-Instruct",
                 grader_model: str = "meta-llama/Llama-3.2-8B-Instruct", 
                 base_url: str = "http://localhost:8000/v1",
                 api_key: str = "EMPTY",
                 **kwargs):
        """
        Initialize VLLM model loader.
        
        Args:
            solver_model: Model name for solving problems (default: Llama-3.2-3B-Instruct)
            grader_model: Model name for grading solutions (default: Llama-3.2-8B-Instruct)
            base_url: VLLM server URL (default: http://localhost:8000/v1)
            api_key: API key for VLLM server (default: "EMPTY" for local)
            **kwargs: Additional arguments passed to parent class
        """
        if AsyncOpenAI is None:
            raise ImportError(
                "openai package is required for VLLMModelLoader. "
                "Install with: pip install openai"
            )
            
        super().__init__(solver_model, grader_model, **kwargs)
        
        # Initialize OpenAI-compatible client for VLLM
        self.client = AsyncOpenAI(
            base_url=base_url,
            api_key=api_key
        )
        self.base_url = base_url
    
    async def _call_api(self, 
                       model: str, 
                       messages: List[Dict[str, str]], 
                       temperature: float = 0.0) -> Tuple[Optional[str], str]:
        """
        Make an API call to VLLM server.
        
        Args:
            model: Model name to use
            messages: List of messages in chat format
            temperature: Temperature for generation
            
        Returns:
            Tuple of (response_content, raw_response)
        """
        try:
            # Prepare API call parameters
            api_params = {
                "model": model,
                "messages": messages,
                "temperature": temperature,
                "max_tokens": 4000,
            }
            
            # Only add response_format for models that support it
            # Most local models may not support structured JSON output
            if temperature == 0.0:
                try:
                    api_params["response_format"] = RESPONSE_FORMAT
                except:
                    # If JSON format is not supported, we'll parse manually
                    pass
            
            # Make the API call
            response = await self.client.chat.completions.create(**api_params)
            
            # Extract response content
            content = response.choices[0].message.content or ""
            
            return content, content
            
        except (RateLimitError, APIError, APIConnectionError) as e:
            # Handle various API errors
            error_str = str(e)
            print(f"❌ VLLM API Error: {error_str}")
            
            if "rate" in error_str.lower() or "limit" in error_str.lower():
                sleep_time = 2 + random.random()
                print(f"   ⏰ Rate limited, sleeping {sleep_time:.1f}s")
                await asyncio.sleep(sleep_time)
            
            # Re-raise to trigger retry logic
            raise
            
        except Exception as e:
            print(f"❌ Unexpected error in VLLM API call: {str(e)}")
            raise
    
    def get_model_info(self) -> Dict[str, str]:
        """Get information about the configured models."""
        return {
            "solver_model": self.solver_model,
            "grader_model": self.grader_model,
            "provider": "vllm",
            "base_url": self.base_url
        }
    
    async def health_check(self) -> bool:
        """
        Perform a simple health check to verify VLLM server connectivity.
        
        Returns:
            True if server is accessible, False otherwise
        """
        try:
            # Simple test call
            test_messages = [
                {"role": "user", "content": "Hello, please respond with a simple JSON: {\"status\": \"ok\"}"}
            ]
            
            result, _ = await self._call_api(
                model=self.solver_model,
                messages=test_messages,
                temperature=0.0
            )
            
            if result and ("ok" in result.lower() or "hello" in result.lower()):
                print(f"✅ VLLM API health check passed for {self.solver_model}")
                return True
            else:
                print(f"⚠️ VLLM API health check returned unexpected response")
                return False
                
        except Exception as e:
            print(f"❌ VLLM API health check failed: {str(e)}")
            print(f"   Make sure VLLM server is running at {self.base_url}")
            return False
    
    async def estimate_cost(self, 
                          num_problems: int, 
                          avg_problem_length: int = 1000,
                          avg_solution_length: int = 2000) -> Dict[str, float]:
        """
        Estimate the cost for processing a given number of problems.
        For local VLLM, cost is typically computational (time/energy) rather than monetary.
        
        Args:
            num_problems: Number of problems to process
            avg_problem_length: Average length of problem statements in characters
            avg_solution_length: Average length of solutions in characters
            
        Returns:
            Dictionary with cost estimates (computational cost in arbitrary units)
        """
        # Rough token estimates (1 token ≈ 4 characters for English)
        tokens_per_solve = (avg_problem_length + avg_solution_length) // 4
        tokens_per_grade = (avg_problem_length + avg_solution_length * 2) // 4
        
        # Computational cost estimation (arbitrary units based on model size)
        # Larger models consume more computational resources
        model_costs = {
            "llama-3.2-1b": 1.0,
            "llama-3.2-3b": 2.0, 
            "llama-3.2-8b": 4.0,
            "llama-3.1-8b": 4.0,
            "llama-3.1-70b": 20.0,
            "mistral-7b": 3.0,
            "qwen2.5-7b": 3.0,
        }
        
        def get_model_cost(model: str) -> float:
            model_lower = model.lower()
            for key, cost in model_costs.items():
                if key in model_lower:
                    return cost
            return 3.0  # Default cost for unknown models
        
        # Calculate computational costs
        solver_cost_factor = get_model_cost(self.solver_model)
        grader_cost_factor = get_model_cost(self.grader_model)
        
        solve_cost = tokens_per_solve * num_problems * solver_cost_factor / 1000
        grade_cost = tokens_per_grade * num_problems * grader_cost_factor / 1000
        
        total_cost = solve_cost + grade_cost
        
        return {
            "solve_cost": round(solve_cost, 4),
            "grade_cost": round(grade_cost, 4),
            "total_cost": round(total_cost, 4),
            "cost_per_problem": round(total_cost / num_problems, 6),
            "currency": "computational_units",
            "note": "Local VLLM costs are computational (time/energy) rather than monetary"
        }
    
    async def list_models(self) -> List[str]:
        """
        List available models on the VLLM server.
        
        Returns:
            List of available model names
        """
        try:
            # Try to get models list from VLLM server
            models_response = await self.client.models.list()
            return [model.id for model in models_response.data]
        except Exception as e:
            print(f"⚠️ Could not retrieve models list: {str(e)}")
            return [self.solver_model, self.grader_model]