summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/backend
diff options
context:
space:
mode:
Diffstat (limited to 'code_eval/OpenCodeEval/backend')
-rw-r--r--code_eval/OpenCodeEval/backend/__init__.py0
-rw-r--r--code_eval/OpenCodeEval/backend/base.py55
-rw-r--r--code_eval/OpenCodeEval/backend/openai.py115
-rw-r--r--code_eval/OpenCodeEval/backend/sglang.py148
-rw-r--r--code_eval/OpenCodeEval/backend/vllm.py144
5 files changed, 462 insertions, 0 deletions
diff --git a/code_eval/OpenCodeEval/backend/__init__.py b/code_eval/OpenCodeEval/backend/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/code_eval/OpenCodeEval/backend/__init__.py
diff --git a/code_eval/OpenCodeEval/backend/base.py b/code_eval/OpenCodeEval/backend/base.py
new file mode 100644
index 0000000..42091b9
--- /dev/null
+++ b/code_eval/OpenCodeEval/backend/base.py
@@ -0,0 +1,55 @@
+from typing import Callable
+from abc import ABC, abstractmethod
+
+def make_chat_template(
+ prompt: str,
+ response_prefix: str = "",
+ is_chat: bool = True,
+ tokenizer: Callable = None
+ ) -> str:
+
+ if is_chat:
+ prompt = tokenizer.apply_chat_template(
+ [
+ {"role": "user", "content": prompt},
+ ],
+ tokenize = False,
+ add_generation_prompt = True
+ ) + response_prefix
+ if tokenizer.bos_token and prompt.startswith(tokenizer.bos_token):
+ prompt = prompt[len(tokenizer.bos_token):]
+ return prompt
+ else:
+ return prompt
+
+class Generator(ABC):
+
+ model_name: str = None
+
+ def __init__(self, model_name: str) -> None:
+ """
+ :param stop_words: list
+ list of stop words if the generation uses a stopping criteria during generation
+ :param requires_execution: bool
+ wheter the task requires code execution during evaluation or not
+ """
+ self.model_name = model_name
+
+ def fewshot_examples(self):
+ """Loads and returns the few-shot examples for the task if they exist."""
+ pass
+
+ @abstractmethod
+ def set_stop(self):
+ """
+ Set the stop tokens for the model
+ """
+ pass
+
+ @abstractmethod
+ def generate(self):
+ """Builds the prompt for the LM to generate from.
+ :param doc: dict[str: str]
+ sample from the test dataset
+ """
+ pass \ No newline at end of file
diff --git a/code_eval/OpenCodeEval/backend/openai.py b/code_eval/OpenCodeEval/backend/openai.py
new file mode 100644
index 0000000..afa45b2
--- /dev/null
+++ b/code_eval/OpenCodeEval/backend/openai.py
@@ -0,0 +1,115 @@
+import os
+
+from typing import List, Dict
+from openai import OpenAI
+from loguru import logger
+from tqdm.contrib.concurrent import thread_map
+
+from OpenCodeEval.backend.base import Generator
+
+class OpenaiGenerator(Generator):
+
+ def __init__(
+ self,
+ model_name: str,
+ model_type: str = "Instruction",
+ batch_size : int = 1,
+ temperature : float = 0.0,
+ top_p : float = 1.0,
+ max_tokens : int = 1024,
+ num_samples : int = 1,
+ ) -> None:
+
+ super().__init__(model_name)
+
+ print("Initializing a decoder model in openai: {} ...".format(model_name))
+ self.model_type = model_type
+ self.batch_size = batch_size
+ self.temperature = temperature
+ self.top_p = top_p
+ self.max_tokens = max_tokens
+ self.num_samples = num_samples
+
+ self.client = OpenAI(
+ base_url = os.getenv("OPENAI_BASE_URL"),
+ api_key = os.getenv("OPENAI_API_KEY")
+ )
+
+ def is_chat(self) -> bool:
+ if self.model_type == "Chat":
+ return True
+ else:
+ return False
+
+ def set_stop(self, eos: List[str]):
+ self.eos = eos
+
+ def connect_server(self, prompt):
+
+ num_tries = 5
+
+ try:
+
+ for _ in range(num_tries):
+
+ result = self.client.chat.completions.create(
+ model = self.model_name,
+ messages=[
+ {"role": "user", "content": prompt['prompt']}
+ ],
+ n = self.num_samples,
+ stream = False,
+ # stop = self.eos,
+ temperature = self.temperature,
+ # top_p = self.top_p,
+ max_tokens = self.max_tokens,
+ extra_headers = {'apikey':os.getenv("OPENAI_API_KEY")},
+ )
+
+ if all(choice.message.content for choice in result.choices):
+ break
+ else:
+ logger.warning("No content in the response, retrying...")
+
+ results = [
+ dict(
+ task_id=prompt['task_id'],
+ completion_id=i,
+ completion=choice.message.content
+ )
+ for i, choice in enumerate(result.choices)
+ ]
+
+ except Exception as e:
+ logger.error("Error: {}".format(e))
+ results = [
+ dict(
+ task_id=prompt['task_id'],
+ completion_id=i,
+ completion = ''
+ )
+ for i in range(self.num_samples)
+ ]
+
+ return results
+
+ def generate(
+ self,
+ prompt_set: List[Dict],
+ response_prefix: str = "",
+ response_suffix: str = ""
+ ) -> List[Dict]:
+
+ logger.info("Example Prompt:\n{}", prompt_set[0]['prompt'])
+
+ results = thread_map(
+ self.connect_server,
+ prompt_set,
+ max_workers = self.batch_size,
+ desc="Generating Completions"
+ )
+
+
+ generations = [item for sublist in results for item in sublist]
+
+ return generations \ No newline at end of file
diff --git a/code_eval/OpenCodeEval/backend/sglang.py b/code_eval/OpenCodeEval/backend/sglang.py
new file mode 100644
index 0000000..c759639
--- /dev/null
+++ b/code_eval/OpenCodeEval/backend/sglang.py
@@ -0,0 +1,148 @@
+import os
+
+from loguru import logger
+
+from tqdm import tqdm
+from typing import List, Dict
+
+from OpenCodeEval.backend.base import Generator, make_chat_template
+
+from OpenCodeEval.utils import refine_text
+
+class SglangGenerator(Generator):
+
+ def __init__(
+ self,
+ model_name: str,
+ tokenizer_name: str = None,
+ model_type: str = "Instruction",
+ dtype: str = "bfloat16",
+ batch_size : int = 1,
+ temperature : float = 0.0,
+ top_p : float = 1.0,
+ max_tokens : int = 1024,
+ num_samples: int = 200,
+ num_gpus: int = 1,
+ trust_remote_code: bool = True,
+ ) -> None:
+
+ super().__init__(model_name)
+
+ print("Initializing a decoder model in sglang: {} ...".format(model_name))
+ self.tokenizer_name = tokenizer_name if tokenizer_name is not None else model_name
+ self.model_type = model_type
+ self.batch_size = batch_size
+ self.temperature = temperature
+ self.top_p = top_p
+ self.max_tokens = max_tokens
+ self.num_samples = num_samples
+ self.dtype = dtype
+ self.num_gpus = num_gpus
+ self.trust_remote_code = trust_remote_code
+
+ from sglang import Engine
+
+ self.model = Engine(
+ model_path = self.model_name,
+ tokenizer_path = self.tokenizer_name,
+ dtype = self.dtype,
+ tp_size = self.num_gpus,
+ trust_remote_code = self.trust_remote_code
+ )
+
+ self.tokenizer = self.model.tokenizer_manager.tokenizer
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+ self.sampling_params = dict(
+ n = self.num_samples,
+ temperature = self.temperature,
+ max_new_tokens = self.max_tokens,
+ top_p = self.top_p,
+ )
+
+ def is_chat(self) -> bool:
+ if self.model_type == "Chat":
+ if self.tokenizer.chat_template is None:
+ logger.error("When the model type is Chat, the chat template must be set for the tokenizer.")
+ else:
+ return True
+ else:
+ return False
+
+ def set_stop(self, eos: List[str]):
+ self.sampling_params['stop'] = eos
+
+ def release_memory(self):
+
+ import gc
+ import torch
+
+ from sglang.srt.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
+
+ self.model.shutdown()
+
+ # destroy_model_parallel
+ # destroy_distributed_environment()
+
+ # del self.model
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def generate(
+ self,
+ prompt_set: List[Dict],
+ response_prefix: str = "",
+ response_suffix: str = ""
+ ) -> List[str]:
+
+ if self.is_chat():
+ for prompt in prompt_set:
+ prompt['prompt'] = make_chat_template(
+ prompt = prompt['prompt'],
+ response_prefix = response_prefix,
+ is_chat = self.is_chat(),
+ tokenizer = self.tokenizer
+ )
+
+ logger.info("Example Prompt:\n{}", prompt_set[0]['prompt'])
+
+ generation_set = []
+
+ for batch_start in tqdm(range(0, len(prompt_set), self.batch_size)):
+
+ batch_prompt = prompt_set[batch_start : batch_start + self.batch_size]
+ batch_outputs = self.model.generate(
+ [prompt['prompt'] for prompt in batch_prompt],
+ self.sampling_params
+ )
+
+ batch_outputs = [
+ [item["text"] for item in batch_outputs[i:i + self.num_samples]]
+ for i in range(0, len(batch_outputs), self.num_samples)
+ ]
+
+ for prompt, output in zip(batch_prompt, batch_outputs):
+
+ # Process completions with prefix/suffix and refine text based on chat mode
+ completions = [
+ refine_text(
+ f"{prompt['prompt']}\n\n{response_prefix}{completion}{response_suffix}"
+ if not self.is_chat()
+ else f"{response_prefix}{completion}{response_suffix}"
+ )
+ for completion in output
+ ]
+
+ for completion_id, completion in enumerate(completions):
+ generation_set.append({
+ 'task_id': prompt['task_id'],
+ 'completion_id': completion_id,
+ 'completion': completion
+ })
+
+ if len(generation_set) != len(prompt_set) * self.num_samples:
+ logger.error("Number of generations does not match the expected number.")
+
+ self.release_memory()
+
+ return generation_set \ No newline at end of file
diff --git a/code_eval/OpenCodeEval/backend/vllm.py b/code_eval/OpenCodeEval/backend/vllm.py
new file mode 100644
index 0000000..3082dfe
--- /dev/null
+++ b/code_eval/OpenCodeEval/backend/vllm.py
@@ -0,0 +1,144 @@
+import os
+
+from loguru import logger
+
+from tqdm import tqdm
+from typing import List, Dict
+
+from OpenCodeEval.backend.base import Generator, make_chat_template
+
+from OpenCodeEval.utils import refine_text
+
+class VllmGenerator(Generator):
+
+ def __init__(
+ self,
+ model_name: str,
+ tokenizer_name: str = None,
+ model_type: str = "Instruction",
+ dtype: str = "bfloat16",
+ batch_size : int = 1,
+ temperature : float = 0.0,
+ top_p : float = 1.0,
+ max_tokens : int = 1024,
+ num_samples: int = 200,
+ num_gpus: int = 1,
+ seed: int = None,
+ trust_remote_code: bool = True,
+ ) -> None:
+
+ super().__init__(model_name)
+
+ print("Initializing a decoder model in vllm: {} ...".format(model_name))
+ self.tokenizer_name = tokenizer_name if tokenizer_name is not None else model_name
+ self.model_type = model_type
+ self.batch_size = batch_size
+ self.temperature = temperature
+ self.top_p = top_p
+ self.max_tokens = max_tokens
+ self.num_samples = num_samples
+ self.dtype = dtype
+ self.num_gpus = num_gpus
+ self.seed = seed
+ self.trust_remote_code = trust_remote_code
+
+ from vllm import LLM, SamplingParams
+
+ self.model = LLM(
+ model = self.model_name,
+ tokenizer = self.tokenizer_name,
+ dtype = self.dtype,
+ tensor_parallel_size = self.num_gpus,
+ trust_remote_code = self.trust_remote_code
+ )
+
+ self.tokenizer = self.model.get_tokenizer()
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
+
+ self.sampling_params = SamplingParams(
+ n = self.num_samples,
+ temperature = self.temperature,
+ max_tokens = self.max_tokens,
+ top_p = self.top_p,
+ seed = int(self.seed)
+ )
+
+ def is_chat(self) -> bool:
+ if self.model_type == "Chat":
+ if self.tokenizer.chat_template is None:
+ logger.error("When the model type is Chat, the chat template must be set for the tokenizer.")
+ else:
+ return True
+ else:
+ return False
+
+ def set_stop(self, eos: List[str]):
+ self.sampling_params.stop = eos
+
+ def release_memory(self):
+
+ import gc
+ import torch
+
+ from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel
+
+ destroy_model_parallel
+ destroy_distributed_environment()
+ del self.model.llm_engine.model_executor
+ del self.model
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ def generate(
+ self,
+ prompt_set: List[Dict],
+ response_prefix: str = "",
+ response_suffix: str = ""
+ ) -> List[str]:
+
+ if self.is_chat():
+ for prompt in prompt_set:
+ prompt['prompt'] = make_chat_template(
+ prompt = prompt['prompt'],
+ response_prefix = response_prefix,
+ is_chat = self.is_chat(),
+ tokenizer = self.tokenizer
+ )
+
+ logger.info("Example Prompt:\n{}", prompt_set[0]['prompt'])
+
+ generation_set = []
+
+ for batch_start in tqdm(range(0, len(prompt_set), self.batch_size)):
+ batch_prompt = prompt_set[batch_start : batch_start + self.batch_size]
+ batch_outputs = self.model.generate(
+ [prompt['prompt'] for prompt in batch_prompt],
+ self.sampling_params,
+ use_tqdm = False,
+ )
+
+ for prompt, output in zip(batch_prompt, batch_outputs):
+
+ # Process completions with prefix/suffix and refine text based on chat mode
+ completions = [
+ refine_text(
+ f"{prompt['prompt']}\n\n{response_prefix}{completion.text}{response_suffix}"
+ if not self.is_chat()
+ else f"{response_prefix}{completion.text}{response_suffix}"
+ )
+ for completion in output.outputs
+ ]
+
+ for completion_id, completion in enumerate(completions):
+ generation_set.append({
+ 'task_id': prompt['task_id'],
+ 'completion_id': completion_id,
+ 'completion': completion
+ })
+
+ if len(generation_set) != len(prompt_set) * self.num_samples:
+ logger.error("Number of generations does not match the expected number.")
+
+ self.release_memory()
+
+ return generation_set \ No newline at end of file