From fc6d57ffb8d5ddb5820fcc00b5491a585c259ebc Mon Sep 17 00:00:00 2001 From: Yuren Hao Date: Thu, 4 Sep 2025 22:16:22 -0500 Subject: Initial commit --- code_eval/OpenCodeEval/benchmark/mbpp.py | 153 +++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 code_eval/OpenCodeEval/benchmark/mbpp.py (limited to 'code_eval/OpenCodeEval/benchmark/mbpp.py') diff --git a/code_eval/OpenCodeEval/benchmark/mbpp.py b/code_eval/OpenCodeEval/benchmark/mbpp.py new file mode 100644 index 0000000..ee522a2 --- /dev/null +++ b/code_eval/OpenCodeEval/benchmark/mbpp.py @@ -0,0 +1,153 @@ +import os + +from OpenCodeEval.benchmark.base import Benchmark, PYTHON_STOP, PYTHON_IMPORTS +from OpenCodeEval.utils import refine_text, stream_jsonl +from OpenCodeEval.eval.func_eval import check_correctness +from OpenCodeEval.eval.sanitize import sanitize + +from typing import List, Literal + +class mbpp(Benchmark): + + name: str = "mbpp" + + imports_code = PYTHON_IMPORTS + chat_stop = PYTHON_STOP + base_stop = ['\n"""', "\nassert"] + + def __init__( + self, + split: Literal["full", "sanitized"] = "full", + time_out: float = 3.0, + prompt_type: str = "Instruction" + ): + + super().__init__() + + self.split = split + self.time_out = time_out + self.prompt_type = prompt_type + + self.path = os.path.join(self.path, f"{self.name}/{self.split}.jsonl") + self.tasks = self.get_task() + + self.few_shots_prompt = self.get_few_shots_prompts() if split == "full" else "" + + def get_task(self): + """ + Get the task data from the jsonl file into a dictionary. + """ + + tasks = {} + + for task_data in stream_jsonl(filename = self.path): + task_id = int(task_data["task_id"]) + + task_data['text'] = refine_text(task_data['text']) + + tasks[task_id] = task_data + + return tasks + + def fewshot_examples(self): + + few_shots_start = 1 + few_shots_end = 4 + + few_shots = [] + + for task_id, task_data in self.tasks.items(): + if task_id >= few_shots_start and task_id < few_shots_end: + few_shots.append(task_data) + + return few_shots + + def format_prompt(self, + promblem: str, + tests: List[str], + code: str = None + ) -> str: + promblem = f"You are an expert Python programmer, and here is your task:\n{promblem}" + test = "\n".join(tests) + test = f"Your code should pass these tests:\n{test}\n" + prompt = promblem + test + if code: + code = refine_text(code) + code = f"\n```python\n{code}\n```\n" + prompt = prompt + code + else: + prompt = prompt + "\n```python\n" + return prompt + + def get_few_shots_prompts(self): + + few_shots_prompts = [] + for few_shot in self.fewshot_examples(): + few_shots_prompts.append(self.format_prompt(few_shot["text"], few_shot["test_list"], few_shot["code"])) + + return '\n'.join(few_shots_prompts) + + def get_prompt(self): + """ + Builds the prompt for the LM to generate from. + """ + + assert self.prompt_type == "Instruction", "Prompt type must be Instruction for mbpp" + + if self.split == "full": + test_start = 10 + test_end = 510 + elif self.split == "sanitized": + test_start = 0 + test_end = 974 + + prompts = [] + + for task_id, task_data in self.tasks.items(): + if task_id >= test_start and task_id < test_end: + + prompt = self.few_shots_prompt + '\n' + self.format_prompt(task_data["text"], task_data["test_list"]) + prompts.append({ + 'task_id': task_id, + 'prompt': prompt + }) + + return prompts + + def postprocess_generation(self, generation): + """ + Postprocess the generations. + """ + + if generation['completion'].startswith(self.few_shots_prompt): + generation['completion'] = generation['completion'][len(self.few_shots_prompt):] + + # if "```python" not in generation['completion']: + # generation['completion'] = "" + + return dict( + task_id = generation['task_id'], + completion_id = generation['completion_id'], + solution = sanitize(generation['completion']) + ) + + def process_results(self, solution): + """ + Takes the list of LM generations and evaluates them against the test cases + """ + + task_data = self.tasks[solution['task_id']] + + code = ( + "\n".join(self.imports_code) + "\n" + + task_data['test_setup_code'] + "\n" + + solution['solution'] + "\n" + + "\n".join(task_data['test_list']) + ) + + result = check_correctness(solution['task_id'], + solution['completion_id'], + code, + self.time_out) + + return result -- cgit v1.2.3