summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/benchmark/mbpp.py
diff options
context:
space:
mode:
Diffstat (limited to 'code_eval/OpenCodeEval/benchmark/mbpp.py')
-rw-r--r--code_eval/OpenCodeEval/benchmark/mbpp.py153
1 files changed, 153 insertions, 0 deletions
diff --git a/code_eval/OpenCodeEval/benchmark/mbpp.py b/code_eval/OpenCodeEval/benchmark/mbpp.py
new file mode 100644
index 0000000..ee522a2
--- /dev/null
+++ b/code_eval/OpenCodeEval/benchmark/mbpp.py
@@ -0,0 +1,153 @@
+import os
+
+from OpenCodeEval.benchmark.base import Benchmark, PYTHON_STOP, PYTHON_IMPORTS
+from OpenCodeEval.utils import refine_text, stream_jsonl
+from OpenCodeEval.eval.func_eval import check_correctness
+from OpenCodeEval.eval.sanitize import sanitize
+
+from typing import List, Literal
+
+class mbpp(Benchmark):
+
+ name: str = "mbpp"
+
+ imports_code = PYTHON_IMPORTS
+ chat_stop = PYTHON_STOP
+ base_stop = ['\n"""', "\nassert"]
+
+ def __init__(
+ self,
+ split: Literal["full", "sanitized"] = "full",
+ time_out: float = 3.0,
+ prompt_type: str = "Instruction"
+ ):
+
+ super().__init__()
+
+ self.split = split
+ self.time_out = time_out
+ self.prompt_type = prompt_type
+
+ self.path = os.path.join(self.path, f"{self.name}/{self.split}.jsonl")
+ self.tasks = self.get_task()
+
+ self.few_shots_prompt = self.get_few_shots_prompts() if split == "full" else ""
+
+ def get_task(self):
+ """
+ Get the task data from the jsonl file into a dictionary.
+ """
+
+ tasks = {}
+
+ for task_data in stream_jsonl(filename = self.path):
+ task_id = int(task_data["task_id"])
+
+ task_data['text'] = refine_text(task_data['text'])
+
+ tasks[task_id] = task_data
+
+ return tasks
+
+ def fewshot_examples(self):
+
+ few_shots_start = 1
+ few_shots_end = 4
+
+ few_shots = []
+
+ for task_id, task_data in self.tasks.items():
+ if task_id >= few_shots_start and task_id < few_shots_end:
+ few_shots.append(task_data)
+
+ return few_shots
+
+ def format_prompt(self,
+ promblem: str,
+ tests: List[str],
+ code: str = None
+ ) -> str:
+ promblem = f"You are an expert Python programmer, and here is your task:\n{promblem}"
+ test = "\n".join(tests)
+ test = f"Your code should pass these tests:\n{test}\n"
+ prompt = promblem + test
+ if code:
+ code = refine_text(code)
+ code = f"\n```python\n{code}\n```\n"
+ prompt = prompt + code
+ else:
+ prompt = prompt + "\n```python\n"
+ return prompt
+
+ def get_few_shots_prompts(self):
+
+ few_shots_prompts = []
+ for few_shot in self.fewshot_examples():
+ few_shots_prompts.append(self.format_prompt(few_shot["text"], few_shot["test_list"], few_shot["code"]))
+
+ return '\n'.join(few_shots_prompts)
+
+ def get_prompt(self):
+ """
+ Builds the prompt for the LM to generate from.
+ """
+
+ assert self.prompt_type == "Instruction", "Prompt type must be Instruction for mbpp"
+
+ if self.split == "full":
+ test_start = 10
+ test_end = 510
+ elif self.split == "sanitized":
+ test_start = 0
+ test_end = 974
+
+ prompts = []
+
+ for task_id, task_data in self.tasks.items():
+ if task_id >= test_start and task_id < test_end:
+
+ prompt = self.few_shots_prompt + '\n' + self.format_prompt(task_data["text"], task_data["test_list"])
+ prompts.append({
+ 'task_id': task_id,
+ 'prompt': prompt
+ })
+
+ return prompts
+
+ def postprocess_generation(self, generation):
+ """
+ Postprocess the generations.
+ """
+
+ if generation['completion'].startswith(self.few_shots_prompt):
+ generation['completion'] = generation['completion'][len(self.few_shots_prompt):]
+
+ # if "```python" not in generation['completion']:
+ # generation['completion'] = ""
+
+ return dict(
+ task_id = generation['task_id'],
+ completion_id = generation['completion_id'],
+ solution = sanitize(generation['completion'])
+ )
+
+ def process_results(self, solution):
+ """
+ Takes the list of LM generations and evaluates them against the test cases
+ """
+
+ task_data = self.tasks[solution['task_id']]
+
+ code = (
+ "\n".join(self.imports_code) + "\n"
+ + task_data['test_setup_code'] + "\n"
+ + solution['solution'] + "\n"
+ + "\n".join(task_data['test_list'])
+ )
+
+ result = check_correctness(solution['task_id'],
+ solution['completion_id'],
+ code,
+ self.time_out)
+
+ return result