summaryrefslogtreecommitdiff
path: root/putnam-bench-anon/loader/openrouter_client.py
blob: 13cd7fa8a5e88f11629f9d37f6878d4b52503ed4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
"""
OpenRouter model loader implementation.
Handles API calls to OpenRouter service using OpenAI-compatible interface.
OpenRouter provides access to multiple model providers through a single API.
"""

import os
from typing import Dict, Optional, List, Tuple

from .openai_client import OpenAIModelLoader


class OpenRouterModelLoader(OpenAIModelLoader):
    """OpenRouter implementation using OpenAI-compatible API."""
    
    def __init__(self, 
                 solver_model: str = "openai/gpt-4o",
                 grader_model: str = "openai/gpt-4o",
                 api_key: Optional[str] = None,
                 site_url: Optional[str] = None,
                 site_name: Optional[str] = None,
                 **kwargs):
        """
        Initialize OpenRouter model loader.
        
        Args:
            solver_model: Model for solving problems (default: openai/gpt-4o)
                        Format should be "provider/model-name" (e.g., "openai/gpt-4o", "anthropic/claude-3-opus")
            grader_model: Model for grading solutions (default: openai/gpt-4o)
                        Format should be "provider/model-name"
            api_key: OpenRouter API key (if None, uses OPENROUTER_API_KEY environment variable)
            site_url: Optional site URL for rankings on openrouter.ai
            site_name: Optional site name for rankings on openrouter.ai
            **kwargs: Additional arguments passed to parent class
        """
        # Get API key from parameter or environment
        if api_key is None:
            api_key = os.getenv('OPENROUTER_API_KEY')
            if not api_key:
                raise ValueError("OpenRouter API key not provided. Set OPENROUTER_API_KEY environment variable or pass api_key parameter")
        
        # Store site information for headers
        self.site_url = site_url
        self.site_name = site_name
        
        # Initialize with OpenRouter-specific settings
        super().__init__(
            solver_model=solver_model,
            grader_model=grader_model,
            api_key=api_key,
            base_url="https://openrouter.ai/api/v1",
            **kwargs
        )
    
    async def _call_api(self, 
                       model: str, 
                       messages: List[Dict[str, str]], 
                       temperature: float = 0.0) -> Tuple[Optional[str], str]:
        """
        Make an API call to OpenRouter with proper headers.
        
        Args:
            model: Model name in format "provider/model-name"
            messages: List of messages in chat format
            temperature: Temperature for generation
            
        Returns:
            Tuple of (response_content, raw_response)
        """
        try:
            # Prepare extra headers for OpenRouter
            extra_headers = {}
            if self.site_url:
                extra_headers["HTTP-Referer"] = self.site_url
            if self.site_name:
                extra_headers["X-Title"] = self.site_name
            
            # Prepare API call parameters
            api_params = {
                "model": model,
                "messages": messages,
                "temperature": temperature,
                # Set max_tokens to avoid truncation, especially for models like Gemini
                # 32000 is a reasonable default that works for most models
                "max_tokens": 32000,
            }
            
            # Add response_format for all models - OpenRouter handles compatibility
            from .prompts import RESPONSE_FORMAT
            api_params["response_format"] = RESPONSE_FORMAT
            
            # Make the API call with extra headers
            if extra_headers:
                response = await self.client.chat.completions.create(
                    **api_params,
                    extra_headers=extra_headers
                )
            else:
                response = await self.client.chat.completions.create(**api_params)
            
            # Check if response is valid
            if not response or not response.choices or len(response.choices) == 0:
                raise ValueError("Empty response from OpenRouter API")
            
            content = response.choices[0].message.content
            if not content:
                raise ValueError("Empty content in OpenRouter API response")
                
            return content, content
            
        except Exception as e:
            # Replace "OpenAI" with "OpenRouter" in error messages
            error_msg = str(e)
            if "OpenAI API Error" in error_msg:
                error_msg = error_msg.replace("OpenAI API Error", "OpenRouter API Error")
            
            # Log with OpenRouter-specific prefix
            if "RateLimitError" in type(e).__name__:
                print(f"🚫 OpenRouter RateLimitError: {error_msg}")
                raise
            elif "APIError" in type(e).__name__ or "APIConnectionError" in type(e).__name__:
                print(f"❌ OpenRouter API Error: {error_msg}")
                raise
            else:
                print(f"❌ Unexpected error in OpenRouter API call: {error_msg}")
                raise
    
    def get_model_info(self) -> Dict[str, str]:
        """Get information about the configured models."""
        return {
            "solver_model": self.solver_model,
            "grader_model": self.grader_model,
            "provider": "openrouter",
            "base_url": "https://openrouter.ai/api/v1"
        }
    
    async def health_check(self) -> bool:
        """
        Perform a simple health check to verify OpenRouter API connectivity.
        
        Returns:
            True if API is accessible, False otherwise
        """
        try:
            # Simple test call
            test_messages = [
                {"role": "user", "content": "Hello, please respond with a simple JSON: {\"status\": \"ok\"}"}
            ]
            
            result, _ = await self._call_api(
                self.solver_model,
                test_messages,
                temperature=0.0
            )
            
            return result is not None
            
        except Exception as e:
            print(f"❌ OpenRouter health check failed: {e}")
            return False
    
    @staticmethod
    def get_available_models() -> List[str]:
        """
        Get a list of commonly available models on OpenRouter.
        Note: This is not exhaustive. Check https://openrouter.ai/models for full list.
        
        Returns:
            List of model identifiers in "provider/model-name" format
        """
        return [
            # OpenAI models
            "openai/gpt-4o",
            "openai/gpt-4o-mini", 
            "openai/gpt-4-turbo",
            "openai/gpt-3.5-turbo",
            "openai/o1-preview",
            "openai/o1-mini",
            
            # Anthropic models
            "anthropic/claude-3-opus",
            "anthropic/claude-3-sonnet",
            "anthropic/claude-3-haiku",
            "anthropic/claude-2.1",
            "anthropic/claude-2",
            
            # Google models
            "google/gemini-pro",
            "google/gemini-pro-vision",
            "google/palm-2-codechat-bison",
            "google/palm-2-chat-bison",
            
            # Meta models
            "meta-llama/llama-3-70b-instruct",
            "meta-llama/llama-3-8b-instruct",
            "meta-llama/codellama-70b-instruct",
            
            # Mistral models
            "mistralai/mistral-large",
            "mistralai/mistral-medium",
            "mistralai/mistral-small",
            "mistralai/mistral-7b-instruct",
            "mistralai/mixtral-8x7b-instruct",
            
            # Other notable models
            "cohere/command-r-plus",
            "cohere/command-r",
            "databricks/dbrx-instruct",
            "deepseek/deepseek-coder",
            "deepseek/deepseek-chat",
            "qwen/qwen-2-72b-instruct",
            "qwen/qwen-1.5-110b-chat",
        ]