summaryrefslogtreecommitdiff
path: root/putnam-bench-anon/loader/vllm_direct.py
blob: b35d99bb113e4abb5b22c518520425e173d06e3f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
"""
VLLM direct Python API model loader implementation.
Uses VLLM's Python API directly without requiring a separate server process.
"""

import asyncio
import json
import re
from typing import Dict, List, Tuple, Optional, Any
import torch

try:
    from vllm import LLM, SamplingParams
    VLLM_AVAILABLE = True
except ImportError:
    LLM = None
    SamplingParams = None
    VLLM_AVAILABLE = False

from .base import ModelLoader
from .prompts import SOLVER_SYSTEM_PROMPT, PROOF_GRADER_SYSTEM_PROMPT


class VLLMDirectModelLoader(ModelLoader):
    """VLLM direct Python API implementation of the ModelLoader."""
    
    def __init__(self, 
                 solver_model: str = "gpt2",
                 grader_model: str = "gpt2",
                 max_model_len: int = 512,
                 gpu_memory_utilization: float = 0.4,
                 device: str = "auto",
                 **kwargs):
        """
        Initialize VLLM direct model loader.
        
        Args:
            solver_model: Model name for solving problems (default: gpt2)
            grader_model: Model name for grading solutions (default: gpt2)
            max_model_len: Maximum sequence length (default: 512 for testing)
            gpu_memory_utilization: GPU memory utilization ratio (default: 0.4)
            device: Device to use ('auto', 'cuda', 'cpu')
            **kwargs: Additional arguments passed to parent class
        """
        if not VLLM_AVAILABLE:
            raise ImportError(
                "vllm package is required for VLLMDirectModelLoader. "
                "Install with: pip install vllm"
            )
            
        super().__init__(solver_model, grader_model, **kwargs)
        
        self.max_model_len = max_model_len
        self.gpu_memory_utilization = gpu_memory_utilization
        self.device = device
        
        # Model instances (lazy loaded)
        self._solver_llm = None
        self._grader_llm = None
        self._loaded_models = []
        
        print(f"🔧 VLLM Direct loader initialized")
        print(f"   Device: {device}")
        print(f"   Max length: {max_model_len}")
        print(f"   GPU utilization: {gpu_memory_utilization}")
    
    def _get_vllm_config(self, model: str) -> Dict[str, Any]:
        """Get VLLM configuration for a model."""
        return {
            "model": model,
            "max_model_len": self.max_model_len,
            "gpu_memory_utilization": self.gpu_memory_utilization,
            "trust_remote_code": False,
            "enforce_eager": True,  # Disable graph optimization for faster startup
        }
    
    async def _load_model(self, model: str, purpose: str) -> LLM:
        """Load a VLLM model instance."""
        print(f"📥 Loading {purpose} model: {model}")
        
        try:
            config = self._get_vllm_config(model)
            llm = LLM(**config)
            
            self._loaded_models.append(model)
            print(f"✅ Model loaded successfully: {model}")
            return llm
            
        except Exception as e:
            print(f"❌ Failed to load model {model}: {e}")
            raise
    
    async def _get_solver_model(self) -> LLM:
        """Get or load the solver model."""
        if self._solver_llm is None:
            self._solver_llm = await self._load_model(self.solver_model, "solver")
        return self._solver_llm
    
    async def _get_grader_model(self) -> LLM:
        """Get or load the grader model."""
        if self._grader_llm is None:
            # If solver and grader use the same model, reuse the instance
            if self.solver_model == self.grader_model and self._solver_llm is not None:
                print(f"♻️ Reusing solver model for grading: {self.grader_model}")
                self._grader_llm = self._solver_llm
            else:
                self._grader_llm = await self._load_model(self.grader_model, "grader")
        return self._grader_llm
    
    def _format_messages_as_prompt(self, messages: List[Dict[str, str]]) -> str:
        """Convert chat messages to a single prompt string."""
        prompt_parts = []
        
        for message in messages:
            role = message["role"]
            content = message["content"]
            
            if role == "system":
                prompt_parts.append(f"System: {content}")
            elif role == "user":
                prompt_parts.append(f"User: {content}")
            elif role == "assistant":
                prompt_parts.append(f"Assistant: {content}")
        
        # Add final assistant prompt
        if not messages[-1]["role"] == "assistant":
            prompt_parts.append("Assistant:")
        
        return "\n\n".join(prompt_parts)
    
    def _extract_json_from_response(self, response: str) -> Optional[Dict]:
        """Extract JSON from model response."""
        try:
            # Try to find JSON in the response
            json_match = re.search(r'\{.*\}', response, re.DOTALL)
            if json_match:
                json_str = json_match.group()
                return json.loads(json_str)
            
            # If no JSON found, try to parse the entire response
            return json.loads(response.strip())
            
        except json.JSONDecodeError:
            # If JSON parsing fails, return None
            return None
    
    async def _call_api(self, 
                       model: str, 
                       messages: List[Dict[str, str]], 
                       temperature: float = 0.0) -> Tuple[Optional[str], str]:
        """
        Make an inference call using VLLM.
        
        Args:
            model: Model name to use
            messages: List of messages in chat format
            temperature: Temperature for generation
            
        Returns:
            Tuple of (response_content, raw_response)
        """
        try:
            # Get the appropriate model instance
            if model == self.solver_model:
                llm = await self._get_solver_model()
            elif model == self.grader_model:
                llm = await self._get_grader_model()
            else:
                raise ValueError(f"Unknown model: {model}")
            
            # Convert messages to prompt
            prompt = self._format_messages_as_prompt(messages)
            
            # Set up sampling parameters
            sampling_params = SamplingParams(
                temperature=temperature,
                top_p=0.95,
                max_tokens=500,  # Reasonable limit for responses
                stop=["\nUser:", "\nSystem:"]  # Stop at new conversation turns
            )
            
            # Generate response
            outputs = llm.generate([prompt], sampling_params)
            
            if outputs and len(outputs) > 0:
                generated_text = outputs[0].outputs[0].text
                return generated_text.strip(), generated_text
            else:
                return None, ""
            
        except Exception as e:
            print(f"❌ VLLM inference error: {str(e)}")
            raise
    
    def get_model_info(self) -> Dict[str, str]:
        """Get information about the configured models."""
        return {
            "solver_model": self.solver_model,
            "grader_model": self.grader_model,
            "provider": "vllm_direct",
            "device": self.device,
            "loaded_models": self._loaded_models
        }
    
    async def health_check(self) -> bool:
        """
        Perform a simple health check to verify VLLM functionality.
        
        Returns:
            True if models can be loaded and generate text, False otherwise
        """
        try:
            print(f"🔍 VLLM health check starting...")
            
            # Try to load and use the solver model
            test_messages = [
                {"role": "user", "content": "Hello! Please respond with 'Health check OK'."}
            ]
            
            result, _ = await self._call_api(
                model=self.solver_model,
                messages=test_messages,
                temperature=0.1
            )
            
            if result and len(result) > 0:
                print(f"✅ VLLM health check passed for {self.solver_model}")
                print(f"   Response: {result[:50]}...")
                return True
            else:
                print(f"❌ VLLM health check failed: empty response")
                return False
                
        except Exception as e:
            print(f"❌ VLLM health check failed: {str(e)}")
            return False
    
    async def estimate_cost(self, 
                          num_problems: int, 
                          avg_problem_length: int = 1000,
                          avg_solution_length: int = 2000) -> Dict[str, float]:
        """
        Estimate the cost for processing a given number of problems.
        For direct VLLM, cost is computational (time/energy).
        
        Args:
            num_problems: Number of problems to process
            avg_problem_length: Average length of problem statements in characters
            avg_solution_length: Average length of solutions in characters
            
        Returns:
            Dictionary with cost estimates
        """
        # Token estimates (1 token ≈ 4 characters)
        tokens_per_solve = (avg_problem_length + avg_solution_length) // 4
        tokens_per_grade = (avg_problem_length + avg_solution_length * 2) // 4
        
        # Model size cost factors (based on parameter count)
        model_costs = {
            "gpt2": 1.0,      # 124M params
            "distilgpt2": 0.5, # 82M params
            "microsoft/dialo": 1.2,  # DialoGPT variants
            "tinyllama": 2.0,  # 1.1B params
        }
        
        def get_model_cost(model: str) -> float:
            model_lower = model.lower()
            for key, cost in model_costs.items():
                if key in model_lower:
                    return cost
            return 1.5  # Default cost
        
        solver_cost_factor = get_model_cost(self.solver_model)
        grader_cost_factor = get_model_cost(self.grader_model)
        
        # Computational cost estimation (arbitrary units)
        solve_cost = tokens_per_solve * num_problems * solver_cost_factor / 10000
        grade_cost = tokens_per_grade * num_problems * grader_cost_factor / 10000
        
        total_cost = solve_cost + grade_cost
        
        return {
            "solve_cost": round(solve_cost, 4),
            "grade_cost": round(grade_cost, 4),
            "total_cost": round(total_cost, 4),
            "cost_per_problem": round(total_cost / num_problems, 6),
            "currency": "computational_units",
            "note": "Direct VLLM costs are computational (GPU time/energy)"
        }
    
    async def unload_all_models(self):
        """Unload all loaded models to free GPU memory."""
        try:
            print("🗑️ Unloading VLLM models...")
            
            # Clean up model instances
            if self._solver_llm is not None:
                del self._solver_llm
                self._solver_llm = None
            
            if self._grader_llm is not None and self._grader_llm != self._solver_llm:
                del self._grader_llm
                self._grader_llm = None
            
            # Clear CUDA cache
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            
            self._loaded_models.clear()
            print("✅ Models unloaded successfully")
            
        except Exception as e:
            print(f"⚠️ Error during model cleanup: {e}")