import re import random import sqlite3 from loguru import logger from itertools import product from collections import defaultdict from typing import Tuple, List, Set from func_timeout import func_timeout, FunctionTimedOut def permute_tuple(element: Tuple, perm: Tuple) -> Tuple: assert len(element) == len(perm) return tuple([element[i] for i in perm]) def unorder_row(row: Tuple) -> Tuple: return tuple(sorted(row, key=lambda x: str(x) + str(type(x)))) # unorder each row in the table # [result_1 and result_2 has the same bag of unordered row] # is a necessary condition of # [result_1 and result_2 are equivalent in denotation] def quick_rej(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: s1 = [unorder_row(row) for row in result1] s2 = [unorder_row(row) for row in result2] if order_matters: return s1 == s2 else: return set(s1) == set(s2) # return whether two bag of relations are equivalent def multiset_eq(l1: List, l2: List) -> bool: if len(l1) != len(l2): return False d = defaultdict(int) for e in l1: d[e] = d[e] + 1 for e in l2: d[e] = d[e] - 1 if d[e] < 0: return False return True def get_constraint_permutation(tab1_sets_by_columns: List[Set], result2: List[Tuple]): num_cols = len(result2[0]) perm_constraints = [{i for i in range(num_cols)} for _ in range(num_cols)] if num_cols <= 3: return product(*perm_constraints) # we sample 20 rows and constrain the space of permutations for _ in range(20): random_tab2_row = random.choice(result2) for tab1_col in range(num_cols): for tab2_col in set(perm_constraints[tab1_col]): if random_tab2_row[tab2_col] not in tab1_sets_by_columns[tab1_col]: perm_constraints[tab1_col].remove(tab2_col) return product(*perm_constraints) # check whether two denotations are correct def result_eq(result1: List[Tuple], result2: List[Tuple], order_matters: bool) -> bool: if len(result1) == 0 and len(result2) == 0: return True # if length is not the same, then they are definitely different bag of rows if len(result1) != len(result2): return False num_cols = len(result1[0]) # if the results do not have the same number of columns, they are different if len(result2[0]) != num_cols: return False # unorder each row and compare whether the denotation is the same # this can already find most pair of denotations that are different if not quick_rej(result1, result2, order_matters): return False # the rest of the problem is in fact more complicated than one might think # we want to find a permutation of column order and a permutation of row order, # s.t. result_1 is the same as result_2 # we return true if we can find such column & row permutations # and false if we cannot tab1_sets_by_columns = [{row[i] for row in result1} for i in range(num_cols)] # on a high level, we enumerate all possible column permutations that might make result_1 == result_2 # we decrease the size of the column permutation space by the function get_constraint_permutation # if one of the permutation make result_1, result_2 equivalent, then they are equivalent for perm in get_constraint_permutation(tab1_sets_by_columns, result2): if len(perm) != len(set(perm)): continue if num_cols == 1: result2_perm = result2 else: result2_perm = [permute_tuple(element, perm) for element in result2] if order_matters: if result1 == result2_perm: return True else: # in fact the first condition must hold if the second condition holds # but the first is way more efficient implementation-wise # and we use it to quickly reject impossible candidates if set(result1) == set(result2_perm) and multiset_eq(result1, result2_perm): return True return False # postprocess the model predictions to avoid execution errors # e.g. removing spaces between ">" and "=" def postprocess(query: str) -> str: query = query.replace("> =", ">=").replace("< =", "<=").replace("! =", "!=") return query def find_detail(reason): patterns = [ "ambiguous column name", "no such column", "no such table", "no such function", "syntax error", ] for p in patterns: if p in reason: return p return "others" def execute_sql(predicted_sql,ground_truth, db_path, method): conn = sqlite3.connect(db_path) # Connect to the database cursor = conn.cursor() cursor.execute(predicted_sql) predicted_res = cursor.fetchall() cursor.execute(ground_truth) ground_truth_res = cursor.fetchall() result = "failed" passed = False if predicted_res is None: sql_return = [[]] else: sql_return = predicted_res if method == "set_match": if set(predicted_res) == set(ground_truth_res): result = "passed" passed = True return result, passed, sql_return elif method == "exact_match": order_matters = "orderby" in re.sub(r"\s+", ground_truth.lower(), "") if result_eq(predicted_res, ground_truth_res, order_matters = order_matters): result = "passed" passed = True else: passed = False return result, passed, sql_return else: logger.error(f"Unknown evaluation method: {method}") def check_correctness(predicted_sql,ground_truth, db_path, meta_time_out, method): try: result, passed, sql_return = func_timeout( meta_time_out, execute_sql, args = ( predicted_sql, ground_truth, db_path, method ) ) except FunctionTimedOut: result = "timeout" passed = False sql_return = [[]] except Exception as e: result = f"error:{e}" passed = False sql_return = [[]] return result, passed, sql_return