From 947d9dfdf16ae37109898111a5caacae7377b96d Mon Sep 17 00:00:00 2001 From: = <=> Date: Wed, 4 Jun 2025 11:49:37 +0800 Subject: update code and kk eval --- code_eval/OpenCodeEval/eval/sql_test.py | 187 ++++++++++++++++++++++++++++++++ 1 file changed, 187 insertions(+) create mode 100644 code_eval/OpenCodeEval/eval/sql_test.py (limited to 'code_eval/OpenCodeEval/eval/sql_test.py') diff --git a/code_eval/OpenCodeEval/eval/sql_test.py b/code_eval/OpenCodeEval/eval/sql_test.py new file mode 100644 index 0000000..9b0724e --- /dev/null +++ b/code_eval/OpenCodeEval/eval/sql_test.py @@ -0,0 +1,187 @@ +import re +import random +import sqlite3 + +from loguru import logger +from itertools import product +from collections import defaultdict +from typing import Tuple, List, Set +from func_timeout import func_timeout, FunctionTimedOut + +def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: + assert len(element) == len(perm) + return tuple([element[i] for i in perm]) + + +def unorder_row(row: Tuple) -> Tuple: + return tuple(sorted(row, key=lambda x: str(x) + str(type(x)))) + + +# unorder each row in the table +# [result_1 and result_2 has the same bag of unordered row] +# is a necessary condition of +# [result_1 and result_2 are equivalent in denotation] +def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: + s1 = [unorder_row(row) for row in result1] + s2 = [unorder_row(row) for row in result2] + if order_matters: + return s1 == s2 + else: + return set(s1) == set(s2) + +# return whether two bag of relations are equivalent +def multiset_eq(l1: List, l2: List) -> bool: + if len(l1) != len(l2): + return False + d = defaultdict(int) + for e in l1: + d[e] = d[e] + 1 + for e in l2: + d[e] = d[e] - 1 + if d[e] < 0: + return False + return True + +def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]): + num_cols = len(result2[0]) + perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)] + if num_cols <= 3: + return product(*perm_constraints) + + # we sample 20 rows and constrain the space of permutations + for _ in range(20): + random_tab2_row = random.choice(result2) + + for tab1_col in range(num_cols): + for tab2_col in set(perm_constraints[tab1_col]): + if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]: + perm_constraints[tab1_col].remove(tab2_col) + return product(*perm_constraints) + +# check whether two denotations are correct +def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: + if len(result1) == 0 and len(result2) == 0: + return True + + # if length is not the same, then they are definitely different bag of rows + if len(result1) != len(result2): + return False + + num_cols = len(result1[0]) + + # if the results do not have the same number of columns, they are different + if len(result2[0]) != num_cols: + return False + + # unorder each row and compare whether the denotation is the same + # this can already find most pair of denotations that are different + if not quick_rej(result1, result2, order_matters): + return False + + # the rest of the problem is in fact more complicated than one might think + # we want to find a permutation of column order and a permutation of row order, + # s.t. result_1 is the same as result_2 + # we return true if we can find such column & row permutations + # and false if we cannot + tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)] + + # on a high level, we enumerate all possible column permutations that might make result_1 == result_2 + # we decrease the size of the column permutation space by the function get_constraint_permutation + # if one of the permutation make result_1, result_2 equivalent, then they are equivalent + for perm in get_constraint_permutation(tab1_sets_by_columns, result2): + if len(perm) != len(set(perm)): + continue + if num_cols == 1: + result2_perm = result2 + else: + result2_perm = [permute_tuple(element, perm) for element in result2] + if order_matters: + if result1 == result2_perm: + return True + else: + # in fact the first condition must hold if the second condition holds + # but the first is way more efficient implementation-wise + # and we use it to quickly reject impossible candidates + if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm): + return True + return False + +# postprocess the model predictions to avoid execution errors +# e.g. removing spaces between ">" and "=" +def postprocess(query: str) -> str: + query = query.replace("> =", ">=").replace("< =", "<=").replace("! =", "!=") + return query + +def find_detail(reason): + patterns = [ + "ambiguous column name", + "no such column", + "no such table", + "no such function", + "syntax error", + ] + + for p in patterns: + if p in reason: + return p + return "others" + +def execute_sql(predicted_sql,ground_truth, db_path, method): + conn = sqlite3.connect(db_path) + # Connect to the database + cursor = conn.cursor() + cursor.execute(predicted_sql) + predicted_res = cursor.fetchall() + cursor.execute(ground_truth) + ground_truth_res = cursor.fetchall() + result = "failed" + passed = False + + if predicted_res is None: + sql_return = [[]] + else: + sql_return = predicted_res + + if method == "set_match": + + if set(predicted_res) == set(ground_truth_res): + result = "passed" + passed = True + return result, passed, sql_return + + elif method == "exact_match": + + order_matters = "orderby" in re.sub(r"\s+", ground_truth.lower(), "") + if result_eq(predicted_res, ground_truth_res, order_matters = order_matters): + result = "passed" + passed = True + else: + passed = False + + return result, passed, sql_return + + else: + logger.error(f"Unknown evaluation method: {method}") + +def check_correctness(predicted_sql,ground_truth, db_path, meta_time_out, method): + try: + result, passed, sql_return = func_timeout( + meta_time_out, + execute_sql, + args = ( + predicted_sql, + ground_truth, + db_path, + method + ) + ) + except FunctionTimedOut: + result = "timeout" + passed = False + sql_return = [[]] + except Exception as e: + result = f"error:{e}" + passed = False + sql_return = [[]] + + return result, passed, sql_return \ No newline at end of file -- cgit v1.2.3