summaryrefslogtreecommitdiff
path: root/Qwen2.5-Eval/evaluation/math_eval.py
diff options
context:
space:
mode:
Diffstat (limited to 'Qwen2.5-Eval/evaluation/math_eval.py')
-rwxr-xr-xQwen2.5-Eval/evaluation/math_eval.py437
1 files changed, 437 insertions, 0 deletions
diff --git a/Qwen2.5-Eval/evaluation/math_eval.py b/Qwen2.5-Eval/evaluation/math_eval.py
new file mode 100755
index 0000000..0a35929
--- /dev/null
+++ b/Qwen2.5-Eval/evaluation/math_eval.py
@@ -0,0 +1,437 @@
+# rm -rf math_eval.py; vim math_eval.py
+import random
+import os
+import argparse
+import time
+from vllm import LLM, SamplingParams
+from datetime import datetime
+from tqdm import tqdm
+
+import torch
+from transformers import AutoTokenizer, AutoModelForCausalLM
+
+from evaluate import evaluate
+from utils import set_seed, load_jsonl, save_jsonl, construct_prompt
+from parser import *
+from trajectory import *
+from data_loader import load_data
+from python_executor import PythonExecutor
+from model_utils import load_hf_lm_and_tokenizer, generate_completions
+
+
+def parse_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--data_names", default="gsm8k,math", type=str)
+ parser.add_argument("--data_dir", default="./data", type=str)
+ parser.add_argument("--model_name_or_path", default="gpt-4", type=str)
+ parser.add_argument("--output_dir", default="./output", type=str)
+ parser.add_argument("--prompt_type", default="tool-integrated", type=str)
+ parser.add_argument("--split", default="test", type=str)
+ parser.add_argument("--num_test_sample", default=-1, type=int) # -1 for full data
+ parser.add_argument("--seed", default=0, type=int)
+ parser.add_argument("--start", default=0, type=int)
+ parser.add_argument("--end", default=-1, type=int)
+ parser.add_argument("--temperature", default=0, type=float)
+ parser.add_argument("--n_sampling", default=1, type=int)
+ parser.add_argument("--top_p", default=1, type=float)
+ parser.add_argument("--max_tokens_per_call", default=8192, type=int)
+ parser.add_argument("--shuffle", action="store_true")
+ parser.add_argument("--use_vllm", action="store_true")
+ parser.add_argument("--vllm_batch_size",default=0, type=int)
+ parser.add_argument("--save_outputs", action="store_true")
+ parser.add_argument("--overwrite", action="store_true")
+ parser.add_argument("--use_safetensors", action="store_true")
+ parser.add_argument("--num_shots", type=int, default=0)
+ parser.add_argument(
+ "--apply_chat_template",
+ action="store_true",
+ help="Apply chat template to prompt.",
+ )
+ parser.add_argument("--pipeline_parallel_size", type=int, default=1)
+ parser.add_argument(
+ "--adapt_few_shot",
+ action="store_true",
+ help="Few shot for multiple-choice questions, zero shot for others.",
+ )
+ args = parser.parse_args()
+ args.top_p = (
+ 1 if args.temperature == 0 else args.top_p
+ ) # top_p must be 1 when using greedy sampling (vllm)
+ return args
+
+
+def prepare_data(data_name, args):
+ examples = load_data(data_name, args.split, args.data_dir)
+
+ # sample `num_test_sample` from dataset
+ if args.num_test_sample > 0:
+ # examples = random.sample(examples, min(args.num_test_sample, len(examples)))
+ examples = examples[: args.num_test_sample]
+
+ # shuffle
+ if args.shuffle:
+ random.seed(datetime.now().timestamp())
+ random.shuffle(examples)
+
+ # select start and end
+ examples = examples[args.start : len(examples) if args.end == -1 else args.end]
+
+ # get out_file name
+ dt_string = datetime.now().strftime("%m-%d_%H-%M")
+ model_name = "/".join(args.model_name_or_path.split("/")[-2:])
+ out_file_prefix = f"{args.split}_{args.prompt_type}_{args.num_test_sample}_seed{args.seed}_t{args.temperature}"
+ output_dir = args.output_dir
+ # if not os.path.exists(output_dir):
+ # output_dir = f"outputs/{output_dir}" # don't add this
+ out_file = f"{output_dir}/{data_name}/{out_file_prefix}_s{args.start}_e{args.end}.jsonl"
+ os.makedirs(f"{output_dir}/{data_name}", exist_ok=True)
+
+ # load all processed samples
+ processed_samples = []
+ if not args.overwrite:
+ processed_files = [
+ f
+ for f in os.listdir(f"{output_dir}/{data_name}/")
+ if f.endswith(".jsonl") and f.startswith(out_file_prefix)
+ ]
+ for f in processed_files:
+ processed_samples.extend(
+ list(load_jsonl(f"{output_dir}/{data_name}/{f}"))
+ )
+
+ # dedepulicate
+ processed_samples = {sample["idx"]: sample for sample in processed_samples}
+ processed_idxs = list(processed_samples.keys())
+ processed_samples = list(processed_samples.values())
+ examples = [example for example in examples if example["idx"] not in processed_idxs]
+ return examples, processed_samples, out_file
+
+
+def setup(args):
+ # load model
+ available_gpus = os.environ["CUDA_VISIBLE_DEVICES"].split(",")
+ if args.use_vllm:
+ llm = LLM(
+ # model="/volume/ailab4sci/ztgao/em/checkpoints/step_20",
+ model=args.model_name_or_path,
+ tensor_parallel_size=4,
+ pipeline_parallel_size=args.pipeline_parallel_size,
+ trust_remote_code=True,
+ )
+ tokenizer = None
+ if args.apply_chat_template:
+ tokenizer = AutoTokenizer.from_pretrained(
+ args.model_name_or_path, trust_remote_code=True
+ )
+ else:
+ llm, tokenizer = load_hf_lm_and_tokenizer(
+ model_name_or_path=args.model_name_or_path,
+ load_in_half=True,
+ use_fast_tokenizer=True,
+ use_safetensors=args.use_safetensors,
+ )
+
+ # infer & eval
+ data_list = args.data_names.split(",")
+ results = []
+ for data_name in data_list:
+ results.append(main(llm, tokenizer, data_name, args))
+
+ # add "avg" result to data_list and results
+ data_list.append("avg")
+ results.append(
+ {
+ "acc": sum([result["acc"] for result in results]) / len(results),
+ }
+ )
+
+ # print all results
+ pad = max([len(data_name) for data_name in data_list])
+ print("\t".join(data_name.ljust(pad, " ") for data_name in data_list))
+ print("\t".join([f"{result['acc']:.1f}".ljust(pad, " ") for result in results]))
+
+
+def is_multi_choice(answer):
+ for c in answer:
+ if c not in ["A", "B", "C", "D", "E"]:
+ return False
+ return True
+
+
+def main(llm, tokenizer, data_name, args):
+ examples, processed_samples, out_file = prepare_data(data_name, args)
+ print("=" * 50)
+ print("data:", data_name, " ,remain samples:", len(examples))
+ if len(examples) > 0:
+ print(examples[0])
+
+ # init python executor
+ if "pal" in args.prompt_type:
+ executor = PythonExecutor(get_answer_expr="solution()")
+ else:
+ executor = PythonExecutor(get_answer_from_stdout=True)
+
+ samples = []
+ for example in tqdm(examples, total=len(examples)):
+ idx = example["idx"]
+
+ # parse question and answer
+ example["question"] = parse_question(example, data_name)
+ if example["question"] == "":
+ continue
+ gt_cot, gt_ans = parse_ground_truth(example, data_name)
+ example["gt_ans"] = gt_ans
+ full_prompt = construct_prompt(example, data_name, args)
+
+ if idx == args.start:
+ print(full_prompt)
+
+ sample = {
+ "idx": idx,
+ "question": example["question"],
+ "gt_cot": gt_cot,
+ "gt": gt_ans,
+ "prompt": full_prompt,
+ }
+
+ # add remain fields
+ for key in [
+ "level",
+ "type",
+ "unit",
+ "solution_type",
+ "choices",
+ "solution",
+ "ques_type",
+ "ans_type",
+ "answer_type",
+ "dataset",
+ "subfield",
+ "filed",
+ "theorem",
+ "answer",
+ "difficulty"
+ ]:
+ if key in example:
+ sample[key] = example[key]
+ samples.append(sample)
+
+ # repeat n times
+ input_prompts = [
+ sample["prompt"] for sample in samples for _ in range(args.n_sampling)
+ ]
+ if args.apply_chat_template:
+ input_prompts = [
+ tokenizer.apply_chat_template(
+ [{"role": "user", "content": prompt.strip()}],
+ tokenize=False,
+ add_generation_prompt=True,
+ )
+ for prompt in input_prompts
+ ]
+ remain_prompts = input_prompts
+ remain_prompts = [(i, prompt) for i, prompt in enumerate(remain_prompts)]
+ end_prompts = []
+
+ max_func_call = 1 if args.prompt_type in ["cot", "pal"] else 4
+
+ stop_words = ["</s>", "<|im_end|>", "<|endoftext|>"]
+
+ if args.prompt_type in ["cot"]:
+ stop_words.append("\n\nQuestion:")
+ if args.prompt_type in ["pal", "tool-integrated", "jiuzhang_tora"]:
+ stop_words.extend(["\n\n---", "```output"])
+ elif args.prompt_type in ["wizard_zs", "platypus_fs"]:
+ stop_words.extend(["Instruction", "Response"])
+ elif "jiuzhang" in args.prompt_type:
+ stop_words.append("\n\n## Question")
+ elif "numina" in args.prompt_type:
+ stop_words.append("\n### Problem")
+ elif "pure" in args.prompt_type:
+ stop_words.append("\n\n\n")
+
+ # start inference
+ # measure time use
+ start_time = time.time()
+ result_prompts = []
+ for epoch in range(max_func_call):
+ print("-" * 20, "Epoch", epoch)
+ current_prompts = remain_prompts
+ if len(current_prompts) == 0:
+ break
+
+ # get all outputs
+ prompts = [item[1] for item in current_prompts]
+ result_prompts.extend(prompts)
+ if args.use_vllm:
+ if args.vllm_batch_size:
+ outputs = []
+ for i in range(0, len(prompts), args.vllm_batch_size):
+ batch_prompts = prompts[i:i+args.vllm_batch_size]
+ print(f"args.temperature: {args.temperature}")
+ batch_outputs = llm.generate(
+ batch_prompts,
+ SamplingParams(
+ temperature=args.temperature,
+ max_tokens=args.max_tokens_per_call,
+ stop=stop_words,
+ stop_token_ids=(
+ [151645, 151643]
+ if "qwen2" in args.model_name_or_path.lower()
+ else None
+ ),
+ ),
+ )
+
+ batch_outputs = sorted(
+ batch_outputs, key=lambda x: int(x.request_id)
+ )
+ batch_outputs = [output.outputs[0].text for output in batch_outputs]
+ outputs.extend(batch_outputs)
+ else:
+ outputs = llm.generate(
+ prompts,
+ SamplingParams(
+ temperature=args.temperature,
+ # top_p=args.top_p,
+ max_tokens=args.max_tokens_per_call,
+ # n=1,
+ stop=stop_words,
+ stop_token_ids=(
+ [151645, 151643]
+ if "qwen2" in args.model_name_or_path.lower()
+ else None
+ ),
+ ),
+ )
+
+ outputs = sorted(
+ outputs, key=lambda x: int(x.request_id)
+ ) # sort outputs by request_id
+ outputs = [output.outputs[0].text for output in outputs]
+ else:
+ outputs = generate_completions(
+ model=llm,
+ tokenizer=tokenizer,
+ prompts=prompts,
+ max_new_tokens=args.max_tokens_per_call,
+ batch_size=16,
+ stop_id_sequences=stop_words,
+ )
+
+ assert len(outputs) == len(current_prompts)
+
+ # process all outputs
+ remain_prompts = []
+ remain_codes = []
+ for (i, query), output in zip(current_prompts, outputs):
+ output = output.rstrip()
+ query += output
+ if args.prompt_type == "pal":
+ remain_prompts.append((i, query))
+ if "```python" in output:
+ output = extract_program(query)
+ remain_codes.append(output)
+ elif args.prompt_type == "cot":
+ end_prompts.append((i, query))
+ elif "boxed" not in output and output.endswith("```"):
+ program = extract_program(query)
+ remain_prompts.append((i, query))
+ remain_codes.append(program)
+ else:
+ end_prompts.append((i, query))
+
+ # execute the remain prompts
+ remain_results = executor.batch_apply(remain_codes)
+ for k in range(len(remain_prompts)):
+ i, query = remain_prompts[k]
+ res, report = remain_results[k]
+ exec_result = res if res else report
+ if "pal" in args.prompt_type:
+ exec_result = "\\boxed{" + exec_result + "}"
+ exec_result = f"\n```output\n{exec_result}\n```\n"
+ query += exec_result
+ # not end
+ if epoch == max_func_call - 1:
+ query += "\nReach max function call limit."
+ remain_prompts[k] = (i, query)
+
+ # unsolved samples
+ print("Unsolved samples:", len(remain_prompts))
+ end_prompts.extend(remain_prompts)
+ # sort by idx
+ end_prompts = sorted(end_prompts, key=lambda x: x[0])
+
+ # remove input_prompt from end_prompt
+ codes = []
+ assert len(input_prompts) == len(end_prompts)
+ for i in range(len(input_prompts)):
+ _, end_prompt = end_prompts[i]
+ code = end_prompt.split(input_prompts[i])[-1].strip()
+ for stop_word in stop_words:
+ if stop_word in code:
+ code = code.split(stop_word)[0].strip()
+ codes.append(code)
+
+ # extract preds
+ results = [
+ run_execute(executor, code, args.prompt_type, data_name) for code in codes
+ ]
+ time_use = time.time() - start_time
+
+ # put results back to examples
+ all_samples = []
+ for i, sample in enumerate(samples):
+ code = codes[i * args.n_sampling : (i + 1) * args.n_sampling]
+ result = results[i * args.n_sampling : (i + 1) * args.n_sampling]
+ preds = [item[0] for item in result]
+ reports = [item[1] for item in result]
+ for j in range(len(preds)):
+ if sample["gt"] in ["A", "B", "C", "D", "E"] and preds[j] not in [
+ "A",
+ "B",
+ "C",
+ "D",
+ "E",
+ ]:
+ preds[j] = choice_answer_clean(code[j])
+ elif is_multi_choice(sample["gt"]) and not is_multi_choice(preds[j]):
+ # remove any non-choice char
+ preds[j] = "".join(
+ [c for c in preds[j] if c in ["A", "B", "C", "D", "E"]]
+ )
+
+ sample.pop("prompt")
+ sample.update({"code": code, "pred": preds, "report": reports})
+ all_samples.append(sample)
+
+ # add processed samples
+ all_samples.extend(processed_samples)
+ all_samples, result_json = evaluate(
+ samples=all_samples,
+ data_name=data_name,
+ prompt_type=args.prompt_type,
+ execute=True,
+ )
+
+ # save outputs
+ if len(processed_samples) < len(all_samples) and args.save_outputs:
+ save_jsonl(all_samples, out_file)
+
+ result_json["time_use_in_second"] = time_use
+ result_json["time_use_in_minite"] = (
+ f"{int(time_use // 60)}:{int(time_use % 60):02d}"
+ )
+ result_json["prompts"] = result_prompts
+
+ with open(
+ out_file.replace(".jsonl", f"_{args.prompt_type}_metrics.json"), "w"
+ ) as f:
+ json.dump(result_json, f, indent=4)
+ return result_json
+
+
+if __name__ == "__main__":
+ args = parse_args()
+ set_seed(args.seed)
+ setup(args)