diff options
Diffstat (limited to 'Qwen2.5-Eval/evaluation/grader.py')
| -rwxr-xr-x | Qwen2.5-Eval/evaluation/grader.py | 395 |
1 files changed, 395 insertions, 0 deletions
diff --git a/Qwen2.5-Eval/evaluation/grader.py b/Qwen2.5-Eval/evaluation/grader.py new file mode 100755 index 0000000..1e18b37 --- /dev/null +++ b/Qwen2.5-Eval/evaluation/grader.py @@ -0,0 +1,395 @@ +""" +This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from: +- https://github.com/microsoft/ProphetNet/tree/master/CRITIC +- https://github.com/openai/prm800k +- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py +- https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py +""" + +import re +import regex +import multiprocessing +from math import isclose +from typing import Union +from collections import defaultdict + +from sympy import simplify, N +from sympy.parsing.sympy_parser import parse_expr +from sympy.parsing.latex import parse_latex +from latex2sympy2 import latex2sympy + +# from .parser import choice_answer_clean, strip_string +# from parser import choice_answer_clean + + +def choice_answer_clean(pred: str): + pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") + # Clean the answer based on the dataset + tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) + if tmp: + pred = tmp + else: + pred = [pred.strip().strip(".")] + pred = pred[-1] + # Remove the period at the end, again! + pred = pred.rstrip(".").rstrip("/") + return pred + + +def parse_digits(num): + num = regex.sub(",", "", str(num)) + try: + return float(num) + except: + if num.endswith("%"): + num = num[:-1] + if num.endswith("\\"): + num = num[:-1] + try: + return float(num) / 100 + except: + pass + return None + + +def is_digit(num): + # paired with parse_digits + return parse_digits(num) is not None + + +def str_to_pmatrix(input_str): + input_str = input_str.strip() + matrix_str = re.findall(r"\{.*,.*\}", input_str) + pmatrix_list = [] + + for m in matrix_str: + m = m.strip("{}") + pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}" + pmatrix_list.append(pmatrix) + + return ", ".join(pmatrix_list) + + +def math_equal( + prediction: Union[bool, float, str], + reference: Union[float, str], + include_percentage: bool = True, + is_close: bool = True, + timeout: bool = False, +) -> bool: + """ + Exact match of math if and only if: + 1. numerical equal: both can convert to float and are equal + 2. symbolic equal: both can convert to sympy expression and are equal + """ + # print("Judge:", prediction, reference) + if prediction is None or reference is None: + return False + if str(prediction.strip().lower()) == str(reference.strip().lower()): + return True + if ( + reference in ["A", "B", "C", "D", "E"] + and choice_answer_clean(prediction) == reference + ): + return True + + try: # 1. numerical equal + if is_digit(prediction) and is_digit(reference): + prediction = parse_digits(prediction) + reference = parse_digits(reference) + # number questions + if include_percentage: + gt_result = [reference / 100, reference, reference * 100] + else: + gt_result = [reference] + for item in gt_result: + try: + if is_close: + if numeric_equal(prediction, item): + return True + else: + if item == prediction: + return True + except Exception: + continue + return False + except: + pass + + if not prediction and prediction not in [0, False]: + return False + + # 2. symbolic equal + reference = str(reference).strip() + prediction = str(prediction).strip() + + ## pmatrix (amps) + if "pmatrix" in prediction and not "pmatrix" in reference: + reference = str_to_pmatrix(reference) + + ## deal with [], (), {} + pred_str, ref_str = prediction, reference + if ( + prediction.startswith("[") + and prediction.endswith("]") + and not reference.startswith("(") + ) or ( + prediction.startswith("(") + and prediction.endswith(")") + and not reference.startswith("[") + ): + pred_str = pred_str.strip("[]()") + ref_str = ref_str.strip("[]()") + for s in ["{", "}", "(", ")"]: + ref_str = ref_str.replace(s, "") + pred_str = pred_str.replace(s, "") + if pred_str.lower() == ref_str.lower(): + return True + + ## [a, b] vs. [c, d], return a==c and b==d + if ( + regex.match(r"(\(|\[).+(\)|\])", prediction) is not None + and regex.match(r"(\(|\[).+(\)|\])", reference) is not None + ): + pred_parts = prediction[1:-1].split(",") + ref_parts = reference[1:-1].split(",") + if len(pred_parts) == len(ref_parts): + if all( + [ + math_equal( + pred_parts[i], ref_parts[i], include_percentage, is_close + ) + for i in range(len(pred_parts)) + ] + ): + return True + if ( + ( + prediction.startswith("\\begin{pmatrix}") + or prediction.startswith("\\begin{bmatrix}") + ) + and ( + prediction.endswith("\\end{pmatrix}") + or prediction.endswith("\\end{bmatrix}") + ) + and ( + reference.startswith("\\begin{pmatrix}") + or reference.startswith("\\begin{bmatrix}") + ) + and ( + reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}") + ) + ): + pred_lines = [ + line.strip() + for line in prediction[ + len("\\begin{pmatrix}") : -len("\\end{pmatrix}") + ].split("\\\\") + if line.strip() + ] + ref_lines = [ + line.strip() + for line in reference[ + len("\\begin{pmatrix}") : -len("\\end{pmatrix}") + ].split("\\\\") + if line.strip() + ] + matched = True + if len(pred_lines) == len(ref_lines): + for pred_line, ref_line in zip(pred_lines, ref_lines): + pred_parts = pred_line.split("&") + ref_parts = ref_line.split("&") + if len(pred_parts) == len(ref_parts): + if not all( + [ + math_equal( + pred_parts[i], + ref_parts[i], + include_percentage, + is_close, + ) + for i in range(len(pred_parts)) + ] + ): + matched = False + break + else: + matched = False + if not matched: + break + else: + matched = False + if matched: + return True + + if prediction.count("=") == 1 and reference.count("=") == 1: + pred = prediction.split("=") + pred = f"{pred[0].strip()} - ({pred[1].strip()})" + ref = reference.split("=") + ref = f"{ref[0].strip()} - ({ref[1].strip()})" + if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref): + return True + elif ( + prediction.count("=") == 1 + and len(prediction.split("=")[0].strip()) <= 2 + and "=" not in reference + ): + if math_equal( + prediction.split("=")[1], reference, include_percentage, is_close + ): + return True + elif ( + reference.count("=") == 1 + and len(reference.split("=")[0].strip()) <= 2 + and "=" not in prediction + ): + if math_equal( + prediction, reference.split("=")[1], include_percentage, is_close + ): + return True + + # symbolic equal with sympy + if timeout: + if call_with_timeout(symbolic_equal_process, prediction, reference): + return True + else: + if symbolic_equal(prediction, reference): + return True + + return False + + +def math_equal_process(param): + return math_equal(param[-2], param[-1]) + + +def numeric_equal(prediction: float, reference: float): + # Note that relative tolerance has significant impact + # on the result of the synthesized GSM-Hard dataset + # if reference.is_integer(): + # return isclose(reference, round(prediction), abs_tol=1e-4) + # else: + # prediction = round(prediction, len(str(reference).split(".")[-1])) + return isclose(reference, prediction, rel_tol=1e-4) + + +def symbolic_equal(a, b): + def _parse(s): + for f in [parse_latex, parse_expr, latex2sympy]: + try: + return f(s.replace("\\\\", "\\")) + except: + try: + return f(s) + except: + pass + return s + + a = _parse(a) + b = _parse(b) + + # direct equal + try: + if str(a) == str(b) or a == b: + return True + except: + pass + + # simplify equal + try: + if a.equals(b) or simplify(a - b) == 0: + return True + except: + pass + + # equation equal + try: + if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)): + return True + except: + pass + + try: + if numeric_equal(float(N(a)), float(N(b))): + return True + except: + pass + + # matrix + try: + # if a and b are matrix + if a.shape == b.shape: + _a = a.applyfunc(lambda x: round(x, 3)) + _b = b.applyfunc(lambda x: round(x, 3)) + if _a.equals(_b): + return True + except: + pass + + return False + + +def symbolic_equal_process(a, b, output_queue): + result = symbolic_equal(a, b) + output_queue.put(result) + + +def call_with_timeout(func, *args, timeout=1, **kwargs): + output_queue = multiprocessing.Queue() + process_args = args + (output_queue,) + process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs) + process.start() + process.join(timeout) + + if process.is_alive(): + process.terminate() + process.join() + return False + + return output_queue.get() + +def _test_math_equal(): + # print(math_equal("0.0833333333333333", "\\frac{1}{12}")) + # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})")) + # print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True)) + # print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True)) + # print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True)) + + # pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}' + # gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})' + + # pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}' + # gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}' + + # pred = '-34x-45y+20z-100=0' + # gt = '34x+45y-20z+100=0' + + # pred = '\\frac{100}{3}' + # gt = '33.3' + + # pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}' + # gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})' + + # pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}' + # gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}' + + # pred = '(+5)(b+2)' + # gt = '(a+5)(b+2)' + + # pred = '\\frac{1+\\sqrt{5}}{2}' + # gt = '2' + + # pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4' + # pred = '1', gt = '1\\\\sqrt{19}' + + # pred = "(0.6,2.6667]" + # gt = "(\\frac{3}{5},\\frac{8}{3}]" + + gt = "x+2n+1" + pred = "x+1" + + print(math_equal(pred, gt, timeout=True)) + + +if __name__ == "__main__": + _test_math_equal() |
