summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/backend/openai.py
blob: afa45b2b61e503f6ffe03dab28cf8679aea68973 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os

from typing import List, Dict
from openai import OpenAI
from loguru import logger
from tqdm.contrib.concurrent import thread_map

from OpenCodeEval.backend.base import Generator

class OpenaiGenerator(Generator):

    def __init__(
            self,
            model_name: str,
            model_type: str = "Instruction",
            batch_size : int = 1,
            temperature : float = 0.0,
            top_p : float = 1.0,
            max_tokens : int = 1024,
            num_samples : int = 1,
        ) -> None:

        super().__init__(model_name)

        print("Initializing a decoder model in openai: {} ...".format(model_name))
        self.model_type = model_type
        self.batch_size = batch_size
        self.temperature = temperature
        self.top_p = top_p
        self.max_tokens = max_tokens
        self.num_samples = num_samples

        self.client = OpenAI(
            base_url = os.getenv("OPENAI_BASE_URL"),
            api_key = os.getenv("OPENAI_API_KEY")
        )

    def is_chat(self) -> bool:
        if self.model_type == "Chat":
            return True
        else:
            return False
    
    def set_stop(self, eos: List[str]):
        self.eos = eos

    def connect_server(self, prompt):

        num_tries = 5

        try:

            for _ in range(num_tries):

                result = self.client.chat.completions.create(
                    model = self.model_name,
                    messages=[
                        {"role": "user", "content": prompt['prompt']}
                    ], 
                    n = self.num_samples,
                    stream = False,
                    # stop = self.eos,
                    temperature = self.temperature,
                    # top_p = self.top_p,
                    max_tokens = self.max_tokens,
                    extra_headers = {'apikey':os.getenv("OPENAI_API_KEY")},
                )

                if all(choice.message.content for choice in result.choices):
                    break
                else:
                    logger.warning("No content in the response, retrying...")
                
            results = [
                dict(
                    task_id=prompt['task_id'],
                    completion_id=i,
                    completion=choice.message.content
                )
                for i, choice in enumerate(result.choices)
            ]
            
        except Exception as e:
            logger.error("Error: {}".format(e))
            results = [
                dict(
                    task_id=prompt['task_id'],
                    completion_id=i,
                    completion = ''
                )
                for i in range(self.num_samples)
            ]
            
        return results

    def generate(
        self,
        prompt_set: List[Dict],
        response_prefix: str = "",
        response_suffix: str = ""
    ) -> List[Dict]:
        
        logger.info("Example Prompt:\n{}", prompt_set[0]['prompt'])
        
        results = thread_map(
            self.connect_server, 
            prompt_set,
            max_workers = self.batch_size,
            desc="Generating Completions"
        )
        

        generations = [item for sublist in results for item in sublist]
        
        return generations