1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
|
import os
import sys
ROOT = os.path.dirname(os.path.abspath(__file__))
from OpenCodeEval.benchmark.base import Benchmark, PYTHON_STOP, PYTHON_IMPORTS
from OpenCodeEval.utils import refine_text, stream_jsonl, program_extract
from OpenCodeEval.eval.func_eval import check_correctness
from OpenCodeEval.eval.sanitize import sanitize
class MBPP(Benchmark):
name: str = "MBPP"
imports_code = PYTHON_IMPORTS
chat_stop = PYTHON_STOP
base_stop = ['\n"""', "\nassert"]
# TODO: add more stop words, e.g. "\nif __name__", "\ndef main(", "\nprint(", '\n```\n']
def __init__(
self,
split: str = "base",
time_out: float = 3.0,
prompt_type: str = "Instruction"
):
super().__init__()
self.split = split
self.time_out = time_out
self.prompt_type = prompt_type
self.path = os.path.join(self.path, f"{self.name}/data.jsonl")
self.tasks = self.get_task()
def get_task(self):
"""
Get the task data from the jsonl file into a dictionary.
"""
tasks = {}
for task_data in stream_jsonl(filename=self.path):
task_id = int(task_data["task_id"])
tasks[task_id] = task_data
return tasks
def format_prompt(self,
promblem: str,
test: str,
) -> str:
promblem = f"You are an expert Python programmer, and here is your task:\n{promblem}"
test = f"Your code should pass the test:\n{test}"
prompt = promblem + "\n" + test
return prompt
def get_prompt(self):
"""
Builds the prompt for the LM to generate from.
"""
assert self.prompt_type == "Instruction", "Prompt type must be Instruction for MBPP"
prompts = []
for task_id, task_data in self.tasks.items():
prompts.append(
dict(
task_id = task_id,
prompt = refine_text(self.format_prompt(task_data["text"], task_data["test_list"][0]))
)
)
return prompts
def postprocess_generation(self, generation):
"""
Postprocess the generations.
"""
entry_point = self.tasks[generation['task_id']]["entry_point"]
try:
# generation['completion'] = program_extract(generation['completion'], program="python", mode="last")
solution = sanitize(generation['completion'], entry_point)
# solution = solution.replace("func0", entry_point)
except Exception:
solution = program_extract(generation['completion'], program="python", mode="all")
return dict(
task_id = generation['task_id'],
completion_id = generation['completion_id'],
solution = solution
)
def process_results(self, solution):
"""Takes the list of LM generations and evaluates them against ground truth references,
returning the metric for the generations.
:param generations: list(list(str))
list of lists containing generations
:param references: list(str)
list of str containing refrences
"""
task_data = self.tasks[solution['task_id']]
if self.split == "base":
test_code = "\n".join(task_data['test_imports']) + "\n\n" + "\n".join(task_data['test_list'])
elif self.split == "plus":
test_code = "\n".join(task_data['test_imports']) + "\n\n" + task_data['test']
code = (
"\n".join(self.imports_code) + "\n"
+ solution['solution'] + "\n"
+ test_code
)
result = check_correctness(solution['task_id'],
solution['completion_id'],
code,
self.time_out)
return result
|