summaryrefslogtreecommitdiff
path: root/putnam-bench-anon/loader/xai_client.py
blob: 10c4cf4a41fcab56d14901fbbec10774af32a57f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
"""
xAI model loader implementation.
Handles API calls to xAI Grok models using OpenAI-compatible interface.
"""

import os
from typing import Dict, Optional, List, Tuple

from .openai_client import OpenAIModelLoader


class XAIModelLoader(OpenAIModelLoader):
    """xAI implementation using OpenAI-compatible API."""
    
    def __init__(self, 
                 solver_model: str = "grok-3",
                 grader_model: str = "grok-3",
                 api_key: Optional[str] = None,
                 **kwargs):
        """
        Initialize xAI model loader.
        
        Args:
            solver_model: xAI model for solving problems (default: grok-3)
            grader_model: xAI model for grading solutions (default: grok-3)
            api_key: xAI API key (if None, uses XAI_API_KEY environment variable)
            **kwargs: Additional arguments passed to parent class
        """
        # Get API key from parameter or environment
        if api_key is None:
            api_key = os.getenv('XAI_API_KEY')
        
        # Initialize with xAI-specific settings
        super().__init__(
            solver_model=solver_model,
            grader_model=grader_model,
            api_key=api_key,
            base_url="https://api.x.ai/v1",
            **kwargs
        )
    
    async def _call_api(self, 
                       model: str, 
                       messages: List[Dict[str, str]], 
                       temperature: float = 0.0) -> Tuple[Optional[str], str]:
        """
        Make an API call to xAI with proper error handling.
        
        Args:
            model: xAI model name
            messages: List of messages in chat format
            temperature: Temperature for generation
            
        Returns:
            Tuple of (response_content, raw_response)
        """
        try:
            # Call parent's implementation
            return await super()._call_api(model, messages, temperature)
            
        except Exception as e:
            # Replace "OpenAI" with "xAI" in error messages
            error_msg = str(e)
            if "OpenAI API Error" in error_msg:
                error_msg = error_msg.replace("OpenAI API Error", "xAI API Error")
            
            # Log with xAI-specific prefix
            if "RateLimitError" in type(e).__name__:
                print(f"🚫 xAI RateLimitError: {error_msg}")
                raise
            elif "APIError" in type(e).__name__ or "APIConnectionError" in type(e).__name__:
                print(f"❌ xAI API Error: {error_msg}")
                raise
            else:
                print(f"❌ Unexpected error in xAI API call: {error_msg}")
                raise
    
    def get_model_info(self) -> Dict[str, str]:
        """Get information about the configured models."""
        return {
            "solver_model": self.solver_model,
            "grader_model": self.grader_model,
            "provider": "xai",
            "base_url": "https://api.x.ai/v1"
        }
    
    async def health_check(self) -> bool:
        """
        Perform a simple health check to verify xAI API connectivity.
        
        Returns:
            True if API is accessible, False otherwise
        """
        try:
            # Simple test call
            test_messages = [
                {"role": "user", "content": "Hello, please respond with a simple JSON: {\"status\": \"ok\"}"}
            ]
            
            result, _ = await self._call_api(
                model=self.solver_model,
                messages=test_messages,
                temperature=0.0
            )
            
            if result and "ok" in result.lower():
                print(f"✅ xAI API health check passed for {self.solver_model}")
                return True
            else:
                print(f"⚠️ xAI API health check returned unexpected response")
                return False
                
        except Exception as e:
            print(f"❌ xAI API health check failed: {str(e)}")
            return False
    
    async def estimate_cost(self, 
                          num_problems: int, 
                          avg_problem_length: int = 1000,
                          avg_solution_length: int = 2000) -> Dict[str, float]:
        """
        Estimate the cost for processing a given number of problems with xAI models.
        
        Args:
            num_problems: Number of problems to process
            avg_problem_length: Average length of problem statements in characters
            avg_solution_length: Average length of solutions in characters
            
        Returns:
            Dictionary with cost estimates
        """
        # Rough token estimates (1 token ≈ 4 characters for English)
        tokens_per_solve = (avg_problem_length + avg_solution_length) // 4
        tokens_per_grade = (avg_problem_length + avg_solution_length * 2) // 4
        
        # xAI pricing (update with actual pricing when available)
        # These are estimates based on similar model pricing
        pricing = {
            "grok-3": {"input": 0.01, "output": 0.03},  # per 1K tokens (estimated)
            "grok-2": {"input": 0.005, "output": 0.015},  # per 1K tokens (estimated)
        }
        
        def get_model_cost(model: str, input_tokens: int, output_tokens: int) -> float:
            if model not in pricing:
                model = "grok-3"  # Default to grok-3 pricing
            
            input_cost = (input_tokens / 1000) * pricing[model]["input"]
            output_cost = (output_tokens / 1000) * pricing[model]["output"]
            return input_cost + output_cost
        
        # Calculate costs
        solve_cost = get_model_cost(
            self.solver_model, 
            tokens_per_solve * num_problems,
            tokens_per_solve * num_problems // 2  # Assume output is ~50% of input
        )
        
        grade_cost = get_model_cost(
            self.grader_model,
            tokens_per_grade * num_problems,
            tokens_per_grade * num_problems // 3  # Assume output is ~33% of input
        )
        
        total_cost = solve_cost + grade_cost
        
        return {
            "solve_cost": round(solve_cost, 4),
            "grade_cost": round(grade_cost, 4),
            "total_cost": round(total_cost, 4),
            "cost_per_problem": round(total_cost / num_problems, 6),
            "currency": "USD",
            "note": "xAI pricing estimates - update with actual pricing"
        }