1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
|
"""
xAI model loader implementation.
Handles API calls to xAI Grok models using OpenAI-compatible interface.
"""
import os
from typing import Dict, Optional, List, Tuple
from .openai_client import OpenAIModelLoader
class XAIModelLoader(OpenAIModelLoader):
"""xAI implementation using OpenAI-compatible API."""
def __init__(self,
solver_model: str = "grok-3",
grader_model: str = "grok-3",
api_key: Optional[str] = None,
**kwargs):
"""
Initialize xAI model loader.
Args:
solver_model: xAI model for solving problems (default: grok-3)
grader_model: xAI model for grading solutions (default: grok-3)
api_key: xAI API key (if None, uses XAI_API_KEY environment variable)
**kwargs: Additional arguments passed to parent class
"""
# Get API key from parameter or environment
if api_key is None:
api_key = os.getenv('XAI_API_KEY')
# Initialize with xAI-specific settings
super().__init__(
solver_model=solver_model,
grader_model=grader_model,
api_key=api_key,
base_url="https://api.x.ai/v1",
**kwargs
)
async def _call_api(self,
model: str,
messages: List[Dict[str, str]],
temperature: float = 0.0) -> Tuple[Optional[str], str]:
"""
Make an API call to xAI with proper error handling.
Args:
model: xAI model name
messages: List of messages in chat format
temperature: Temperature for generation
Returns:
Tuple of (response_content, raw_response)
"""
try:
# Call parent's implementation
return await super()._call_api(model, messages, temperature)
except Exception as e:
# Replace "OpenAI" with "xAI" in error messages
error_msg = str(e)
if "OpenAI API Error" in error_msg:
error_msg = error_msg.replace("OpenAI API Error", "xAI API Error")
# Log with xAI-specific prefix
if "RateLimitError" in type(e).__name__:
print(f"🚫 xAI RateLimitError: {error_msg}")
raise
elif "APIError" in type(e).__name__ or "APIConnectionError" in type(e).__name__:
print(f"❌ xAI API Error: {error_msg}")
raise
else:
print(f"❌ Unexpected error in xAI API call: {error_msg}")
raise
def get_model_info(self) -> Dict[str, str]:
"""Get information about the configured models."""
return {
"solver_model": self.solver_model,
"grader_model": self.grader_model,
"provider": "xai",
"base_url": "https://api.x.ai/v1"
}
async def health_check(self) -> bool:
"""
Perform a simple health check to verify xAI API connectivity.
Returns:
True if API is accessible, False otherwise
"""
try:
# Simple test call
test_messages = [
{"role": "user", "content": "Hello, please respond with a simple JSON: {\"status\": \"ok\"}"}
]
result, _ = await self._call_api(
model=self.solver_model,
messages=test_messages,
temperature=0.0
)
if result and "ok" in result.lower():
print(f"✅ xAI API health check passed for {self.solver_model}")
return True
else:
print(f"⚠️ xAI API health check returned unexpected response")
return False
except Exception as e:
print(f"❌ xAI API health check failed: {str(e)}")
return False
async def estimate_cost(self,
num_problems: int,
avg_problem_length: int = 1000,
avg_solution_length: int = 2000) -> Dict[str, float]:
"""
Estimate the cost for processing a given number of problems with xAI models.
Args:
num_problems: Number of problems to process
avg_problem_length: Average length of problem statements in characters
avg_solution_length: Average length of solutions in characters
Returns:
Dictionary with cost estimates
"""
# Rough token estimates (1 token ≈ 4 characters for English)
tokens_per_solve = (avg_problem_length + avg_solution_length) // 4
tokens_per_grade = (avg_problem_length + avg_solution_length * 2) // 4
# xAI pricing (update with actual pricing when available)
# These are estimates based on similar model pricing
pricing = {
"grok-3": {"input": 0.01, "output": 0.03}, # per 1K tokens (estimated)
"grok-2": {"input": 0.005, "output": 0.015}, # per 1K tokens (estimated)
}
def get_model_cost(model: str, input_tokens: int, output_tokens: int) -> float:
if model not in pricing:
model = "grok-3" # Default to grok-3 pricing
input_cost = (input_tokens / 1000) * pricing[model]["input"]
output_cost = (output_tokens / 1000) * pricing[model]["output"]
return input_cost + output_cost
# Calculate costs
solve_cost = get_model_cost(
self.solver_model,
tokens_per_solve * num_problems,
tokens_per_solve * num_problems // 2 # Assume output is ~50% of input
)
grade_cost = get_model_cost(
self.grader_model,
tokens_per_grade * num_problems,
tokens_per_grade * num_problems // 3 # Assume output is ~33% of input
)
total_cost = solve_cost + grade_cost
return {
"solve_cost": round(solve_cost, 4),
"grade_cost": round(grade_cost, 4),
"total_cost": round(total_cost, 4),
"cost_per_problem": round(total_cost / num_problems, 6),
"currency": "USD",
"note": "xAI pricing estimates - update with actual pricing"
}
|