1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
|
"""
Cross-provider model loader implementation.
Allows using different providers for solving and grading tasks.
"""
from typing import Dict, Optional, Tuple, Any
from .base import ModelLoader
class CrossProviderLoader(ModelLoader):
"""Wrapper that allows using different providers for solving and grading."""
def __init__(self,
solver_loader: ModelLoader,
grader_loader: Optional[ModelLoader] = None,
**kwargs):
"""
Initialize cross-provider loader.
Args:
solver_loader: ModelLoader instance for solving problems
grader_loader: ModelLoader instance for grading (if None, uses solver_loader)
**kwargs: Additional arguments passed to parent class
"""
# If no grader loader specified, use the solver loader for both
self.solver_loader = solver_loader
self.grader_loader = grader_loader or solver_loader
# Initialize parent with combined model info
super().__init__(
solver_model=solver_loader.solver_model,
grader_model=self.grader_loader.grader_model,
**kwargs
)
# Track if we're using cross-provider
self.is_cross_provider = grader_loader is not None and grader_loader != solver_loader
async def _call_api(self,
model: str,
messages: list[Dict[str, str]],
temperature: float = 0.0) -> Tuple[Optional[str], str]:
"""
Route API calls to the appropriate provider based on the model.
Args:
model: Model name to use
messages: List of messages in chat format
temperature: Temperature for generation
Returns:
Tuple of (response_content, raw_response)
"""
# Determine which loader to use based on the model
if model == self.solver_model:
return await self.solver_loader._call_api(model, messages, temperature)
elif model == self.grader_model:
return await self.grader_loader._call_api(model, messages, temperature)
else:
# Try to determine based on which loader has the model
if hasattr(self.solver_loader, 'solver_model') and model == self.solver_loader.solver_model:
return await self.solver_loader._call_api(model, messages, temperature)
elif hasattr(self.grader_loader, 'grader_model') and model == self.grader_loader.grader_model:
return await self.grader_loader._call_api(model, messages, temperature)
else:
raise ValueError(f"Model {model} not found in either solver or grader loader")
def get_model_info(self) -> Dict[str, Any]:
"""Get information about the configured models and providers."""
solver_info = self.solver_loader.get_model_info()
grader_info = self.grader_loader.get_model_info()
return {
"solver_model": self.solver_model,
"grader_model": self.grader_model,
"solver_provider": solver_info.get("provider", "unknown"),
"grader_provider": grader_info.get("provider", "unknown"),
"is_cross_provider": self.is_cross_provider,
"solver_info": solver_info,
"grader_info": grader_info
}
async def health_check(self) -> bool:
"""
Perform health checks on both providers.
Returns:
True if both providers are healthy, False otherwise
"""
print("🔍 Checking solver provider health...")
solver_health = await self.solver_loader.health_check()
if self.is_cross_provider:
print("🔍 Checking grader provider health...")
grader_health = await self.grader_loader.health_check()
return solver_health and grader_health
else:
return solver_health
async def estimate_cost(self,
num_problems: int,
avg_problem_length: int = 1000,
avg_solution_length: int = 2000) -> Dict[str, float]:
"""
Estimate costs for both providers.
Args:
num_problems: Number of problems to process
avg_problem_length: Average length of problem statements in characters
avg_solution_length: Average length of solutions in characters
Returns:
Dictionary with combined cost estimates
"""
# Get solver costs
solver_costs = await self.solver_loader.estimate_cost(
num_problems, avg_problem_length, avg_solution_length
)
if self.is_cross_provider:
# Get grader costs separately
grader_costs = await self.grader_loader.estimate_cost(
num_problems, avg_problem_length, avg_solution_length
)
# Combine costs
return {
"solver_cost": solver_costs.get("solve_cost", 0),
"grader_cost": grader_costs.get("grade_cost", 0),
"total_cost": solver_costs.get("solve_cost", 0) + grader_costs.get("grade_cost", 0),
"solver_provider": self.solver_loader.get_model_info().get("provider"),
"grader_provider": self.grader_loader.get_model_info().get("provider"),
"solver_model": self.solver_model,
"grader_model": self.grader_model,
"num_problems": num_problems,
"note": "Cross-provider costs combined"
}
else:
# Single provider costs
return solver_costs
async def __aenter__(self):
"""Async context manager entry."""
if hasattr(self.solver_loader, '__aenter__'):
await self.solver_loader.__aenter__()
if self.is_cross_provider and hasattr(self.grader_loader, '__aenter__'):
await self.grader_loader.__aenter__()
return self
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async context manager exit."""
if hasattr(self.solver_loader, '__aexit__'):
await self.solver_loader.__aexit__(exc_type, exc_val, exc_tb)
if self.is_cross_provider and hasattr(self.grader_loader, '__aexit__'):
await self.grader_loader.__aexit__(exc_type, exc_val, exc_tb)
|