summaryrefslogtreecommitdiff
path: root/putnam-bench-anon/loader/vllm_local.py
diff options
context:
space:
mode:
Diffstat (limited to 'putnam-bench-anon/loader/vllm_local.py')
-rw-r--r--putnam-bench-anon/loader/vllm_local.py224
1 files changed, 224 insertions, 0 deletions
diff --git a/putnam-bench-anon/loader/vllm_local.py b/putnam-bench-anon/loader/vllm_local.py
new file mode 100644
index 0000000..bc8c4fb
--- /dev/null
+++ b/putnam-bench-anon/loader/vllm_local.py
@@ -0,0 +1,224 @@
+"""
+VLLM local model loader implementation.
+Handles API calls to locally deployed VLLM services with OpenAI-compatible endpoints.
+"""
+
+import asyncio
+import random
+from typing import Dict, List, Tuple, Optional
+
+try:
+ from openai import AsyncOpenAI, RateLimitError, APIError, APIConnectionError
+except ImportError:
+ AsyncOpenAI = None
+ RateLimitError = Exception
+ APIError = Exception
+ APIConnectionError = Exception
+
+from .base import ModelLoader
+from .prompts import RESPONSE_FORMAT
+
+
+class VLLMModelLoader(ModelLoader):
+ """VLLM local model implementation of the ModelLoader."""
+
+ def __init__(self,
+ solver_model: str = "meta-llama/Llama-3.2-3B-Instruct",
+ grader_model: str = "meta-llama/Llama-3.2-8B-Instruct",
+ base_url: str = "http://localhost:8000/v1",
+ api_key: str = "EMPTY",
+ **kwargs):
+ """
+ Initialize VLLM model loader.
+
+ Args:
+ solver_model: Model name for solving problems (default: Llama-3.2-3B-Instruct)
+ grader_model: Model name for grading solutions (default: Llama-3.2-8B-Instruct)
+ base_url: VLLM server URL (default: http://localhost:8000/v1)
+ api_key: API key for VLLM server (default: "EMPTY" for local)
+ **kwargs: Additional arguments passed to parent class
+ """
+ if AsyncOpenAI is None:
+ raise ImportError(
+ "openai package is required for VLLMModelLoader. "
+ "Install with: pip install openai"
+ )
+
+ super().__init__(solver_model, grader_model, **kwargs)
+
+ # Initialize OpenAI-compatible client for VLLM
+ self.client = AsyncOpenAI(
+ base_url=base_url,
+ api_key=api_key
+ )
+ self.base_url = base_url
+
+ async def _call_api(self,
+ model: str,
+ messages: List[Dict[str, str]],
+ temperature: float = 0.0) -> Tuple[Optional[str], str]:
+ """
+ Make an API call to VLLM server.
+
+ Args:
+ model: Model name to use
+ messages: List of messages in chat format
+ temperature: Temperature for generation
+
+ Returns:
+ Tuple of (response_content, raw_response)
+ """
+ try:
+ # Prepare API call parameters
+ api_params = {
+ "model": model,
+ "messages": messages,
+ "temperature": temperature,
+ "max_tokens": 4000,
+ }
+
+ # Only add response_format for models that support it
+ # Most local models may not support structured JSON output
+ if temperature == 0.0:
+ try:
+ api_params["response_format"] = RESPONSE_FORMAT
+ except:
+ # If JSON format is not supported, we'll parse manually
+ pass
+
+ # Make the API call
+ response = await self.client.chat.completions.create(**api_params)
+
+ # Extract response content
+ content = response.choices[0].message.content or ""
+
+ return content, content
+
+ except (RateLimitError, APIError, APIConnectionError) as e:
+ # Handle various API errors
+ error_str = str(e)
+ print(f"❌ VLLM API Error: {error_str}")
+
+ if "rate" in error_str.lower() or "limit" in error_str.lower():
+ sleep_time = 2 + random.random()
+ print(f" ⏰ Rate limited, sleeping {sleep_time:.1f}s")
+ await asyncio.sleep(sleep_time)
+
+ # Re-raise to trigger retry logic
+ raise
+
+ except Exception as e:
+ print(f"❌ Unexpected error in VLLM API call: {str(e)}")
+ raise
+
+ def get_model_info(self) -> Dict[str, str]:
+ """Get information about the configured models."""
+ return {
+ "solver_model": self.solver_model,
+ "grader_model": self.grader_model,
+ "provider": "vllm",
+ "base_url": self.base_url
+ }
+
+ async def health_check(self) -> bool:
+ """
+ Perform a simple health check to verify VLLM server connectivity.
+
+ Returns:
+ True if server is accessible, False otherwise
+ """
+ try:
+ # Simple test call
+ test_messages = [
+ {"role": "user", "content": "Hello, please respond with a simple JSON: {\"status\": \"ok\"}"}
+ ]
+
+ result, _ = await self._call_api(
+ model=self.solver_model,
+ messages=test_messages,
+ temperature=0.0
+ )
+
+ if result and ("ok" in result.lower() or "hello" in result.lower()):
+ print(f"✅ VLLM API health check passed for {self.solver_model}")
+ return True
+ else:
+ print(f"⚠️ VLLM API health check returned unexpected response")
+ return False
+
+ except Exception as e:
+ print(f"❌ VLLM API health check failed: {str(e)}")
+ print(f" Make sure VLLM server is running at {self.base_url}")
+ return False
+
+ async def estimate_cost(self,
+ num_problems: int,
+ avg_problem_length: int = 1000,
+ avg_solution_length: int = 2000) -> Dict[str, float]:
+ """
+ Estimate the cost for processing a given number of problems.
+ For local VLLM, cost is typically computational (time/energy) rather than monetary.
+
+ Args:
+ num_problems: Number of problems to process
+ avg_problem_length: Average length of problem statements in characters
+ avg_solution_length: Average length of solutions in characters
+
+ Returns:
+ Dictionary with cost estimates (computational cost in arbitrary units)
+ """
+ # Rough token estimates (1 token ≈ 4 characters for English)
+ tokens_per_solve = (avg_problem_length + avg_solution_length) // 4
+ tokens_per_grade = (avg_problem_length + avg_solution_length * 2) // 4
+
+ # Computational cost estimation (arbitrary units based on model size)
+ # Larger models consume more computational resources
+ model_costs = {
+ "llama-3.2-1b": 1.0,
+ "llama-3.2-3b": 2.0,
+ "llama-3.2-8b": 4.0,
+ "llama-3.1-8b": 4.0,
+ "llama-3.1-70b": 20.0,
+ "mistral-7b": 3.0,
+ "qwen2.5-7b": 3.0,
+ }
+
+ def get_model_cost(model: str) -> float:
+ model_lower = model.lower()
+ for key, cost in model_costs.items():
+ if key in model_lower:
+ return cost
+ return 3.0 # Default cost for unknown models
+
+ # Calculate computational costs
+ solver_cost_factor = get_model_cost(self.solver_model)
+ grader_cost_factor = get_model_cost(self.grader_model)
+
+ solve_cost = tokens_per_solve * num_problems * solver_cost_factor / 1000
+ grade_cost = tokens_per_grade * num_problems * grader_cost_factor / 1000
+
+ total_cost = solve_cost + grade_cost
+
+ return {
+ "solve_cost": round(solve_cost, 4),
+ "grade_cost": round(grade_cost, 4),
+ "total_cost": round(total_cost, 4),
+ "cost_per_problem": round(total_cost / num_problems, 6),
+ "currency": "computational_units",
+ "note": "Local VLLM costs are computational (time/energy) rather than monetary"
+ }
+
+ async def list_models(self) -> List[str]:
+ """
+ List available models on the VLLM server.
+
+ Returns:
+ List of available model names
+ """
+ try:
+ # Try to get models list from VLLM server
+ models_response = await self.client.models.list()
+ return [model.id for model in models_response.data]
+ except Exception as e:
+ print(f"⚠️ Could not retrieve models list: {str(e)}")
+ return [self.solver_model, self.grader_model]