From fc6d57ffb8d5ddb5820fcc00b5491a585c259ebc Mon Sep 17 00:00:00 2001 From: Yuren Hao Date: Thu, 4 Sep 2025 22:16:22 -0500 Subject: Initial commit --- Qwen2.5-Eval/evaluation/rm_maj_eval.py | 100 +++++++++++++++++++++++++++++++++ 1 file changed, 100 insertions(+) create mode 100644 Qwen2.5-Eval/evaluation/rm_maj_eval.py (limited to 'Qwen2.5-Eval/evaluation/rm_maj_eval.py') diff --git a/Qwen2.5-Eval/evaluation/rm_maj_eval.py b/Qwen2.5-Eval/evaluation/rm_maj_eval.py new file mode 100644 index 0000000..b41d5a8 --- /dev/null +++ b/Qwen2.5-Eval/evaluation/rm_maj_eval.py @@ -0,0 +1,100 @@ +from grader import math_equal +from parser import strip_string +import timeout_decorator +from collections import defaultdict, Counter +from utils import load_jsonl + + +@timeout_decorator.timeout(5) +def math_equal_timeout(pred, gt): + try: + return math_equal(pred, gt) + except Exception as e: + print("Timeout error:", e) + return False + + +def group_pred(preds, strip=True, use_symbol=False): + orginal_preds = preds + if not use_symbol: + if strip: + preds = [strip_string(pred) for pred in preds] + cnt = Counter(preds) + majority = cnt.most_common(1)[0][0] + groups = defaultdict(list) + for idx, pred in enumerate(preds): + groups[pred].append(idx) + return groups, orginal_preds[groups[majority][0]] + + groups = defaultdict(list) + for idx, pred in enumerate(preds): + found_group = False + if strip: + pred = strip_string(pred) + for group_pred in groups: + try: + if math_equal_timeout(pred, group_pred): + groups[group_pred].append(idx) + found_group = True + break + except: + continue + if not found_group: + groups[pred].append(idx) + # get the key of the longest group + majority = sorted(groups.items(), key=lambda item: len(item[1]), reverse=True)[0][0] + majority = orginal_preds[groups[majority][0]] + return groups, majority + + +def eval_rm_k_metrics(data_path, k=8): + print(f"evaluating rm@{k}") + data_list = load_jsonl(data_path) + + count, right_count = 0, 0 + for sample in data_list: + assert len(sample['pred_score']) >= k, sample['data_source'] + pred_score = sample['pred_score'][:k] + pred = sample['score'][:k] + assert len(pred_score) == len(pred), f"{len(pred_score)}, {len(pred)}" + + rm_score = pred_score + rm_score = [inner_score for score in rm_score for inner_score in score] + assert len(rm_score) == len(pred), f"{len(rm_score)}, {len(pred)}" + + max_index = rm_score.index(max(rm_score)) + max_pred = pred[max_index] + right_count += max_pred + count += 1 + + print(count) + task_acc = right_count / count * 100 + print(f"acc: {task_acc:.1f}") + return task_acc + + +def eval_maj_k_metrics(data_path, k=8): + print(f"evaluating maj@{k}") + + data_list = load_jsonl(data_path) + count, right_count = 0, 0 + for sample in data_list: + assert len(sample['score']) >= k, sample + groups, majority_pred = group_pred(sample['pred'][:k], strip=False, use_symbol=False) + idx = groups[majority_pred][0] + right_count += sample['score'][idx] + count += 1 + + task_acc = right_count / count * 100 + print(f"acc: {task_acc:.1f}") + return task_acc + + +if __name__ == "__main__": + data_path = "./data/eval_rm_maj_example/math_cot_100.jsonl" + + candidate = 8 + all_result = {} + all_result[f'maj@{candidate}'] = eval_maj_k_metrics(data_path, k=candidate) + all_result[f'rm@{candidate}'] = eval_rm_k_metrics(data_path, k=candidate) + print(all_result) -- cgit v1.2.3