summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/eval/sql_test.py
diff options
context:
space:
mode:
authorYuren Hao <yurenh2@timan108.cs.illinois.edu>2025-09-04 22:16:22 -0500
committerYuren Hao <yurenh2@timan108.cs.illinois.edu>2025-09-04 22:16:22 -0500
commitfc6d57ffb8d5ddb5820fcc00b5491a585c259ebc (patch)
treee9841f93a353e2107225cfc721d1ce57c0e594dc /code_eval/OpenCodeEval/eval/sql_test.py
Initial commit
Diffstat (limited to 'code_eval/OpenCodeEval/eval/sql_test.py')
-rw-r--r--code_eval/OpenCodeEval/eval/sql_test.py187
1 files changed, 187 insertions, 0 deletions
diff --git a/code_eval/OpenCodeEval/eval/sql_test.py b/code_eval/OpenCodeEval/eval/sql_test.py
new file mode 100644
index 0000000..9b0724e
--- /dev/null
+++ b/code_eval/OpenCodeEval/eval/sql_test.py
@@ -0,0 +1,187 @@
+import re
+import random
+import sqlite3
+
+from loguru import logger
+from itertools import product
+from collections import defaultdict
+from typing import Tuple, List, Set
+from func_timeout import func_timeout, FunctionTimedOut
+
+def permute_tuple(element: Tuple, perm: Tuple) -> Tuple:
+ assert len(element) == len(perm)
+ return tuple([element[i] for i in perm])
+
+
+def unorder_row(row: Tuple) -> Tuple:
+ return tuple(sorted(row, key=lambda x: str(x) + str(type(x))))
+
+
+# unorder each row in the table
+# [result_1 and result_2 has the same bag of unordered row]
+# is a necessary condition of
+# [result_1 and result_2 are equivalent in denotation]
+def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
+ s1 = [unorder_row(row) for row in result1]
+ s2 = [unorder_row(row) for row in result2]
+ if order_matters:
+ return s1 == s2
+ else:
+ return set(s1) == set(s2)
+
+# return whether two bag of relations are equivalent
+def multiset_eq(l1: List, l2: List) -> bool:
+ if len(l1) != len(l2):
+ return False
+ d = defaultdict(int)
+ for e in l1:
+ d[e] = d[e] + 1
+ for e in l2:
+ d[e] = d[e] - 1
+ if d[e] < 0:
+ return False
+ return True
+
+def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]):
+ num_cols = len(result2[0])
+ perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)]
+ if num_cols <= 3:
+ return product(*perm_constraints)
+
+ # we sample 20 rows and constrain the space of permutations
+ for _ in range(20):
+ random_tab2_row = random.choice(result2)
+
+ for tab1_col in range(num_cols):
+ for tab2_col in set(perm_constraints[tab1_col]):
+ if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]:
+ perm_constraints[tab1_col].remove(tab2_col)
+ return product(*perm_constraints)
+
+# check whether two denotations are correct
+def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool:
+ if len(result1) == 0 and len(result2) == 0:
+ return True
+
+ # if length is not the same, then they are definitely different bag of rows
+ if len(result1) != len(result2):
+ return False
+
+ num_cols = len(result1[0])
+
+ # if the results do not have the same number of columns, they are different
+ if len(result2[0]) != num_cols:
+ return False
+
+ # unorder each row and compare whether the denotation is the same
+ # this can already find most pair of denotations that are different
+ if not quick_rej(result1, result2, order_matters):
+ return False
+
+ # the rest of the problem is in fact more complicated than one might think
+ # we want to find a permutation of column order and a permutation of row order,
+ # s.t. result_1 is the same as result_2
+ # we return true if we can find such column & row permutations
+ # and false if we cannot
+ tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)]
+
+ # on a high level, we enumerate all possible column permutations that might make result_1 == result_2
+ # we decrease the size of the column permutation space by the function get_constraint_permutation
+ # if one of the permutation make result_1, result_2 equivalent, then they are equivalent
+ for perm in get_constraint_permutation(tab1_sets_by_columns, result2):
+ if len(perm) != len(set(perm)):
+ continue
+ if num_cols == 1:
+ result2_perm = result2
+ else:
+ result2_perm = [permute_tuple(element, perm) for element in result2]
+ if order_matters:
+ if result1 == result2_perm:
+ return True
+ else:
+ # in fact the first condition must hold if the second condition holds
+ # but the first is way more efficient implementation-wise
+ # and we use it to quickly reject impossible candidates
+ if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm):
+ return True
+ return False
+
+# postprocess the model predictions to avoid execution errors
+# e.g. removing spaces between ">" and "="
+def postprocess(query: str) -> str:
+ query = query.replace("> =", ">=").replace("< =", "<=").replace("! =", "!=")
+ return query
+
+def find_detail(reason):
+ patterns = [
+ "ambiguous column name",
+ "no such column",
+ "no such table",
+ "no such function",
+ "syntax error",
+ ]
+
+ for p in patterns:
+ if p in reason:
+ return p
+ return "others"
+
+def execute_sql(predicted_sql,ground_truth, db_path, method):
+ conn = sqlite3.connect(db_path)
+ # Connect to the database
+ cursor = conn.cursor()
+ cursor.execute(predicted_sql)
+ predicted_res = cursor.fetchall()
+ cursor.execute(ground_truth)
+ ground_truth_res = cursor.fetchall()
+ result = "failed"
+ passed = False
+
+ if predicted_res is None:
+ sql_return = [[]]
+ else:
+ sql_return = predicted_res
+
+ if method == "set_match":
+
+ if set(predicted_res) == set(ground_truth_res):
+ result = "passed"
+ passed = True
+ return result, passed, sql_return
+
+ elif method == "exact_match":
+
+ order_matters = "orderby" in re.sub(r"\s+", ground_truth.lower(), "")
+ if result_eq(predicted_res, ground_truth_res, order_matters = order_matters):
+ result = "passed"
+ passed = True
+ else:
+ passed = False
+
+ return result, passed, sql_return
+
+ else:
+ logger.error(f"Unknown evaluation method: {method}")
+
+def check_correctness(predicted_sql,ground_truth, db_path, meta_time_out, method):
+ try:
+ result, passed, sql_return = func_timeout(
+ meta_time_out,
+ execute_sql,
+ args = (
+ predicted_sql,
+ ground_truth,
+ db_path,
+ method
+ )
+ )
+ except FunctionTimedOut:
+ result = "timeout"
+ passed = False
+ sql_return = [[]]
+ except Exception as e:
+ result = f"error:{e}"
+ passed = False
+ sql_return = [[]]
+
+ return result, passed, sql_return \ No newline at end of file