diff options
| author | = <=> | 2025-06-04 11:49:37 +0800 |
|---|---|---|
| committer | = <=> | 2025-06-04 11:49:37 +0800 |
| commit | 947d9dfdf16ae37109898111a5caacae7377b96d (patch) | |
| tree | ff4e884020fb7d968a6192106f370b215647f569 /code_eval/OpenCodeEval/backend/openai.py | |
| parent | 5e163b529a78d528b745b8b57ba794b7b2bba97a (diff) | |
update code and kk eval
Diffstat (limited to 'code_eval/OpenCodeEval/backend/openai.py')
| -rw-r--r-- | code_eval/OpenCodeEval/backend/openai.py | 115 |
1 files changed, 115 insertions, 0 deletions
diff --git a/code_eval/OpenCodeEval/backend/openai.py b/code_eval/OpenCodeEval/backend/openai.py new file mode 100644 index 0000000..afa45b2 --- /dev/null +++ b/code_eval/OpenCodeEval/backend/openai.py @@ -0,0 +1,115 @@ +import os + +from typing import List, Dict +from openai import OpenAI +from loguru import logger +from tqdm.contrib.concurrent import thread_map + +from OpenCodeEval.backend.base import Generator + +class OpenaiGenerator(Generator): + + def __init__( + self, + model_name: str, + model_type: str = "Instruction", + batch_size : int = 1, + temperature : float = 0.0, + top_p : float = 1.0, + max_tokens : int = 1024, + num_samples : int = 1, + ) -> None: + + super().__init__(model_name) + + print("Initializing a decoder model in openai: {} ...".format(model_name)) + self.model_type = model_type + self.batch_size = batch_size + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + self.num_samples = num_samples + + self.client = OpenAI( + base_url = os.getenv("OPENAI_BASE_URL"), + api_key = os.getenv("OPENAI_API_KEY") + ) + + def is_chat(self) -> bool: + if self.model_type == "Chat": + return True + else: + return False + + def set_stop(self, eos: List[str]): + self.eos = eos + + def connect_server(self, prompt): + + num_tries = 5 + + try: + + for _ in range(num_tries): + + result = self.client.chat.completions.create( + model = self.model_name, + messages=[ + {"role": "user", "content": prompt['prompt']} + ], + n = self.num_samples, + stream = False, + # stop = self.eos, + temperature = self.temperature, + # top_p = self.top_p, + max_tokens = self.max_tokens, + extra_headers = {'apikey':os.getenv("OPENAI_API_KEY")}, + ) + + if all(choice.message.content for choice in result.choices): + break + else: + logger.warning("No content in the response, retrying...") + + results = [ + dict( + task_id=prompt['task_id'], + completion_id=i, + completion=choice.message.content + ) + for i, choice in enumerate(result.choices) + ] + + except Exception as e: + logger.error("Error: {}".format(e)) + results = [ + dict( + task_id=prompt['task_id'], + completion_id=i, + completion = '' + ) + for i in range(self.num_samples) + ] + + return results + + def generate( + self, + prompt_set: List[Dict], + response_prefix: str = "", + response_suffix: str = "" + ) -> List[Dict]: + + logger.info("Example Prompt:\n{}", prompt_set[0]['prompt']) + + results = thread_map( + self.connect_server, + prompt_set, + max_workers = self.batch_size, + desc="Generating Completions" + ) + + + generations = [item for sublist in results for item in sublist] + + return generations
\ No newline at end of file |
