summaryrefslogtreecommitdiff
path: root/putnam-bench-anon/loader/cross_provider.py
diff options
context:
space:
mode:
Diffstat (limited to 'putnam-bench-anon/loader/cross_provider.py')
-rw-r--r--putnam-bench-anon/loader/cross_provider.py155
1 files changed, 155 insertions, 0 deletions
diff --git a/putnam-bench-anon/loader/cross_provider.py b/putnam-bench-anon/loader/cross_provider.py
new file mode 100644
index 0000000..afd833c
--- /dev/null
+++ b/putnam-bench-anon/loader/cross_provider.py
@@ -0,0 +1,155 @@
+"""
+Cross-provider model loader implementation.
+Allows using different providers for solving and grading tasks.
+"""
+
+from typing import Dict, Optional, Tuple, Any
+from .base import ModelLoader
+
+
+class CrossProviderLoader(ModelLoader):
+ """Wrapper that allows using different providers for solving and grading."""
+
+ def __init__(self,
+ solver_loader: ModelLoader,
+ grader_loader: Optional[ModelLoader] = None,
+ **kwargs):
+ """
+ Initialize cross-provider loader.
+
+ Args:
+ solver_loader: ModelLoader instance for solving problems
+ grader_loader: ModelLoader instance for grading (if None, uses solver_loader)
+ **kwargs: Additional arguments passed to parent class
+ """
+ # If no grader loader specified, use the solver loader for both
+ self.solver_loader = solver_loader
+ self.grader_loader = grader_loader or solver_loader
+
+ # Initialize parent with combined model info
+ super().__init__(
+ solver_model=solver_loader.solver_model,
+ grader_model=self.grader_loader.grader_model,
+ **kwargs
+ )
+
+ # Track if we're using cross-provider
+ self.is_cross_provider = grader_loader is not None and grader_loader != solver_loader
+
+ async def _call_api(self,
+ model: str,
+ messages: list[Dict[str, str]],
+ temperature: float = 0.0) -> Tuple[Optional[str], str]:
+ """
+ Route API calls to the appropriate provider based on the model.
+
+ Args:
+ model: Model name to use
+ messages: List of messages in chat format
+ temperature: Temperature for generation
+
+ Returns:
+ Tuple of (response_content, raw_response)
+ """
+ # Determine which loader to use based on the model
+ if model == self.solver_model:
+ return await self.solver_loader._call_api(model, messages, temperature)
+ elif model == self.grader_model:
+ return await self.grader_loader._call_api(model, messages, temperature)
+ else:
+ # Try to determine based on which loader has the model
+ if hasattr(self.solver_loader, 'solver_model') and model == self.solver_loader.solver_model:
+ return await self.solver_loader._call_api(model, messages, temperature)
+ elif hasattr(self.grader_loader, 'grader_model') and model == self.grader_loader.grader_model:
+ return await self.grader_loader._call_api(model, messages, temperature)
+ else:
+ raise ValueError(f"Model {model} not found in either solver or grader loader")
+
+ def get_model_info(self) -> Dict[str, Any]:
+ """Get information about the configured models and providers."""
+ solver_info = self.solver_loader.get_model_info()
+ grader_info = self.grader_loader.get_model_info()
+
+ return {
+ "solver_model": self.solver_model,
+ "grader_model": self.grader_model,
+ "solver_provider": solver_info.get("provider", "unknown"),
+ "grader_provider": grader_info.get("provider", "unknown"),
+ "is_cross_provider": self.is_cross_provider,
+ "solver_info": solver_info,
+ "grader_info": grader_info
+ }
+
+ async def health_check(self) -> bool:
+ """
+ Perform health checks on both providers.
+
+ Returns:
+ True if both providers are healthy, False otherwise
+ """
+ print("🔍 Checking solver provider health...")
+ solver_health = await self.solver_loader.health_check()
+
+ if self.is_cross_provider:
+ print("🔍 Checking grader provider health...")
+ grader_health = await self.grader_loader.health_check()
+ return solver_health and grader_health
+ else:
+ return solver_health
+
+ async def estimate_cost(self,
+ num_problems: int,
+ avg_problem_length: int = 1000,
+ avg_solution_length: int = 2000) -> Dict[str, float]:
+ """
+ Estimate costs for both providers.
+
+ Args:
+ num_problems: Number of problems to process
+ avg_problem_length: Average length of problem statements in characters
+ avg_solution_length: Average length of solutions in characters
+
+ Returns:
+ Dictionary with combined cost estimates
+ """
+ # Get solver costs
+ solver_costs = await self.solver_loader.estimate_cost(
+ num_problems, avg_problem_length, avg_solution_length
+ )
+
+ if self.is_cross_provider:
+ # Get grader costs separately
+ grader_costs = await self.grader_loader.estimate_cost(
+ num_problems, avg_problem_length, avg_solution_length
+ )
+
+ # Combine costs
+ return {
+ "solver_cost": solver_costs.get("solve_cost", 0),
+ "grader_cost": grader_costs.get("grade_cost", 0),
+ "total_cost": solver_costs.get("solve_cost", 0) + grader_costs.get("grade_cost", 0),
+ "solver_provider": self.solver_loader.get_model_info().get("provider"),
+ "grader_provider": self.grader_loader.get_model_info().get("provider"),
+ "solver_model": self.solver_model,
+ "grader_model": self.grader_model,
+ "num_problems": num_problems,
+ "note": "Cross-provider costs combined"
+ }
+ else:
+ # Single provider costs
+ return solver_costs
+
+ async def __aenter__(self):
+ """Async context manager entry."""
+ if hasattr(self.solver_loader, '__aenter__'):
+ await self.solver_loader.__aenter__()
+ if self.is_cross_provider and hasattr(self.grader_loader, '__aenter__'):
+ await self.grader_loader.__aenter__()
+ return self
+
+ async def __aexit__(self, exc_type, exc_val, exc_tb):
+ """Async context manager exit."""
+ if hasattr(self.solver_loader, '__aexit__'):
+ await self.solver_loader.__aexit__(exc_type, exc_val, exc_tb)
+ if self.is_cross_provider and hasattr(self.grader_loader, '__aexit__'):
+ await self.grader_loader.__aexit__(exc_type, exc_val, exc_tb) \ No newline at end of file