diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-04-08 22:06:05 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-04-08 22:06:05 -0500 |
| commit | 05704d0eb2fa59fe727652465b07db40bcb06c38 (patch) | |
| tree | 8904aca836cf552fd1a5ae8c2174e9f91e70bbbc /putnam-bench-anon/loader/vllm_local.py | |
Initial release: GAP framework
- Full pipeline: variant generation, multi-judge verification, evaluation
- Loaders for OpenAI / Anthropic / Google / xAI / OpenRouter / vLLM
- Framework-level mechanism analyses: paired structural overlap, repairability rescue, self-correction probe, cross-model agreement, topic x problem-type interaction
- Unicode -> bare-LaTeX cleaner + audit + spot-check
- Mirrors https://huggingface.co/datasets/blackhao0426/PutnamGAP
Diffstat (limited to 'putnam-bench-anon/loader/vllm_local.py')
| -rw-r--r-- | putnam-bench-anon/loader/vllm_local.py | 224 |
1 files changed, 224 insertions, 0 deletions
diff --git a/putnam-bench-anon/loader/vllm_local.py b/putnam-bench-anon/loader/vllm_local.py new file mode 100644 index 0000000..bc8c4fb --- /dev/null +++ b/putnam-bench-anon/loader/vllm_local.py @@ -0,0 +1,224 @@ +""" +VLLM local model loader implementation. +Handles API calls to locally deployed VLLM services with OpenAI-compatible endpoints. +""" + +import asyncio +import random +from typing import Dict, List, Tuple, Optional + +try: + from openai import AsyncOpenAI, RateLimitError, APIError, APIConnectionError +except ImportError: + AsyncOpenAI = None + RateLimitError = Exception + APIError = Exception + APIConnectionError = Exception + +from .base import ModelLoader +from .prompts import RESPONSE_FORMAT + + +class VLLMModelLoader(ModelLoader): + """VLLM local model implementation of the ModelLoader.""" + + def __init__(self, + solver_model: str = "meta-llama/Llama-3.2-3B-Instruct", + grader_model: str = "meta-llama/Llama-3.2-8B-Instruct", + base_url: str = "http://localhost:8000/v1", + api_key: str = "EMPTY", + **kwargs): + """ + Initialize VLLM model loader. + + Args: + solver_model: Model name for solving problems (default: Llama-3.2-3B-Instruct) + grader_model: Model name for grading solutions (default: Llama-3.2-8B-Instruct) + base_url: VLLM server URL (default: http://localhost:8000/v1) + api_key: API key for VLLM server (default: "EMPTY" for local) + **kwargs: Additional arguments passed to parent class + """ + if AsyncOpenAI is None: + raise ImportError( + "openai package is required for VLLMModelLoader. " + "Install with: pip install openai" + ) + + super().__init__(solver_model, grader_model, **kwargs) + + # Initialize OpenAI-compatible client for VLLM + self.client = AsyncOpenAI( + base_url=base_url, + api_key=api_key + ) + self.base_url = base_url + + async def _call_api(self, + model: str, + messages: List[Dict[str, str]], + temperature: float = 0.0) -> Tuple[Optional[str], str]: + """ + Make an API call to VLLM server. + + Args: + model: Model name to use + messages: List of messages in chat format + temperature: Temperature for generation + + Returns: + Tuple of (response_content, raw_response) + """ + try: + # Prepare API call parameters + api_params = { + "model": model, + "messages": messages, + "temperature": temperature, + "max_tokens": 4000, + } + + # Only add response_format for models that support it + # Most local models may not support structured JSON output + if temperature == 0.0: + try: + api_params["response_format"] = RESPONSE_FORMAT + except: + # If JSON format is not supported, we'll parse manually + pass + + # Make the API call + response = await self.client.chat.completions.create(**api_params) + + # Extract response content + content = response.choices[0].message.content or "" + + return content, content + + except (RateLimitError, APIError, APIConnectionError) as e: + # Handle various API errors + error_str = str(e) + print(f"❌ VLLM API Error: {error_str}") + + if "rate" in error_str.lower() or "limit" in error_str.lower(): + sleep_time = 2 + random.random() + print(f" ⏰ Rate limited, sleeping {sleep_time:.1f}s") + await asyncio.sleep(sleep_time) + + # Re-raise to trigger retry logic + raise + + except Exception as e: + print(f"❌ Unexpected error in VLLM API call: {str(e)}") + raise + + def get_model_info(self) -> Dict[str, str]: + """Get information about the configured models.""" + return { + "solver_model": self.solver_model, + "grader_model": self.grader_model, + "provider": "vllm", + "base_url": self.base_url + } + + async def health_check(self) -> bool: + """ + Perform a simple health check to verify VLLM server connectivity. + + Returns: + True if server is accessible, False otherwise + """ + try: + # Simple test call + test_messages = [ + {"role": "user", "content": "Hello, please respond with a simple JSON: {\"status\": \"ok\"}"} + ] + + result, _ = await self._call_api( + model=self.solver_model, + messages=test_messages, + temperature=0.0 + ) + + if result and ("ok" in result.lower() or "hello" in result.lower()): + print(f"✅ VLLM API health check passed for {self.solver_model}") + return True + else: + print(f"⚠️ VLLM API health check returned unexpected response") + return False + + except Exception as e: + print(f"❌ VLLM API health check failed: {str(e)}") + print(f" Make sure VLLM server is running at {self.base_url}") + return False + + async def estimate_cost(self, + num_problems: int, + avg_problem_length: int = 1000, + avg_solution_length: int = 2000) -> Dict[str, float]: + """ + Estimate the cost for processing a given number of problems. + For local VLLM, cost is typically computational (time/energy) rather than monetary. + + Args: + num_problems: Number of problems to process + avg_problem_length: Average length of problem statements in characters + avg_solution_length: Average length of solutions in characters + + Returns: + Dictionary with cost estimates (computational cost in arbitrary units) + """ + # Rough token estimates (1 token ≈ 4 characters for English) + tokens_per_solve = (avg_problem_length + avg_solution_length) // 4 + tokens_per_grade = (avg_problem_length + avg_solution_length * 2) // 4 + + # Computational cost estimation (arbitrary units based on model size) + # Larger models consume more computational resources + model_costs = { + "llama-3.2-1b": 1.0, + "llama-3.2-3b": 2.0, + "llama-3.2-8b": 4.0, + "llama-3.1-8b": 4.0, + "llama-3.1-70b": 20.0, + "mistral-7b": 3.0, + "qwen2.5-7b": 3.0, + } + + def get_model_cost(model: str) -> float: + model_lower = model.lower() + for key, cost in model_costs.items(): + if key in model_lower: + return cost + return 3.0 # Default cost for unknown models + + # Calculate computational costs + solver_cost_factor = get_model_cost(self.solver_model) + grader_cost_factor = get_model_cost(self.grader_model) + + solve_cost = tokens_per_solve * num_problems * solver_cost_factor / 1000 + grade_cost = tokens_per_grade * num_problems * grader_cost_factor / 1000 + + total_cost = solve_cost + grade_cost + + return { + "solve_cost": round(solve_cost, 4), + "grade_cost": round(grade_cost, 4), + "total_cost": round(total_cost, 4), + "cost_per_problem": round(total_cost / num_problems, 6), + "currency": "computational_units", + "note": "Local VLLM costs are computational (time/energy) rather than monetary" + } + + async def list_models(self) -> List[str]: + """ + List available models on the VLLM server. + + Returns: + List of available model names + """ + try: + # Try to get models list from VLLM server + models_response = await self.client.models.list() + return [model.id for model in models_response.data] + except Exception as e: + print(f"⚠️ Could not retrieve models list: {str(e)}") + return [self.solver_model, self.grader_model] |
