1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
|
import os
from typing import List, Dict
from openai import OpenAI
from loguru import logger
from tqdm.contrib.concurrent import thread_map
from OpenCodeEval.backend.base import Generator
class OpenaiGenerator(Generator):
def __init__(
self,
model_name: str,
model_type: str = "Instruction",
batch_size : int = 1,
temperature : float = 0.0,
top_p : float = 1.0,
max_tokens : int = 1024,
num_samples : int = 1,
) -> None:
super().__init__(model_name)
print("Initializing a decoder model in openai: {} ...".format(model_name))
self.model_type = model_type
self.batch_size = batch_size
self.temperature = temperature
self.top_p = top_p
self.max_tokens = max_tokens
self.num_samples = num_samples
self.client = OpenAI(
base_url = os.getenv("OPENAI_BASE_URL"),
api_key = os.getenv("OPENAI_API_KEY")
)
def is_chat(self) -> bool:
if self.model_type == "Chat":
return True
else:
return False
def set_stop(self, eos: List[str]):
self.eos = eos
def connect_server(self, prompt):
num_tries = 5
try:
for _ in range(num_tries):
result = self.client.chat.completions.create(
model = self.model_name,
messages=[
{"role": "user", "content": prompt['prompt']}
],
n = self.num_samples,
stream = False,
# stop = self.eos,
temperature = self.temperature,
# top_p = self.top_p,
max_tokens = self.max_tokens,
extra_headers = {'apikey':os.getenv("OPENAI_API_KEY")},
)
if all(choice.message.content for choice in result.choices):
break
else:
logger.warning("No content in the response, retrying...")
results = [
dict(
task_id=prompt['task_id'],
completion_id=i,
completion=choice.message.content
)
for i, choice in enumerate(result.choices)
]
except Exception as e:
logger.error("Error: {}".format(e))
results = [
dict(
task_id=prompt['task_id'],
completion_id=i,
completion = ''
)
for i in range(self.num_samples)
]
return results
def generate(
self,
prompt_set: List[Dict],
response_prefix: str = "",
response_suffix: str = ""
) -> List[Dict]:
logger.info("Example Prompt:\n{}", prompt_set[0]['prompt'])
results = thread_map(
self.connect_server,
prompt_set,
max_workers = self.batch_size,
desc="Generating Completions"
)
generations = [item for sublist in results for item in sublist]
return generations
|