1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
import os
import sys
from loguru import logger
from typing import Literal
from OpenCodeEval.benchmark.base import Benchmark
from OpenCodeEval.utils import refine_text, program_extract, stream_jsonl
from OpenCodeEval.eval.sql_test import check_correctness
class Bird(Benchmark):
name: str = "Bird"
def __init__(
self,
split: Literal["train", "dev"] = "dev",
time_out: float = 30.0,
prompt_type: str = "Instruction"
):
super().__init__()
self.split = split
self.time_out = time_out
self.prompt_type = prompt_type
if self.prompt_type == "Completion":
logger.error("Completion prompt type not supported for Bird")
self.database = os.path.join(self.path, f"{self.name}/{self.split}/database")
self.path = os.path.join(self.path, f"{self.name}/{self.split}/data.jsonl")
self.tasks = self.get_task()
def get_task(self):
"""
Get the task data from the json file into a dictionary.
"""
tasks = {}
for task_data in stream_jsonl(filename = self.path):
tasks[int(task_data['id'])] = task_data
return tasks
def get_prompt(self):
"""
Builds the prompt for the LM to generate from.
"""
def construct_prompt(data):
instruction = data['o_schema']
instruction += f"\n\n-- External Knowledge: {data['evidence']}\n\n"
instruction += "-- Using valid SQLite and understanding External Knowledge, answer the following questions for the tables provided above.\n\n"
instruction += f"Question: {data['question']}\n"
return instruction
prompts = []
for task_id, task_data in self.tasks.items():
prompt = construct_prompt(task_data)
prompts.append(
dict(
task_id = task_id,
prompt = refine_text(prompt)
)
)
return prompts
def postprocess_generation(self, generation):
"""
Postprocess the generations.
"""
def one_line_sql(response):
response = program_extract(response, "sql", "last").strip()
lines = response.splitlines()
lines = [l.strip() for l in lines if l.strip()]
sql = " ".join(lines)
return sql
result = dict(
task_id = generation['task_id'],
completion_id = generation['completion_id'],
solution = one_line_sql(generation['completion'])
)
return result
def process_results(self, solution):
"""
Takes the list of LM generations and evaluates them against the test cases
"""
task_data = self.tasks[solution['task_id']]
db_path = self.database + f"/{task_data['db_id']}/{task_data['db_id']}.sqlite"
result, passed, sql_return = check_correctness(
solution['solution'],
task_data['sql'],
db_path,
self.time_out,
"set_match"
)
return dict(
task_id = solution['task_id'],
completion_id = solution['completion_id'],
passed = passed,
result = result,
solution = solution['solution'],
sql_return = sql_return
)
|