diff options
| author | zitian-gao <zitian.gao@outlook.com> | 2025-05-27 16:45:31 +0800 |
|---|---|---|
| committer | zitian-gao <zitian.gao@outlook.com> | 2025-05-27 16:45:31 +0800 |
| commit | 7c792461c8e4e4f1f8734fed143630c74e76b27f (patch) | |
| tree | cf6341ff9f2727424751da7a11a3bea6c39015bb /Qwen2.5-Eval/evaluation/parser.py | |
| parent | 16815c8c5ec263c4bd1a0af60030c1c0efa1421e (diff) | |
init eval
Diffstat (limited to 'Qwen2.5-Eval/evaluation/parser.py')
| -rwxr-xr-x | Qwen2.5-Eval/evaluation/parser.py | 769 |
1 files changed, 769 insertions, 0 deletions
diff --git a/Qwen2.5-Eval/evaluation/parser.py b/Qwen2.5-Eval/evaluation/parser.py new file mode 100755 index 0000000..7b4ad76 --- /dev/null +++ b/Qwen2.5-Eval/evaluation/parser.py @@ -0,0 +1,769 @@ +# rm -rf parser.py; vim parser.py +import random +import regex +import re +import sympy +from latex2sympy2 import latex2sympy +from typing import TypeVar, Iterable, List, Union, Any, Dict +from word2number import w2n +from utils import * + + +def _fix_fracs(string): + substrs = string.split("\\frac") + new_str = substrs[0] + if len(substrs) > 1: + substrs = substrs[1:] + for substr in substrs: + new_str += "\\frac" + if len(substr) > 0 and substr[0] == "{": + new_str += substr + else: + try: + assert len(substr) >= 2 + except: + return string + a = substr[0] + b = substr[1] + if b != "{": + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}{" + b + "}" + post_substr + else: + new_str += "{" + a + "}{" + b + "}" + else: + if len(substr) > 2: + post_substr = substr[2:] + new_str += "{" + a + "}" + b + post_substr + else: + new_str += "{" + a + "}" + b + string = new_str + return string + + +def _fix_a_slash_b(string): + if len(string.split("/")) != 2: + return string + a = string.split("/")[0] + b = string.split("/")[1] + try: + if "sqrt" not in a: + a = int(a) + if "sqrt" not in b: + b = int(b) + assert string == "{}/{}".format(a, b) + new_string = "\\frac{" + str(a) + "}{" + str(b) + "}" + return new_string + except: + return string + + +def _fix_sqrt(string): + _string = re.sub(r"\\sqrt(\w+)", r"\\sqrt{\1}", string) + return _string + + +def convert_word_number(text: str) -> str: + try: + text = str(w2n.word_to_num(text)) + except: + pass + return text + + +# units mainly from MathQA +unit_texts = [ + "east", + "degree", + "mph", + "kmph", + "ft", + "m sqaure", + " m east", + "sq m", + "deg", + "mile", + "q .", + "monkey", + "prime", + "ratio", + "profit of rs", + "rd", + "o", + "gm", + "p . m", + "lb", + "tile", + "per", + "dm", + "lt", + "gain", + "ab", + "way", + "west", + "a .", + "b .", + "c .", + "d .", + "e .", + "f .", + "g .", + "h .", + "t", + "a", + "h", + "no change", + "men", + "soldier", + "pie", + "bc", + "excess", + "st", + "inches", + "noon", + "percent", + "by", + "gal", + "kmh", + "c", + "acre", + "rise", + "a . m", + "th", + "π r 2", + "sq", + "mark", + "l", + "toy", + "coin", + "sq . m", + "gallon", + "° f", + "profit", + "minw", + "yr", + "women", + "feet", + "am", + "pm", + "hr", + "cu cm", + "square", + "v â € ™", + "are", + "rupee", + "rounds", + "cubic", + "cc", + "mtr", + "s", + "ohm", + "number", + "kmph", + "day", + "hour", + "minute", + "min", + "second", + "man", + "woman", + "sec", + "cube", + "mt", + "sq inch", + "mp", + "∏ cm ³", + "hectare", + "more", + "sec", + "unit", + "cu . m", + "cm 2", + "rs .", + "rs", + "kg", + "g", + "month", + "km", + "m", + "cm", + "mm", + "apple", + "liter", + "loss", + "yard", + "pure", + "year", + "increase", + "decrease", + "d", + "less", + "Surface", + "litre", + "pi sq m", + "s .", + "metre", + "meter", + "inch", +] + +unit_texts.extend([t + "s" for t in unit_texts]) + + +def strip_string(string, skip_unit=False): + string = str(string).strip() + # linebreaks + string = string.replace("\n", "") + + # right "." + string = string.rstrip(".") + + # remove inverse spaces + # replace \\ with \ + string = string.replace("\\!", "") + # string = string.replace("\\ ", "") + # string = string.replace("\\\\", "\\") + + # matrix + string = re.sub(r"\\begin\{array\}\{.*?\}", r"\\begin{pmatrix}", string) + string = re.sub(r"\\end\{array\}", r"\\end{pmatrix}", string) + string = string.replace("bmatrix", "pmatrix") + + # replace tfrac and dfrac with frac + string = string.replace("tfrac", "frac") + string = string.replace("dfrac", "frac") + string = ( + string.replace("\\neq", "\\ne") + .replace("\\leq", "\\le") + .replace("\\geq", "\\ge") + ) + + # remove \left and \right + string = string.replace("\\left", "") + string = string.replace("\\right", "") + string = string.replace("\\{", "{") + string = string.replace("\\}", "}") + + # Remove unit: miles, dollars if after is not none + _string = re.sub(r"\\text{.*?}$", "", string).strip() + if _string != "" and _string != string: + # print("Warning: unit not removed: '{}' -> '{}'".format(string, _string)) + string = _string + + if not skip_unit: + # Remove unit: texts + for _ in range(2): + for unit_text in unit_texts: + # use regex, the prefix should be either the start of the string or a non-alphanumeric character + # the suffix should be either the end of the string or a non-alphanumeric character + _string = re.sub(r"(^|\W)" + unit_text + r"($|\W)", r"\1\2", string) + if _string != "": + string = _string + + # Remove circ (degrees) + string = string.replace("^{\\circ}", "") + string = string.replace("^\\circ", "") + + # remove dollar signs + string = string.replace("\\$", "") + string = string.replace("$", "") + string = string.replace("\\(", "").replace("\\)", "") + + # convert word number to digit + string = convert_word_number(string) + + # replace "\\text{...}" to "..." + string = re.sub(r"\\text\{(.*?)\}", r"\1", string) + for key in ["x=", "y=", "z=", "x\\in", "y\\in", "z\\in", "x\\to", "y\\to", "z\\to"]: + string = string.replace(key, "") + string = string.replace("\\emptyset", r"{}") + string = string.replace("(-\\infty,\\infty)", "\\mathbb{R}") + + # remove percentage + string = string.replace("\\%", "") + string = string.replace("\%", "") + string = string.replace("%", "") + + # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string + string = string.replace(" .", " 0.") + string = string.replace("{.", "{0.") + + # cdot + # string = string.replace("\\cdot", "") + if ( + string.startswith("{") + and string.endswith("}") + and string.isalnum() + or string.startswith("(") + and string.endswith(")") + and string.isalnum() + or string.startswith("[") + and string.endswith("]") + and string.isalnum() + ): + string = string[1:-1] + + # inf + string = string.replace("infinity", "\\infty") + if "\\infty" not in string: + string = string.replace("inf", "\\infty") + string = string.replace("+\\inity", "\\infty") + + # and + string = string.replace("and", "") + string = string.replace("\\mathbf", "") + + # use regex to remove \mbox{...} + string = re.sub(r"\\mbox{.*?}", "", string) + + # quote + string.replace("'", "") + string.replace('"', "") + + # i, j + if "j" in string and "i" not in string: + string = string.replace("j", "i") + + # replace a.000b where b is not number or b is end, with ab, use regex + string = re.sub(r"(\d+)\.0*([^\d])", r"\1\2", string) + string = re.sub(r"(\d+)\.0*$", r"\1", string) + + # if empty, return empty string + if len(string) == 0: + return string + if string[0] == ".": + string = "0" + string + + # to consider: get rid of e.g. "k = " or "q = " at beginning + if len(string.split("=")) == 2: + if len(string.split("=")[0]) <= 2: + string = string.split("=")[1] + + string = _fix_sqrt(string) + string = string.replace(" ", "") + + # \frac1b or \frac12 --> \frac{1}{b} and \frac{1}{2}, etc. Even works with \frac1{72} (but not \frac{72}1). Also does a/b --> \\frac{a}{b} + string = _fix_fracs(string) + + # NOTE: X/Y changed to \frac{X}{Y} in dataset, but in simple cases fix in case the model output is X/Y + string = _fix_a_slash_b(string) + + return string + + +def extract_multi_choice_answer(pred_str): + # TODO: SFT models + if "Problem:" in pred_str: + pred_str = pred_str.split("Problem:", 1)[0] + pred_str = pred_str.replace("choice is", "answer is") + patt = regex.search(r"answer is \(?(?P<ans>[abcde])\)?", pred_str.lower()) + if patt is not None: + return patt.group("ans").upper() + return "placeholder" + + +direct_answer_trigger_for_fewshot = ("choice is", "answer is") + + +def choice_answer_clean(pred: str): + pred = pred.strip("\n") + + # Determine if this is ICL, if so, use \n\n to split the first chunk. + ICL = False + for trigger in direct_answer_trigger_for_fewshot: + if pred.count(trigger) > 1: + ICL = True + if ICL: + pred = pred.split("\n\n")[0] + + # Split the trigger to find the answer. + preds = re.split("|".join(direct_answer_trigger_for_fewshot), pred) + if len(preds) > 1: + answer_flag = True + pred = preds[-1] + else: + answer_flag = False + + pred = pred.strip("\n").rstrip(".").rstrip("/").strip(" ").lstrip(":") + + # Clean the answer based on the dataset + tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) + if tmp: + pred = tmp + else: + pred = [pred.strip().strip(".")] + + if len(pred) == 0: + pred = "" + else: + if answer_flag: + # choose the first element in list ... + pred = pred[0] + else: + # choose the last e + pred = pred[-1] + + # Remove the period at the end, again! + pred = pred.rstrip(".").rstrip("/") + + return pred + + +def find_box(pred_str: str): + ans = pred_str.split("boxed")[-1] + if not ans: + return "" + if ans[0] == "{": + stack = 1 + a = "" + for c in ans[1:]: + if c == "{": + stack += 1 + a += c + elif c == "}": + stack -= 1 + if stack == 0: + break + a += c + else: + a += c + else: + a = ans.split("$")[0].strip() + return a + + +def clean_units(pred_str: str): + """Clean the units in the number.""" + + def convert_pi_to_number(code_string): + code_string = code_string.replace("\\pi", "π") + # Replace \pi or π not preceded by a digit or } with 3.14 + code_string = re.sub(r"(?<![\d}])\\?π", "3.14", code_string) + # Replace instances where π is preceded by a digit but without a multiplication symbol, e.g., "3π" -> "3*3.14" + code_string = re.sub(r"(\d)(\\?π)", r"\1*3.14", code_string) + # Handle cases where π is within braces or followed by a multiplication symbol + # This replaces "{π}" with "3.14" directly and "3*π" with "3*3.14" + code_string = re.sub(r"\{(\\?π)\}", "3.14", code_string) + code_string = re.sub(r"\*(\\?π)", "*3.14", code_string) + return code_string + + pred_str = convert_pi_to_number(pred_str) + pred_str = pred_str.replace("%", "/100") + pred_str = pred_str.replace("$", "") + pred_str = pred_str.replace("¥", "") + pred_str = pred_str.replace("°C", "") + pred_str = pred_str.replace(" C", "") + pred_str = pred_str.replace("°", "") + return pred_str + + +def extract_theoremqa_answer(pred: str, answer_flag: bool = True): + if any([option in pred.lower() for option in ["yes", "true"]]): + pred = "True" + elif any([option in pred.lower() for option in ["no", "false"]]): + pred = "False" + elif any( + [ + option in pred.lower() + for option in ["(a)", "(b)", "(c)", "(d)", "(e)", "(f)"] + ] + ): + pass + else: + # Some of the models somehow get used to boxed output from pre-training + if "boxed" in pred: + pred = find_box(pred) + + if answer_flag: + # Extract the numbers out of the string + pred = pred.split("=")[-1].strip() + pred = clean_units(pred) + try: + tmp = str(latex2sympy(pred)) + pred = str(eval(tmp)) + except Exception: + if re.match(r"-?[\d\.]+\s\D+$", pred): + pred = pred.split(" ")[0] + elif re.match(r"-?[\d\.]+\s[^\s]+$", pred): + pred = pred.split(" ")[0] + else: + # desparate search over the last number + preds = re.findall(r"-?\d*\.?\d+", pred) + if len(preds) >= 1: + pred = preds[-1] + else: + pred = "" + + return pred + + +def extract_answer(pred_str, data_name, use_last_number=True): + pred_str = pred_str.replace("\u043a\u0438", "") + if data_name in ["mmlu_stem", "sat_math", "aqua", "gaokao2023"]: + # TODO check multiple choice + return choice_answer_clean(pred_str) + + if "final answer is $" in pred_str and "$. I hope" in pred_str: + # minerva_math + tmp = pred_str.split("final answer is $", 1)[1] + pred = tmp.split("$. I hope", 1)[0].strip() + elif "boxed" in pred_str: + ans = pred_str.split("boxed")[-1] + if len(ans) == 0: + return "" + elif ans[0] == "{": + stack = 1 + a = "" + for c in ans[1:]: + if c == "{": + stack += 1 + a += c + elif c == "}": + stack -= 1 + if stack == 0: + break + a += c + else: + a += c + else: + a = ans.split("$")[0].strip() + pred = a + elif "he answer is" in pred_str: + pred = pred_str.split("he answer is")[-1].strip() + elif "final answer is" in pred_str: + pred = pred_str.split("final answer is")[-1].strip() + elif "答案是" in pred_str: + # Handle Chinese few-shot multiple choice problem answer extraction + pred = pred_str.split("答案是")[1].strip().split("\n\n")[0].strip() + else: # use the last number + if use_last_number: + pattern = "-?\d*\.?\d+" + pred = re.findall(pattern, pred_str.replace(",", "")) + if len(pred) >= 1: + pred = pred[-1] + else: + pred = "" + else: + pred = "" + + # choice answer + if ( + data_name in ["sat_math", "aqua"] + or "mmlu" in data_name + ): + tmp = re.findall(r"\b(A|B|C|D|E)\b", pred.upper()) + if tmp: + pred = tmp[-1] + else: + pred = pred.strip().strip(".") + + # multiple line + # pred = pred.split("\n")[0] + pred = re.sub(r"\n\s*", "", pred) + if pred != "" and pred[0] == ":": + pred = pred[1:] + if pred != "" and pred[-1] == ".": + pred = pred[:-1] + if pred != "" and pred[-1] == "/": + pred = pred[:-1] + pred = strip_string(pred, skip_unit=data_name in ["carp_en", "minerva_math"]) + return pred + + +STRIP_EXCEPTIONS = ["carp_en", "minerva_math"] + + +def parse_ground_truth(example: Dict[str, Any], data_name): + if "gt_cot" in example and "gt" in example: + if data_name in ["math", "math500"]: + gt_ans = extract_answer(example["gt_cot"], data_name) + elif data_name in STRIP_EXCEPTIONS: + gt_ans = example["gt"] + else: + gt_ans = strip_string(example["gt"]) + return example["gt_cot"], gt_ans + + # parse ground truth + if data_name in ["math", "math500", "minerva_math"]: + gt_cot = example["solution"] + gt_ans = extract_answer(gt_cot, data_name) + elif data_name == "gsm8k": + gt_cot, gt_ans = example["answer"].split("####") + elif data_name == "svamp": + gt_cot, gt_ans = example["Equation"], example["Answer"] + elif data_name == "asdiv": + gt_cot = example["formula"] + gt_ans = re.sub(r"\(.*?\)", "", example["answer"]) + elif data_name == "mawps": + gt_cot, gt_ans = None, example["target"] + elif data_name == "tabmwp": + gt_cot = example["solution"] + gt_ans = example["answer"] + if example["ans_type"] in ["integer_number", "decimal_number"]: + if "/" in gt_ans: + gt_ans = int(gt_ans.split("/")[0]) / int(gt_ans.split("/")[1]) + elif "," in gt_ans: + gt_ans = float(gt_ans.replace(",", "")) + elif "%" in gt_ans: + gt_ans = float(gt_ans.split("%")[0]) / 100 + else: + gt_ans = float(gt_ans) + elif data_name == "carp_en": + gt_cot, gt_ans = example["steps"], example["answer"] + elif data_name == "mmlu_stem": + abcd = "ABCD" + gt_cot, gt_ans = None, abcd[example["answer"]] + elif data_name == "sat_math": + gt_cot, gt_ans = None, example["Answer"] + elif data_name == "aqua": + gt_cot, gt_ans = None, example["correct"] + elif data_name in ["gaokao2023en", "college_math", "gaokao_math_cloze"]: + gt_cot, gt_ans = None, example["answer"].replace("$", "").strip() + elif data_name == "gaokao_math_qa": + gt_cot, gt_ans = None, example["label"] + elif data_name in ["gaokao2024_mix", "cn_middle_school"]: + if len(example["choice_answer"]) > 0: + gt_cot, gt_ans = None, example["choice_answer"] + else: + gt_cot, gt_ans = None, example["answer"] + elif data_name == "olympiadbench": + gt_cot, gt_ans = None, example["final_answer"][0].strip("$") + elif data_name in [ + "aime24", + "aime24x8", + "aime25", + "aime25x8", + "phy1", + "amc23", + "amc23x8", + "cmath", + "gaokao2024_I", + "gaokao2024_II", + "imo2024", + "deepscaler", + "deepscaler_random3p", + "deepscaler_random3p_noInstruct" + ]: + gt_cot, gt_ans = None, example["answer"] + else: + raise NotImplementedError(f"`{data_name}`") + # post process + gt_cot = str(gt_cot).strip() + if data_name not in STRIP_EXCEPTIONS: + gt_ans = strip_string(gt_ans, skip_unit=data_name == "carp_en") + else: + gt_ans = ( + gt_ans.replace("\\neq", "\\ne") + .replace("\\leq", "\\le") + .replace("\\geq", "\\ge") + ) + return gt_cot, gt_ans + + +def parse_question(example, data_name): + question = "" + if data_name == "asdiv": + question = f"{example['body'].strip()} {example['question'].strip()}" + elif data_name == "svamp": + body = example["Body"].strip() + if not body.endswith("."): + body = body + "." + question = f'{body} {example["Question"].strip()}' + elif data_name == "tabmwp": + title_str = ( + f'regarding "{example["table_title"]}" ' if example["table_title"] else "" + ) + question = f"Read the following table {title_str}and answer a question:\n" + question += f'{example["table"]}\n{example["question"]}' + if example["choices"]: + question += ( + f' Please select from the following options: {example["choices"]}' + ) + elif data_name == "carp_en": + question = example["content"] + elif data_name == "mmlu_stem": + options = example["choices"] + assert len(options) == 4 + for i, (label, option) in enumerate(zip("ABCD", options)): + options[i] = f"({label}) {str(option).strip()}" + options = " ".join(options) + # question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options}" + question = f"{example['question'].strip()}\nAnswer Choices: {options}" + elif data_name == "sat_math": + options = example["options"].strip() + assert "A" == options[0] + options = "(" + options + for ch in "BCD": + if f" {ch}) " in options: + options = regex.sub(f" {ch}\) ", f" ({ch}) ", options) + # question = f"{example['question'].strip()}\nWhat of the following is the right choice? Explain your answer.\n{options.strip()}" + question = f"{example['question'].strip()}\nAnswer Choices: {options}" + elif "aqua" in data_name: + options = example["options"] + choice = "(" + "(".join(options) + choice = choice.replace("(", " (").replace(")", ") ").strip() + choice = "\nAnswer Choices: " + choice + question = example["question"].strip() + choice + elif data_name == "gaokao_math_qa": + options_dict = example["options"] + options = [] + for key in options_dict: + options.append(f"({key}) {options_dict[key]}") + options = " ".join(options) + question = f"{example['question'].strip()}\n选项: {options}" + else: + for key in ["question", "problem", "Question", "input"]: + if key in example: + question = example[key] + break + # assert question != "" + # Yes or No question + _, gt_ans = parse_ground_truth(example, data_name) + if isinstance(gt_ans, str): + gt_lower = gt_ans.lower() + if gt_lower in ["true", "false"]: + question += " (True or False)" + if gt_lower in ["yes", "no"]: + question += " (Yes or No)" + return question.strip() + + +def run_execute(executor, result, prompt_type, data_name, execute=False): + if not result or result == "error": + return None, None + report = None + + if "program_only" in prompt_type: + prediction = extract_program_output(result) + elif prompt_type in ["pot", "pal"] and execute: + code = extract_program(result) + prediction, report = executor.apply(code) + else: + prediction = extract_answer(result, data_name) + + # prediction = strip_string(prediction, skip_unit=data_name == "carp_en") + prediction = strip_string(prediction, skip_unit=data_name in STRIP_EXCEPTIONS) + return prediction, report + + +def _test_extract_answer(): + text = """ +This is still not equal to $0$, so we must have made another mistake. + +When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that: + +\[\frac{386}{64} - 7 = \frac{386}{64} - \frac{7 \cdot 64}{1 \cdot 64} = \frac{386 - 448}{64} = \frac{-62}{64}.\] + +This is still not equal to $0$, so we must have made another mistake. + +When we subtracted $7$ from $\frac{386}{64}$, we should have subtracted $7 \cdot 64$ from $386$, not the other way around. Let's correct that: + +\[\frac{386}{64} +""" + print(extract_answer(text, "math-oai", use_last_number=False)) + print(choice_answer_clean("\mathrm{(D)\}1,008,016")) + # should output a dict + + +if __name__ == "__main__": + _test_extract_answer() |
