diff options
Diffstat (limited to 'code_eval/OpenCodeEval/backend/vllm.py')
| -rw-r--r-- | code_eval/OpenCodeEval/backend/vllm.py | 144 |
1 files changed, 144 insertions, 0 deletions
diff --git a/code_eval/OpenCodeEval/backend/vllm.py b/code_eval/OpenCodeEval/backend/vllm.py new file mode 100644 index 0000000..3082dfe --- /dev/null +++ b/code_eval/OpenCodeEval/backend/vllm.py @@ -0,0 +1,144 @@ +import os + +from loguru import logger + +from tqdm import tqdm +from typing import List, Dict + +from OpenCodeEval.backend.base import Generator, make_chat_template + +from OpenCodeEval.utils import refine_text + +class VllmGenerator(Generator): + + def __init__( + self, + model_name: str, + tokenizer_name: str = None, + model_type: str = "Instruction", + dtype: str = "bfloat16", + batch_size : int = 1, + temperature : float = 0.0, + top_p : float = 1.0, + max_tokens : int = 1024, + num_samples: int = 200, + num_gpus: int = 1, + seed: int = None, + trust_remote_code: bool = True, + ) -> None: + + super().__init__(model_name) + + print("Initializing a decoder model in vllm: {} ...".format(model_name)) + self.tokenizer_name = tokenizer_name if tokenizer_name is not None else model_name + self.model_type = model_type + self.batch_size = batch_size + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + self.num_samples = num_samples + self.dtype = dtype + self.num_gpus = num_gpus + self.seed = seed + self.trust_remote_code = trust_remote_code + + from vllm import LLM, SamplingParams + + self.model = LLM( + model = self.model_name, + tokenizer = self.tokenizer_name, + dtype = self.dtype, + tensor_parallel_size = self.num_gpus, + trust_remote_code = self.trust_remote_code + ) + + self.tokenizer = self.model.get_tokenizer() + os.environ["TOKENIZERS_PARALLELISM"] = "false" + + self.sampling_params = SamplingParams( + n = self.num_samples, + temperature = self.temperature, + max_tokens = self.max_tokens, + top_p = self.top_p, + seed = int(self.seed) + ) + + def is_chat(self) -> bool: + if self.model_type == "Chat": + if self.tokenizer.chat_template is None: + logger.error("When the model type is Chat, the chat template must be set for the tokenizer.") + else: + return True + else: + return False + + def set_stop(self, eos: List[str]): + self.sampling_params.stop = eos + + def release_memory(self): + + import gc + import torch + + from vllm.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel + + destroy_model_parallel + destroy_distributed_environment() + del self.model.llm_engine.model_executor + del self.model + gc.collect() + torch.cuda.empty_cache() + + def generate( + self, + prompt_set: List[Dict], + response_prefix: str = "", + response_suffix: str = "" + ) -> List[str]: + + if self.is_chat(): + for prompt in prompt_set: + prompt['prompt'] = make_chat_template( + prompt = prompt['prompt'], + response_prefix = response_prefix, + is_chat = self.is_chat(), + tokenizer = self.tokenizer + ) + + logger.info("Example Prompt:\n{}", prompt_set[0]['prompt']) + + generation_set = [] + + for batch_start in tqdm(range(0, len(prompt_set), self.batch_size)): + batch_prompt = prompt_set[batch_start : batch_start + self.batch_size] + batch_outputs = self.model.generate( + [prompt['prompt'] for prompt in batch_prompt], + self.sampling_params, + use_tqdm = False, + ) + + for prompt, output in zip(batch_prompt, batch_outputs): + + # Process completions with prefix/suffix and refine text based on chat mode + completions = [ + refine_text( + f"{prompt['prompt']}\n\n{response_prefix}{completion.text}{response_suffix}" + if not self.is_chat() + else f"{response_prefix}{completion.text}{response_suffix}" + ) + for completion in output.outputs + ] + + for completion_id, completion in enumerate(completions): + generation_set.append({ + 'task_id': prompt['task_id'], + 'completion_id': completion_id, + 'completion': completion + }) + + if len(generation_set) != len(prompt_set) * self.num_samples: + logger.error("Number of generations does not match the expected number.") + + self.release_memory() + + return generation_set
\ No newline at end of file |
