summaryrefslogtreecommitdiff
path: root/putnam-bench-anon/loader/cross_provider.py
blob: afd833c484cb93f72faa501e42cb71851cd1cf7b (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""
Cross-provider model loader implementation.
Allows using different providers for solving and grading tasks.
"""

from typing import Dict, Optional, Tuple, Any
from .base import ModelLoader


class CrossProviderLoader(ModelLoader):
    """Wrapper that allows using different providers for solving and grading."""
    
    def __init__(self, 
                 solver_loader: ModelLoader,
                 grader_loader: Optional[ModelLoader] = None,
                 **kwargs):
        """
        Initialize cross-provider loader.
        
        Args:
            solver_loader: ModelLoader instance for solving problems
            grader_loader: ModelLoader instance for grading (if None, uses solver_loader)
            **kwargs: Additional arguments passed to parent class
        """
        # If no grader loader specified, use the solver loader for both
        self.solver_loader = solver_loader
        self.grader_loader = grader_loader or solver_loader
        
        # Initialize parent with combined model info
        super().__init__(
            solver_model=solver_loader.solver_model,
            grader_model=self.grader_loader.grader_model,
            **kwargs
        )
        
        # Track if we're using cross-provider
        self.is_cross_provider = grader_loader is not None and grader_loader != solver_loader
    
    async def _call_api(self, 
                       model: str, 
                       messages: list[Dict[str, str]], 
                       temperature: float = 0.0) -> Tuple[Optional[str], str]:
        """
        Route API calls to the appropriate provider based on the model.
        
        Args:
            model: Model name to use
            messages: List of messages in chat format
            temperature: Temperature for generation
            
        Returns:
            Tuple of (response_content, raw_response)
        """
        # Determine which loader to use based on the model
        if model == self.solver_model:
            return await self.solver_loader._call_api(model, messages, temperature)
        elif model == self.grader_model:
            return await self.grader_loader._call_api(model, messages, temperature)
        else:
            # Try to determine based on which loader has the model
            if hasattr(self.solver_loader, 'solver_model') and model == self.solver_loader.solver_model:
                return await self.solver_loader._call_api(model, messages, temperature)
            elif hasattr(self.grader_loader, 'grader_model') and model == self.grader_loader.grader_model:
                return await self.grader_loader._call_api(model, messages, temperature)
            else:
                raise ValueError(f"Model {model} not found in either solver or grader loader")
    
    def get_model_info(self) -> Dict[str, Any]:
        """Get information about the configured models and providers."""
        solver_info = self.solver_loader.get_model_info()
        grader_info = self.grader_loader.get_model_info()
        
        return {
            "solver_model": self.solver_model,
            "grader_model": self.grader_model,
            "solver_provider": solver_info.get("provider", "unknown"),
            "grader_provider": grader_info.get("provider", "unknown"),
            "is_cross_provider": self.is_cross_provider,
            "solver_info": solver_info,
            "grader_info": grader_info
        }
    
    async def health_check(self) -> bool:
        """
        Perform health checks on both providers.
        
        Returns:
            True if both providers are healthy, False otherwise
        """
        print("🔍 Checking solver provider health...")
        solver_health = await self.solver_loader.health_check()
        
        if self.is_cross_provider:
            print("🔍 Checking grader provider health...")
            grader_health = await self.grader_loader.health_check()
            return solver_health and grader_health
        else:
            return solver_health
    
    async def estimate_cost(self, 
                          num_problems: int, 
                          avg_problem_length: int = 1000,
                          avg_solution_length: int = 2000) -> Dict[str, float]:
        """
        Estimate costs for both providers.
        
        Args:
            num_problems: Number of problems to process
            avg_problem_length: Average length of problem statements in characters
            avg_solution_length: Average length of solutions in characters
            
        Returns:
            Dictionary with combined cost estimates
        """
        # Get solver costs
        solver_costs = await self.solver_loader.estimate_cost(
            num_problems, avg_problem_length, avg_solution_length
        )
        
        if self.is_cross_provider:
            # Get grader costs separately
            grader_costs = await self.grader_loader.estimate_cost(
                num_problems, avg_problem_length, avg_solution_length
            )
            
            # Combine costs
            return {
                "solver_cost": solver_costs.get("solve_cost", 0),
                "grader_cost": grader_costs.get("grade_cost", 0),
                "total_cost": solver_costs.get("solve_cost", 0) + grader_costs.get("grade_cost", 0),
                "solver_provider": self.solver_loader.get_model_info().get("provider"),
                "grader_provider": self.grader_loader.get_model_info().get("provider"),
                "solver_model": self.solver_model,
                "grader_model": self.grader_model,
                "num_problems": num_problems,
                "note": "Cross-provider costs combined"
            }
        else:
            # Single provider costs
            return solver_costs
    
    async def __aenter__(self):
        """Async context manager entry."""
        if hasattr(self.solver_loader, '__aenter__'):
            await self.solver_loader.__aenter__()
        if self.is_cross_provider and hasattr(self.grader_loader, '__aenter__'):
            await self.grader_loader.__aenter__()
        return self
    
    async def __aexit__(self, exc_type, exc_val, exc_tb):
        """Async context manager exit."""
        if hasattr(self.solver_loader, '__aexit__'):
            await self.solver_loader.__aexit__(exc_type, exc_val, exc_tb)
        if self.is_cross_provider and hasattr(self.grader_loader, '__aexit__'):
            await self.grader_loader.__aexit__(exc_type, exc_val, exc_tb)