summaryrefslogtreecommitdiff
path: root/Qwen2.5-Eval/evaluation/grader.py
diff options
context:
space:
mode:
authorzitian-gao <zitian.gao@outlook.com>2025-05-27 16:45:31 +0800
committerzitian-gao <zitian.gao@outlook.com>2025-05-27 16:45:31 +0800
commit7c792461c8e4e4f1f8734fed143630c74e76b27f (patch)
treecf6341ff9f2727424751da7a11a3bea6c39015bb /Qwen2.5-Eval/evaluation/grader.py
parent16815c8c5ec263c4bd1a0af60030c1c0efa1421e (diff)
init eval
Diffstat (limited to 'Qwen2.5-Eval/evaluation/grader.py')
-rwxr-xr-xQwen2.5-Eval/evaluation/grader.py395
1 files changed, 395 insertions, 0 deletions
diff --git a/Qwen2.5-Eval/evaluation/grader.py b/Qwen2.5-Eval/evaluation/grader.py
new file mode 100755
index 0000000..1e18b37
--- /dev/null
+++ b/Qwen2.5-Eval/evaluation/grader.py
@@ -0,0 +1,395 @@
+"""
+This logic is largely copied from the Hendrycks' MATH release (math_equivalence), and borrowed from:
+- https://github.com/microsoft/ProphetNet/tree/master/CRITIC
+- https://github.com/openai/prm800k
+- https://github.com/microsoft/ToRA/blob/main/src/eval/grader.py
+- https://github.com/deepseek-ai/DeepSeek-Math/blob/main/evaluation/eval/eval_utils.py
+"""
+
+import re
+import regex
+import multiprocessing
+from math import isclose
+from typing import Union
+from collections import defaultdict
+
+from sympy import simplify, N
+from sympy.parsing.sympy_parser import parse_expr
+from sympy.parsing.latex import parse_latex
+from latex2sympy2 import latex2sympy
+
+# from .parser import choice_answer_clean, strip_string
+# from parser import choice_answer_clean
+
+
+def choice_answer_clean(pred: str):
+ pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":")
+ # Clean the answer based on the dataset
+ tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper())
+ if tmp:
+ pred = tmp
+ else:
+ pred = [pred.strip().strip(".")]
+ pred = pred[-1]
+ # Remove the period at the end, again!
+ pred = pred.rstrip(".").rstrip("/")
+ return pred
+
+
+def parse_digits(num):
+ num = regex.sub(",", "", str(num))
+ try:
+ return float(num)
+ except:
+ if num.endswith("%"):
+ num = num[:-1]
+ if num.endswith("\\"):
+ num = num[:-1]
+ try:
+ return float(num) / 100
+ except:
+ pass
+ return None
+
+
+def is_digit(num):
+ # paired with parse_digits
+ return parse_digits(num) is not None
+
+
+def str_to_pmatrix(input_str):
+ input_str = input_str.strip()
+ matrix_str = re.findall(r"\{.*,.*\}", input_str)
+ pmatrix_list = []
+
+ for m in matrix_str:
+ m = m.strip("{}")
+ pmatrix = r"\begin{pmatrix}" + m.replace(",", "\\") + r"\end{pmatrix}"
+ pmatrix_list.append(pmatrix)
+
+ return ", ".join(pmatrix_list)
+
+
+def math_equal(
+ prediction: Union[bool, float, str],
+ reference: Union[float, str],
+ include_percentage: bool = True,
+ is_close: bool = True,
+ timeout: bool = False,
+) -> bool:
+ """
+ Exact match of math if and only if:
+ 1. numerical equal: both can convert to float and are equal
+ 2. symbolic equal: both can convert to sympy expression and are equal
+ """
+ # print("Judge:", prediction, reference)
+ if prediction is None or reference is None:
+ return False
+ if str(prediction.strip().lower()) == str(reference.strip().lower()):
+ return True
+ if (
+ reference in ["A", "B", "C", "D", "E"]
+ and choice_answer_clean(prediction) == reference
+ ):
+ return True
+
+ try: # 1. numerical equal
+ if is_digit(prediction) and is_digit(reference):
+ prediction = parse_digits(prediction)
+ reference = parse_digits(reference)
+ # number questions
+ if include_percentage:
+ gt_result = [reference / 100, reference, reference * 100]
+ else:
+ gt_result = [reference]
+ for item in gt_result:
+ try:
+ if is_close:
+ if numeric_equal(prediction, item):
+ return True
+ else:
+ if item == prediction:
+ return True
+ except Exception:
+ continue
+ return False
+ except:
+ pass
+
+ if not prediction and prediction not in [0, False]:
+ return False
+
+ # 2. symbolic equal
+ reference = str(reference).strip()
+ prediction = str(prediction).strip()
+
+ ## pmatrix (amps)
+ if "pmatrix" in prediction and not "pmatrix" in reference:
+ reference = str_to_pmatrix(reference)
+
+ ## deal with [], (), {}
+ pred_str, ref_str = prediction, reference
+ if (
+ prediction.startswith("[")
+ and prediction.endswith("]")
+ and not reference.startswith("(")
+ ) or (
+ prediction.startswith("(")
+ and prediction.endswith(")")
+ and not reference.startswith("[")
+ ):
+ pred_str = pred_str.strip("[]()")
+ ref_str = ref_str.strip("[]()")
+ for s in ["{", "}", "(", ")"]:
+ ref_str = ref_str.replace(s, "")
+ pred_str = pred_str.replace(s, "")
+ if pred_str.lower() == ref_str.lower():
+ return True
+
+ ## [a, b] vs. [c, d], return a==c and b==d
+ if (
+ regex.match(r"(\(|\[).+(\)|\])", prediction) is not None
+ and regex.match(r"(\(|\[).+(\)|\])", reference) is not None
+ ):
+ pred_parts = prediction[1:-1].split(",")
+ ref_parts = reference[1:-1].split(",")
+ if len(pred_parts) == len(ref_parts):
+ if all(
+ [
+ math_equal(
+ pred_parts[i], ref_parts[i], include_percentage, is_close
+ )
+ for i in range(len(pred_parts))
+ ]
+ ):
+ return True
+ if (
+ (
+ prediction.startswith("\\begin{pmatrix}")
+ or prediction.startswith("\\begin{bmatrix}")
+ )
+ and (
+ prediction.endswith("\\end{pmatrix}")
+ or prediction.endswith("\\end{bmatrix}")
+ )
+ and (
+ reference.startswith("\\begin{pmatrix}")
+ or reference.startswith("\\begin{bmatrix}")
+ )
+ and (
+ reference.endswith("\\end{pmatrix}") or reference.endswith("\\end{bmatrix}")
+ )
+ ):
+ pred_lines = [
+ line.strip()
+ for line in prediction[
+ len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
+ ].split("\\\\")
+ if line.strip()
+ ]
+ ref_lines = [
+ line.strip()
+ for line in reference[
+ len("\\begin{pmatrix}") : -len("\\end{pmatrix}")
+ ].split("\\\\")
+ if line.strip()
+ ]
+ matched = True
+ if len(pred_lines) == len(ref_lines):
+ for pred_line, ref_line in zip(pred_lines, ref_lines):
+ pred_parts = pred_line.split("&")
+ ref_parts = ref_line.split("&")
+ if len(pred_parts) == len(ref_parts):
+ if not all(
+ [
+ math_equal(
+ pred_parts[i],
+ ref_parts[i],
+ include_percentage,
+ is_close,
+ )
+ for i in range(len(pred_parts))
+ ]
+ ):
+ matched = False
+ break
+ else:
+ matched = False
+ if not matched:
+ break
+ else:
+ matched = False
+ if matched:
+ return True
+
+ if prediction.count("=") == 1 and reference.count("=") == 1:
+ pred = prediction.split("=")
+ pred = f"{pred[0].strip()} - ({pred[1].strip()})"
+ ref = reference.split("=")
+ ref = f"{ref[0].strip()} - ({ref[1].strip()})"
+ if symbolic_equal(pred, ref) or symbolic_equal(f"-({pred})", ref):
+ return True
+ elif (
+ prediction.count("=") == 1
+ and len(prediction.split("=")[0].strip()) <= 2
+ and "=" not in reference
+ ):
+ if math_equal(
+ prediction.split("=")[1], reference, include_percentage, is_close
+ ):
+ return True
+ elif (
+ reference.count("=") == 1
+ and len(reference.split("=")[0].strip()) <= 2
+ and "=" not in prediction
+ ):
+ if math_equal(
+ prediction, reference.split("=")[1], include_percentage, is_close
+ ):
+ return True
+
+ # symbolic equal with sympy
+ if timeout:
+ if call_with_timeout(symbolic_equal_process, prediction, reference):
+ return True
+ else:
+ if symbolic_equal(prediction, reference):
+ return True
+
+ return False
+
+
+def math_equal_process(param):
+ return math_equal(param[-2], param[-1])
+
+
+def numeric_equal(prediction: float, reference: float):
+ # Note that relative tolerance has significant impact
+ # on the result of the synthesized GSM-Hard dataset
+ # if reference.is_integer():
+ # return isclose(reference, round(prediction), abs_tol=1e-4)
+ # else:
+ # prediction = round(prediction, len(str(reference).split(".")[-1]))
+ return isclose(reference, prediction, rel_tol=1e-4)
+
+
+def symbolic_equal(a, b):
+ def _parse(s):
+ for f in [parse_latex, parse_expr, latex2sympy]:
+ try:
+ return f(s.replace("\\\\", "\\"))
+ except:
+ try:
+ return f(s)
+ except:
+ pass
+ return s
+
+ a = _parse(a)
+ b = _parse(b)
+
+ # direct equal
+ try:
+ if str(a) == str(b) or a == b:
+ return True
+ except:
+ pass
+
+ # simplify equal
+ try:
+ if a.equals(b) or simplify(a - b) == 0:
+ return True
+ except:
+ pass
+
+ # equation equal
+ try:
+ if (abs(a.lhs - a.rhs)).equals(abs(b.lhs - b.rhs)):
+ return True
+ except:
+ pass
+
+ try:
+ if numeric_equal(float(N(a)), float(N(b))):
+ return True
+ except:
+ pass
+
+ # matrix
+ try:
+ # if a and b are matrix
+ if a.shape == b.shape:
+ _a = a.applyfunc(lambda x: round(x, 3))
+ _b = b.applyfunc(lambda x: round(x, 3))
+ if _a.equals(_b):
+ return True
+ except:
+ pass
+
+ return False
+
+
+def symbolic_equal_process(a, b, output_queue):
+ result = symbolic_equal(a, b)
+ output_queue.put(result)
+
+
+def call_with_timeout(func, *args, timeout=1, **kwargs):
+ output_queue = multiprocessing.Queue()
+ process_args = args + (output_queue,)
+ process = multiprocessing.Process(target=func, args=process_args, kwargs=kwargs)
+ process.start()
+ process.join(timeout)
+
+ if process.is_alive():
+ process.terminate()
+ process.join()
+ return False
+
+ return output_queue.get()
+
+def _test_math_equal():
+ # print(math_equal("0.0833333333333333", "\\frac{1}{12}"))
+ # print(math_equal("(1,4.5)", "(1,\\frac{9}{2})"))
+ # print(math_equal("\\frac{x}{7}+\\frac{2}{7}", "\\frac{x+2}{7}", timeout=True))
+ # print(math_equal("\\sec^2(y)", "\\tan^2(y)+1", timeout=True))
+ # print(math_equal("\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\end{pmatrix}", "(\\begin{pmatrix}-\\frac{7}{4}&-2\\\\4&\\frac{1}{4}\\\\\\end{pmatrix})", timeout=True))
+
+ # pred = '\\begin{pmatrix}\\frac{1}{3x^{2/3}}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\end{pmatrix}'
+ # gt = '(\\begin{pmatrix}\\frac{1}{3\\sqrt[3]{x}^2}&0&0\\\\0&1&0\\\\-\\sin(x)&0&0\\\\\\end{pmatrix})'
+
+ # pred= '-\\frac{8x^2}{9(x^2-2)^{5/3}}+\\frac{2}{3(x^2-2)^{2/3}}'
+ # gt= '-\\frac{2(x^2+6)}{9(x^2-2)\\sqrt[3]{x^2-2}^2}'
+
+ # pred = '-34x-45y+20z-100=0'
+ # gt = '34x+45y-20z+100=0'
+
+ # pred = '\\frac{100}{3}'
+ # gt = '33.3'
+
+ # pred = '\\begin{pmatrix}0.290243531202435\\\\0.196008371385084\\\\-0.186381278538813\\end{pmatrix}'
+ # gt = '(\\begin{pmatrix}0.29\\\\0.196\\\\-0.186\\\\\\end{pmatrix})'
+
+ # pred = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{2\\sqrt{33}+15}'
+ # gt = '\\frac{\\sqrt{\\sqrt{11}+\\sqrt{194}}}{15+2\\sqrt{33}}'
+
+ # pred = '(+5)(b+2)'
+ # gt = '(a+5)(b+2)'
+
+ # pred = '\\frac{1+\\sqrt{5}}{2}'
+ # gt = '2'
+
+ # pred = '\\frac{34}{16}+\\frac{\\sqrt{1358}}{16}', gt = '4'
+ # pred = '1', gt = '1\\\\sqrt{19}'
+
+ # pred = "(0.6,2.6667]"
+ # gt = "(\\frac{3}{5},\\frac{8}{3}]"
+
+ gt = "x+2n+1"
+ pred = "x+1"
+
+ print(math_equal(pred, gt, timeout=True))
+
+
+if __name__ == "__main__":
+ _test_math_equal()