summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/backend/sglang.py
blob: c7596396c0c0b3d0d25405fedbf572feda5d00e4 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os

from loguru import logger

from tqdm import tqdm
from typing import List, Dict

from OpenCodeEval.backend.base import Generator, make_chat_template

from OpenCodeEval.utils import refine_text

class SglangGenerator(Generator):

    def __init__(
        self,
        model_name: str,
        tokenizer_name: str = None,
        model_type: str = "Instruction",
        dtype: str = "bfloat16",
        batch_size : int = 1,
        temperature : float = 0.0,
        top_p : float = 1.0,
        max_tokens : int = 1024,
        num_samples: int = 200,
        num_gpus: int = 1,
        trust_remote_code: bool = True,
    ) -> None:

        super().__init__(model_name)

        print("Initializing a decoder model in sglang: {} ...".format(model_name))
        self.tokenizer_name = tokenizer_name if tokenizer_name is not None else model_name
        self.model_type = model_type
        self.batch_size = batch_size
        self.temperature = temperature
        self.top_p = top_p
        self.max_tokens = max_tokens
        self.num_samples = num_samples
        self.dtype = dtype
        self.num_gpus = num_gpus
        self.trust_remote_code = trust_remote_code

        from sglang import Engine

        self.model = Engine(
            model_path = self.model_name,
            tokenizer_path = self.tokenizer_name,
            dtype = self.dtype,
            tp_size = self.num_gpus,
            trust_remote_code = self.trust_remote_code
        )

        self.tokenizer = self.model.tokenizer_manager.tokenizer
        os.environ["TOKENIZERS_PARALLELISM"] = "false"

        self.sampling_params = dict(
            n = self.num_samples,
            temperature = self.temperature,
            max_new_tokens = self.max_tokens,
            top_p = self.top_p,
        )
    
    def is_chat(self) -> bool:
        if self.model_type == "Chat":
            if self.tokenizer.chat_template is None:
                logger.error("When the model type is Chat, the chat template must be set for the tokenizer.")
            else:
                return True
        else:
            return False
        
    def set_stop(self, eos: List[str]):
        self.sampling_params['stop'] = eos
        
    def release_memory(self):

        import gc
        import torch

        from sglang.srt.distributed.parallel_state import destroy_distributed_environment, destroy_model_parallel

        self.model.shutdown()

        # destroy_model_parallel
        # destroy_distributed_environment()

        # del self.model
        gc.collect()
        torch.cuda.empty_cache()

    def generate(
            self,
            prompt_set: List[Dict],
            response_prefix: str = "",
            response_suffix: str = ""
        ) -> List[str]:

        if self.is_chat():
            for prompt in prompt_set:
                prompt['prompt'] = make_chat_template(
                    prompt = prompt['prompt'],
                    response_prefix = response_prefix,
                    is_chat = self.is_chat(),
                    tokenizer = self.tokenizer
                )

        logger.info("Example Prompt:\n{}", prompt_set[0]['prompt'])

        generation_set = []

        for batch_start in tqdm(range(0, len(prompt_set), self.batch_size)):

            batch_prompt = prompt_set[batch_start : batch_start + self.batch_size]
            batch_outputs = self.model.generate(
                [prompt['prompt'] for prompt in batch_prompt],
                self.sampling_params
            )

            batch_outputs = [
                [item["text"] for item in batch_outputs[i:i + self.num_samples]]
                for i in range(0, len(batch_outputs), self.num_samples)
            ]

            for prompt, output in zip(batch_prompt, batch_outputs):

                # Process completions with prefix/suffix and refine text based on chat mode
                completions = [
                    refine_text(
                        f"{prompt['prompt']}\n\n{response_prefix}{completion}{response_suffix}"
                        if not self.is_chat()
                        else f"{response_prefix}{completion}{response_suffix}"
                    )
                    for completion in output
                ]

                for completion_id, completion in enumerate(completions):
                    generation_set.append({
                        'task_id': prompt['task_id'],
                        'completion_id': completion_id,
                        'completion': completion
                    })

        if len(generation_set) != len(prompt_set) * self.num_samples:
            logger.error("Number of generations does not match the expected number.")

        self.release_memory()

        return generation_set