1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
|
import re
import random
import sqlite3
from loguru import logger
from itertools import product
from collections import defaultdict
from typing import Tuple, List, Set
from func_timeout import func_timeout, FunctionTimedOut
def permute_tuple(element: Tuple, perm: Tuple) -> Tuple:
assert len(element) == len(perm)
return tuple([element[i] for i in perm])
def unorder_row(row: Tuple) -> Tuple:
return tuple(sorted(row, key=lambda x: str(x) + str(type(x))))
# unorder each row in the table
# [result_1 and result_2 has the same bag of unordered row]
# is a necessary condition of
# [result_1 and result_2 are equivalent in denotation]
def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
s1 = [unorder_row(row) for row in result1]
s2 = [unorder_row(row) for row in result2]
if order_matters:
return s1 == s2
else:
return set(s1) == set(s2)
# return whether two bag of relations are equivalent
def multiset_eq(l1: List, l2: List) -> bool:
if len(l1) != len(l2):
return False
d = defaultdict(int)
for e in l1:
d[e] = d[e] + 1
for e in l2:
d[e] = d[e] - 1
if d[e] < 0:
return False
return True
def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]):
num_cols = len(result2[0])
perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)]
if num_cols <= 3:
return product(*perm_constraints)
# we sample 20 rows and constrain the space of permutations
for _ in range(20):
random_tab2_row = random.choice(result2)
for tab1_col in range(num_cols):
for tab2_col in set(perm_constraints[tab1_col]):
if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]:
perm_constraints[tab1_col].remove(tab2_col)
return product(*perm_constraints)
# check whether two denotations are correct
def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
if len(result1) == 0 and len(result2) == 0:
return True
# if length is not the same, then they are definitely different bag of rows
if len(result1) != len(result2):
return False
num_cols = len(result1[0])
# if the results do not have the same number of columns, they are different
if len(result2[0]) != num_cols:
return False
# unorder each row and compare whether the denotation is the same
# this can already find most pair of denotations that are different
if not quick_rej(result1, result2, order_matters):
return False
# the rest of the problem is in fact more complicated than one might think
# we want to find a permutation of column order and a permutation of row order,
# s.t. result_1 is the same as result_2
# we return true if we can find such column & row permutations
# and false if we cannot
tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)]
# on a high level, we enumerate all possible column permutations that might make result_1 == result_2
# we decrease the size of the column permutation space by the function get_constraint_permutation
# if one of the permutation make result_1, result_2 equivalent, then they are equivalent
for perm in get_constraint_permutation(tab1_sets_by_columns, result2):
if len(perm) != len(set(perm)):
continue
if num_cols == 1:
result2_perm = result2
else:
result2_perm = [permute_tuple(element, perm) for element in result2]
if order_matters:
if result1 == result2_perm:
return True
else:
# in fact the first condition must hold if the second condition holds
# but the first is way more efficient implementation-wise
# and we use it to quickly reject impossible candidates
if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm):
return True
return False
# postprocess the model predictions to avoid execution errors
# e.g. removing spaces between ">" and "="
def postprocess(query: str) -> str:
query = query.replace("> =", ">=").replace("< =", "<=").replace("! =", "!=")
return query
def find_detail(reason):
patterns = [
"ambiguous column name",
"no such column",
"no such table",
"no such function",
"syntax error",
]
for p in patterns:
if p in reason:
return p
return "others"
def execute_sql(predicted_sql,ground_truth, db_path, method):
conn = sqlite3.connect(db_path)
# Connect to the database
cursor = conn.cursor()
cursor.execute(predicted_sql)
predicted_res = cursor.fetchall()
cursor.execute(ground_truth)
ground_truth_res = cursor.fetchall()
result = "failed"
passed = False
if predicted_res is None:
sql_return = [[]]
else:
sql_return = predicted_res
if method == "set_match":
if set(predicted_res) == set(ground_truth_res):
result = "passed"
passed = True
return result, passed, sql_return
elif method == "exact_match":
order_matters = "orderby" in re.sub(r"\s+", ground_truth.lower(), "")
if result_eq(predicted_res, ground_truth_res, order_matters = order_matters):
result = "passed"
passed = True
else:
passed = False
return result, passed, sql_return
else:
logger.error(f"Unknown evaluation method: {method}")
def check_correctness(predicted_sql,ground_truth, db_path, meta_time_out, method):
try:
result, passed, sql_return = func_timeout(
meta_time_out,
execute_sql,
args = (
predicted_sql,
ground_truth,
db_path,
method
)
)
except FunctionTimedOut:
result = "timeout"
passed = False
sql_return = [[]]
except Exception as e:
result = f"error:{e}"
passed = False
sql_return = [[]]
return result, passed, sql_return
|