1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
|
from grader import math_equal
from parser import strip_string
import timeout_decorator
from collections import defaultdict, Counter
from utils import load_jsonl
@timeout_decorator.timeout(5)
def math_equal_timeout(pred, gt):
try:
return math_equal(pred, gt)
except Exception as e:
print("Timeout error:", e)
return False
def group_pred(preds, strip=True, use_symbol=False):
orginal_preds = preds
if not use_symbol:
if strip:
preds = [strip_string(pred) for pred in preds]
cnt = Counter(preds)
majority = cnt.most_common(1)[0][0]
groups = defaultdict(list)
for idx, pred in enumerate(preds):
groups[pred].append(idx)
return groups, orginal_preds[groups[majority][0]]
groups = defaultdict(list)
for idx, pred in enumerate(preds):
found_group = False
if strip:
pred = strip_string(pred)
for group_pred in groups:
try:
if math_equal_timeout(pred, group_pred):
groups[group_pred].append(idx)
found_group = True
break
except:
continue
if not found_group:
groups[pred].append(idx)
# get the key of the longest group
majority = sorted(groups.items(), key=lambda item: len(item[1]), reverse=True)[0][0]
majority = orginal_preds[groups[majority][0]]
return groups, majority
def eval_rm_k_metrics(data_path, k=8):
print(f"evaluating rm@{k}")
data_list = load_jsonl(data_path)
count, right_count = 0, 0
for sample in data_list:
assert len(sample['pred_score']) >= k, sample['data_source']
pred_score = sample['pred_score'][:k]
pred = sample['score'][:k]
assert len(pred_score) == len(pred), f"{len(pred_score)}, {len(pred)}"
rm_score = pred_score
rm_score = [inner_score for score in rm_score for inner_score in score]
assert len(rm_score) == len(pred), f"{len(rm_score)}, {len(pred)}"
max_index = rm_score.index(max(rm_score))
max_pred = pred[max_index]
right_count += max_pred
count += 1
print(count)
task_acc = right_count / count * 100
print(f"acc: {task_acc:.1f}")
return task_acc
def eval_maj_k_metrics(data_path, k=8):
print(f"evaluating maj@{k}")
data_list = load_jsonl(data_path)
count, right_count = 0, 0
for sample in data_list:
assert len(sample['score']) >= k, sample
groups, majority_pred = group_pred(sample['pred'][:k], strip=False, use_symbol=False)
idx = groups[majority_pred][0]
right_count += sample['score'][idx]
count += 1
task_acc = right_count / count * 100
print(f"acc: {task_acc:.1f}")
return task_acc
if __name__ == "__main__":
data_path = "./data/eval_rm_maj_example/math_cot_100.jsonl"
candidate = 8
all_result = {}
all_result[f'maj@{candidate}'] = eval_maj_k_metrics(data_path, k=candidate)
all_result[f'rm@{candidate}'] = eval_rm_k_metrics(data_path, k=candidate)
print(all_result)
|