diff options
Diffstat (limited to 'code_eval/OpenCodeEval/benchmark/MBPP.py')
| -rw-r--r-- | code_eval/OpenCodeEval/benchmark/MBPP.py | 126 |
1 files changed, 126 insertions, 0 deletions
diff --git a/code_eval/OpenCodeEval/benchmark/MBPP.py b/code_eval/OpenCodeEval/benchmark/MBPP.py new file mode 100644 index 0000000..0242893 --- /dev/null +++ b/code_eval/OpenCodeEval/benchmark/MBPP.py @@ -0,0 +1,126 @@ +import os +import sys + +ROOT = os.path.dirname(os.path.abspath(__file__)) + +from OpenCodeEval.benchmark.base import Benchmark, PYTHON_STOP, PYTHON_IMPORTS +from OpenCodeEval.utils import refine_text, stream_jsonl, program_extract +from OpenCodeEval.eval.func_eval import check_correctness +from OpenCodeEval.eval.sanitize import sanitize + +class MBPP(Benchmark): + + name: str = "MBPP" + + imports_code = PYTHON_IMPORTS + chat_stop = PYTHON_STOP + base_stop = ['\n"""', "\nassert"] + # TODO: add more stop words, e.g. "\nif __name__", "\ndef main(", "\nprint(", '\n```\n'] + + def __init__( + self, + split: str = "base", + time_out: float = 3.0, + prompt_type: str = "Instruction" + ): + + super().__init__() + + self.split = split + self.time_out = time_out + self.prompt_type = prompt_type + + self.path = os.path.join(self.path, f"{self.name}/data.jsonl") + + self.tasks = self.get_task() + + def get_task(self): + """ + Get the task data from the jsonl file into a dictionary. + """ + + tasks = {} + + for task_data in stream_jsonl(filename=self.path): + + task_id = int(task_data["task_id"]) + + tasks[task_id] = task_data + + return tasks + + def format_prompt(self, + promblem: str, + test: str, + ) -> str: + promblem = f"You are an expert Python programmer, and here is your task:\n{promblem}" + test = f"Your code should pass the test:\n{test}" + prompt = promblem + "\n" + test + return prompt + + def get_prompt(self): + """ + Builds the prompt for the LM to generate from. + """ + + assert self.prompt_type == "Instruction", "Prompt type must be Instruction for MBPP" + + prompts = [] + for task_id, task_data in self.tasks.items(): + prompts.append( + dict( + task_id = task_id, + prompt = refine_text(self.format_prompt(task_data["text"], task_data["test_list"][0])) + ) + ) + return prompts + + def postprocess_generation(self, generation): + """ + Postprocess the generations. + """ + + entry_point = self.tasks[generation['task_id']]["entry_point"] + + try: + # generation['completion'] = program_extract(generation['completion'], program="python", mode="last") + solution = sanitize(generation['completion'], entry_point) + # solution = solution.replace("func0", entry_point) + except Exception: + solution = program_extract(generation['completion'], program="python", mode="all") + + return dict( + task_id = generation['task_id'], + completion_id = generation['completion_id'], + solution = solution + ) + + + def process_results(self, solution): + """Takes the list of LM generations and evaluates them against ground truth references, + returning the metric for the generations. + :param generations: list(list(str)) + list of lists containing generations + :param references: list(str) + list of str containing refrences + """ + + task_data = self.tasks[solution['task_id']] + + if self.split == "base": + test_code = "\n".join(task_data['test_imports']) + "\n\n" + "\n".join(task_data['test_list']) + elif self.split == "plus": + test_code = "\n".join(task_data['test_imports']) + "\n\n" + task_data['test'] + + code = ( + "\n".join(self.imports_code) + "\n" + + solution['solution'] + "\n" + + test_code + ) + + result = check_correctness(solution['task_id'], + solution['completion_id'], + code, + self.time_out) + + return result
\ No newline at end of file |
