summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/benchmark/MBPP.py
diff options
context:
space:
mode:
Diffstat (limited to 'code_eval/OpenCodeEval/benchmark/MBPP.py')
-rw-r--r--code_eval/OpenCodeEval/benchmark/MBPP.py126
1 files changed, 126 insertions, 0 deletions
diff --git a/code_eval/OpenCodeEval/benchmark/MBPP.py b/code_eval/OpenCodeEval/benchmark/MBPP.py
new file mode 100644
index 0000000..0242893
--- /dev/null
+++ b/code_eval/OpenCodeEval/benchmark/MBPP.py
@@ -0,0 +1,126 @@
+import os
+import sys
+
+ROOT = os.path.dirname(os.path.abspath(__file__))
+
+from OpenCodeEval.benchmark.base import Benchmark, PYTHON_STOP, PYTHON_IMPORTS
+from OpenCodeEval.utils import refine_text, stream_jsonl, program_extract
+from OpenCodeEval.eval.func_eval import check_correctness
+from OpenCodeEval.eval.sanitize import sanitize
+
+class MBPP(Benchmark):
+
+ name: str = "MBPP"
+
+ imports_code = PYTHON_IMPORTS
+ chat_stop = PYTHON_STOP
+ base_stop = ['\n"""', "\nassert"]
+ # TODO: add more stop words, e.g. "\nif __name__", "\ndef main(", "\nprint(", '\n```\n']
+
+ def __init__(
+ self,
+ split: str = "base",
+ time_out: float = 3.0,
+ prompt_type: str = "Instruction"
+ ):
+
+ super().__init__()
+
+ self.split = split
+ self.time_out = time_out
+ self.prompt_type = prompt_type
+
+ self.path = os.path.join(self.path, f"{self.name}/data.jsonl")
+
+ self.tasks = self.get_task()
+
+ def get_task(self):
+ """
+ Get the task data from the jsonl file into a dictionary.
+ """
+
+ tasks = {}
+
+ for task_data in stream_jsonl(filename=self.path):
+
+ task_id = int(task_data["task_id"])
+
+ tasks[task_id] = task_data
+
+ return tasks
+
+ def format_prompt(self,
+ promblem: str,
+ test: str,
+ ) -> str:
+ promblem = f"You are an expert Python programmer, and here is your task:\n{promblem}"
+ test = f"Your code should pass the test:\n{test}"
+ prompt = promblem + "\n" + test
+ return prompt
+
+ def get_prompt(self):
+ """
+ Builds the prompt for the LM to generate from.
+ """
+
+ assert self.prompt_type == "Instruction", "Prompt type must be Instruction for MBPP"
+
+ prompts = []
+ for task_id, task_data in self.tasks.items():
+ prompts.append(
+ dict(
+ task_id = task_id,
+ prompt = refine_text(self.format_prompt(task_data["text"], task_data["test_list"][0]))
+ )
+ )
+ return prompts
+
+ def postprocess_generation(self, generation):
+ """
+ Postprocess the generations.
+ """
+
+ entry_point = self.tasks[generation['task_id']]["entry_point"]
+
+ try:
+ # generation['completion'] = program_extract(generation['completion'], program="python", mode="last")
+ solution = sanitize(generation['completion'], entry_point)
+ # solution = solution.replace("func0", entry_point)
+ except Exception:
+ solution = program_extract(generation['completion'], program="python", mode="all")
+
+ return dict(
+ task_id = generation['task_id'],
+ completion_id = generation['completion_id'],
+ solution = solution
+ )
+
+
+ def process_results(self, solution):
+ """Takes the list of LM generations and evaluates them against ground truth references,
+ returning the metric for the generations.
+ :param generations: list(list(str))
+ list of lists containing generations
+ :param references: list(str)
+ list of str containing refrences
+ """
+
+ task_data = self.tasks[solution['task_id']]
+
+ if self.split == "base":
+ test_code = "\n".join(task_data['test_imports']) + "\n\n" + "\n".join(task_data['test_list'])
+ elif self.split == "plus":
+ test_code = "\n".join(task_data['test_imports']) + "\n\n" + task_data['test']
+
+ code = (
+ "\n".join(self.imports_code) + "\n"
+ + solution['solution'] + "\n"
+ + test_code
+ )
+
+ result = check_correctness(solution['task_id'],
+ solution['completion_id'],
+ code,
+ self.time_out)
+
+ return result \ No newline at end of file