diff options
Diffstat (limited to 'Qwen2.5-Eval/evaluation/python_executor.py')
| -rwxr-xr-x | Qwen2.5-Eval/evaluation/python_executor.py | 230 |
1 files changed, 230 insertions, 0 deletions
diff --git a/Qwen2.5-Eval/evaluation/python_executor.py b/Qwen2.5-Eval/evaluation/python_executor.py new file mode 100755 index 0000000..4e8f600 --- /dev/null +++ b/Qwen2.5-Eval/evaluation/python_executor.py @@ -0,0 +1,230 @@ +import os +import io +import regex +import pickle +import traceback +import copy +import datetime +import dateutil.relativedelta +import multiprocess +from multiprocess import Pool +from typing import Any, Dict, Optional +from pebble import ProcessPool +from tqdm import tqdm +from concurrent.futures import TimeoutError +from functools import partial +from timeout_decorator import timeout +from contextlib import redirect_stdout + + +class GenericRuntime: + GLOBAL_DICT = {} + LOCAL_DICT = None + HEADERS = [] + def __init__(self): + self._global_vars = copy.copy(self.GLOBAL_DICT) + self._local_vars = copy.copy(self.LOCAL_DICT) if self.LOCAL_DICT else None + + for c in self.HEADERS: + self.exec_code(c) + + def exec_code(self, code_piece: str) -> None: + if regex.search(r'(\s|^)?input\(', code_piece): + # regex.search(r'(\s|^)?os.', code_piece): + raise RuntimeError() + exec(code_piece, self._global_vars) + + # TODO: use: https://github.com/shroominic/codebox-api + # @high safe exec in sandbox + # byte_code = compile_restricted( + # code_piece, + # filename='<inline code>', + # mode='exec' + # ) + # print("global vars:", self._global_vars) + # _print_ = PrintCollector + # exec(byte_code, {'__builtins__': utility_builtins}, None) + + def eval_code(self, expr: str) -> Any: + return eval(expr, self._global_vars) + + def inject(self, var_dict: Dict[str, Any]) -> None: + for k, v in var_dict.items(): + self._global_vars[k] = v + + @property + def answer(self): + return self._global_vars['answer'] + +class DateRuntime(GenericRuntime): + GLOBAL_DICT = { + 'datetime': datetime.datetime, + 'timedelta': dateutil.relativedelta.relativedelta, + 'relativedelta': dateutil.relativedelta.relativedelta + } + + +class CustomDict(dict): + def __iter__(self): + return list(super().__iter__()).__iter__() + +class ColorObjectRuntime(GenericRuntime): + GLOBAL_DICT = {'dict': CustomDict} + + +class PythonExecutor: + def __init__( + self, + runtime: Optional[Any] = None, + get_answer_symbol: Optional[str] = None, + get_answer_expr: Optional[str] = None, + get_answer_from_stdout: bool = False, + timeout_length: int = 5, + ) -> None: + self.runtime = runtime if runtime else GenericRuntime() + self.answer_symbol = get_answer_symbol + self.answer_expr = get_answer_expr + self.get_answer_from_stdout = get_answer_from_stdout + self.pool = Pool(multiprocess.cpu_count()) + self.timeout_length = timeout_length + + def process_generation_to_code(self, gens: str): + return [g.strip().split('\n') for g in gens] + + @staticmethod + def execute( + code, + get_answer_from_stdout = None, + runtime = None, + answer_symbol = None, + answer_expr = None, + timeout_length = 10, + auto_mode=False + ): + try: + if auto_mode: + if "print(" in code[-1]: + program_io = io.StringIO() + with redirect_stdout(program_io): + timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) + program_io.seek(0) + result = program_io.read() + else: + print(code) + timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) + result = timeout(timeout_length)(runtime.eval_code)(code[-1]) + else: + if get_answer_from_stdout: + program_io = io.StringIO() + with redirect_stdout(program_io): + timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) + program_io.seek(0) + result = program_io.read() + elif answer_symbol: + timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) + result = runtime._global_vars[answer_symbol] + elif answer_expr: + timeout(timeout_length)(runtime.exec_code)('\n'.join(code)) + result = timeout(timeout_length)(runtime.eval_code)(answer_expr) + else: + timeout(timeout_length)(runtime.exec_code)('\n'.join(code[:-1])) + result = timeout(timeout_length)(runtime.eval_code)(code[-1]) + report = "Done" + str(result) + pickle.dumps(result) # serialization check + except: + result = '' + report = traceback.format_exc().split('\n')[-2] + return result, report + + def apply(self, code): + return self.batch_apply([code])[0] + + @staticmethod + def truncate(s, max_length=400): + half = max_length // 2 + if len(s) > max_length: + s = s[:half] + "..." + s[-half:] + return s + + def batch_apply(self, batch_code): + all_code_snippets = self.process_generation_to_code(batch_code) + + timeout_cnt = 0 + all_exec_results = [] + # with ProcessPool(max_workers=min(len(all_code_snippets), os.cpu_count())) as pool: + with ProcessPool(max_workers=min(len(all_code_snippets), 1)) as pool: + executor = partial( + self.execute, + get_answer_from_stdout=self.get_answer_from_stdout, + runtime=self.runtime, + answer_symbol=self.answer_symbol, + answer_expr=self.answer_expr, + timeout_length=self.timeout_length, # this timeout not work + auto_mode=True + ) + future = pool.map(executor, all_code_snippets, timeout=self.timeout_length) + iterator = future.result() + + if len(all_code_snippets) > 100: + progress_bar = tqdm(total=len(all_code_snippets), desc="Execute") + else: + progress_bar = None + + while True: + try: + result = next(iterator) + all_exec_results.append(result) + except StopIteration: + break + except TimeoutError as error: + print(error) + all_exec_results.append(("", "Timeout Error")) + timeout_cnt += 1 + except Exception as error: + print(error) + exit() + if progress_bar is not None: + progress_bar.update(1) + + if progress_bar is not None: + progress_bar.close() + + batch_results = [] + for code, (res, report) in zip(all_code_snippets, all_exec_results): + # post processing + res, report = str(res).strip(), str(report).strip() + res, report = self.truncate(res), self.truncate(report) + batch_results.append((res, report)) + return batch_results + + +def _test(): + batch_code = [ + """ +from sympy import Matrix + +def null_space_basis(): + # Define the matrix + A = Matrix([[3, 3, -1, -6], [9, -1, -8, -1], [7, 4, -2, -9]]) + + # Compute the basis for the null space + basis = A.nullspace() + + # Round the elements of the basis vectors to three decimal places + basis_rounded = [v.evalf(3) for v in basis] + + return basis_rounded + +result = null_space_basis() +print(result) + """ + ] + + executor = PythonExecutor(get_answer_from_stdout=True) + predictions = executor.apply(batch_code[0]) + print(predictions) + + +if __name__ == '__main__': + _test()
\ No newline at end of file |
