diff options
| author | Yuren Hao <yurenh2@illinois.edu> | 2026-04-08 22:06:05 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@illinois.edu> | 2026-04-08 22:06:05 -0500 |
| commit | 05704d0eb2fa59fe727652465b07db40bcb06c38 (patch) | |
| tree | 8904aca836cf552fd1a5ae8c2174e9f91e70bbbc /putnam-bench-anon/loader/cross_provider.py | |
Initial release: GAP framework
- Full pipeline: variant generation, multi-judge verification, evaluation
- Loaders for OpenAI / Anthropic / Google / xAI / OpenRouter / vLLM
- Framework-level mechanism analyses: paired structural overlap, repairability rescue, self-correction probe, cross-model agreement, topic x problem-type interaction
- Unicode -> bare-LaTeX cleaner + audit + spot-check
- Mirrors https://huggingface.co/datasets/blackhao0426/PutnamGAP
Diffstat (limited to 'putnam-bench-anon/loader/cross_provider.py')
| -rw-r--r-- | putnam-bench-anon/loader/cross_provider.py | 155 |
1 files changed, 155 insertions, 0 deletions
diff --git a/putnam-bench-anon/loader/cross_provider.py b/putnam-bench-anon/loader/cross_provider.py new file mode 100644 index 0000000..afd833c --- /dev/null +++ b/putnam-bench-anon/loader/cross_provider.py @@ -0,0 +1,155 @@ +""" +Cross-provider model loader implementation. +Allows using different providers for solving and grading tasks. +""" + +from typing import Dict, Optional, Tuple, Any +from .base import ModelLoader + + +class CrossProviderLoader(ModelLoader): + """Wrapper that allows using different providers for solving and grading.""" + + def __init__(self, + solver_loader: ModelLoader, + grader_loader: Optional[ModelLoader] = None, + **kwargs): + """ + Initialize cross-provider loader. + + Args: + solver_loader: ModelLoader instance for solving problems + grader_loader: ModelLoader instance for grading (if None, uses solver_loader) + **kwargs: Additional arguments passed to parent class + """ + # If no grader loader specified, use the solver loader for both + self.solver_loader = solver_loader + self.grader_loader = grader_loader or solver_loader + + # Initialize parent with combined model info + super().__init__( + solver_model=solver_loader.solver_model, + grader_model=self.grader_loader.grader_model, + **kwargs + ) + + # Track if we're using cross-provider + self.is_cross_provider = grader_loader is not None and grader_loader != solver_loader + + async def _call_api(self, + model: str, + messages: list[Dict[str, str]], + temperature: float = 0.0) -> Tuple[Optional[str], str]: + """ + Route API calls to the appropriate provider based on the model. + + Args: + model: Model name to use + messages: List of messages in chat format + temperature: Temperature for generation + + Returns: + Tuple of (response_content, raw_response) + """ + # Determine which loader to use based on the model + if model == self.solver_model: + return await self.solver_loader._call_api(model, messages, temperature) + elif model == self.grader_model: + return await self.grader_loader._call_api(model, messages, temperature) + else: + # Try to determine based on which loader has the model + if hasattr(self.solver_loader, 'solver_model') and model == self.solver_loader.solver_model: + return await self.solver_loader._call_api(model, messages, temperature) + elif hasattr(self.grader_loader, 'grader_model') and model == self.grader_loader.grader_model: + return await self.grader_loader._call_api(model, messages, temperature) + else: + raise ValueError(f"Model {model} not found in either solver or grader loader") + + def get_model_info(self) -> Dict[str, Any]: + """Get information about the configured models and providers.""" + solver_info = self.solver_loader.get_model_info() + grader_info = self.grader_loader.get_model_info() + + return { + "solver_model": self.solver_model, + "grader_model": self.grader_model, + "solver_provider": solver_info.get("provider", "unknown"), + "grader_provider": grader_info.get("provider", "unknown"), + "is_cross_provider": self.is_cross_provider, + "solver_info": solver_info, + "grader_info": grader_info + } + + async def health_check(self) -> bool: + """ + Perform health checks on both providers. + + Returns: + True if both providers are healthy, False otherwise + """ + print("🔍 Checking solver provider health...") + solver_health = await self.solver_loader.health_check() + + if self.is_cross_provider: + print("🔍 Checking grader provider health...") + grader_health = await self.grader_loader.health_check() + return solver_health and grader_health + else: + return solver_health + + async def estimate_cost(self, + num_problems: int, + avg_problem_length: int = 1000, + avg_solution_length: int = 2000) -> Dict[str, float]: + """ + Estimate costs for both providers. + + Args: + num_problems: Number of problems to process + avg_problem_length: Average length of problem statements in characters + avg_solution_length: Average length of solutions in characters + + Returns: + Dictionary with combined cost estimates + """ + # Get solver costs + solver_costs = await self.solver_loader.estimate_cost( + num_problems, avg_problem_length, avg_solution_length + ) + + if self.is_cross_provider: + # Get grader costs separately + grader_costs = await self.grader_loader.estimate_cost( + num_problems, avg_problem_length, avg_solution_length + ) + + # Combine costs + return { + "solver_cost": solver_costs.get("solve_cost", 0), + "grader_cost": grader_costs.get("grade_cost", 0), + "total_cost": solver_costs.get("solve_cost", 0) + grader_costs.get("grade_cost", 0), + "solver_provider": self.solver_loader.get_model_info().get("provider"), + "grader_provider": self.grader_loader.get_model_info().get("provider"), + "solver_model": self.solver_model, + "grader_model": self.grader_model, + "num_problems": num_problems, + "note": "Cross-provider costs combined" + } + else: + # Single provider costs + return solver_costs + + async def __aenter__(self): + """Async context manager entry.""" + if hasattr(self.solver_loader, '__aenter__'): + await self.solver_loader.__aenter__() + if self.is_cross_provider and hasattr(self.grader_loader, '__aenter__'): + await self.grader_loader.__aenter__() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Async context manager exit.""" + if hasattr(self.solver_loader, '__aexit__'): + await self.solver_loader.__aexit__(exc_type, exc_val, exc_tb) + if self.is_cross_provider and hasattr(self.grader_loader, '__aexit__'): + await self.grader_loader.__aexit__(exc_type, exc_val, exc_tb)
\ No newline at end of file |
