import numpy as np
from kk_prompt import system_instruction, demonstration_2char, system_instruction_no_reason, demonstration_2char_no_reason
from compute_score import extract_solution, parse_solution_text_format, parse_model_answer, validate_response_structure
def num_tokens_from_string(string):
import tiktoken
"""Returns the number of tokens in a text string."""
encoding = tiktoken.encoding_for_model("gpt-3.5-turbo")
num_tokens = len(encoding.encode(string))
return num_tokens
def parse_cot_eval(pred_str, ans,
conclusion_patterns=['CONCLUSION:'],
verbose=False,
finish_patterns=["### Reason", "Let's think step by step again", "let's go back and check", "###"],
reformat_gold_conditions=None):
def judge_string(input_str, reformat_gold_conditions, wrong_reason, finish_patterns):
correct_count = 0
is_correct = False
beyond_id = len(reformat_gold_conditions)+1
beyond_id_pattern = f"({beyond_id})"
for finish_pattern in finish_patterns:
if finish_pattern in input_str:
input_str = input_str.split(finish_pattern)[0]
if beyond_id_pattern in input_str:
is_correct = False
wrong_reason = "beyond_list"
elif "if" in input_str:
is_correct = False
wrong_reason = "contain_if"
else:
is_correct = True
for gold_condition in reformat_gold_conditions:
if gold_condition not in input_str:
is_correct = False
wrong_reason = "wrong_identity"
else:
correct_count += 1
correct_ratio = correct_count/len(reformat_gold_conditions)
return is_correct, wrong_reason, correct_ratio
def check_numbers_in_string(s, N):
for i in range(1, N + 1):
if f"({i})" not in s:
return False
return True
original_str = pred_str
pred_str = pred_str.split("### Question")[0]
pred_answer = pred_str
is_correct = False
correct_ratio = 0
if reformat_gold_conditions is None:
gold = ans.replace(" and ", "").replace(".", "")
gold_conditions = gold.split(",")
reformat_gold_conditions = []
for condition in gold_conditions:
gold_condition = condition.strip() # Remove leading and trailing spaces
reformat_gold_conditions.append(gold_condition)
wrong_reason = "no_conclusion_matched"
for pattern in conclusion_patterns:
pred = pred_str.split(pattern)
if len(pred) > 1:
if len(pred[1]) > 0: # if the matched the answer is not empty
pred_answer = pred[1]
is_correct, wrong_reason, correct_ratio = judge_string(
pred_answer, reformat_gold_conditions, wrong_reason, finish_patterns)
break
if is_correct == False and wrong_reason == "no_conclusion_matched":
if check_numbers_in_string(pred_str, len(reformat_gold_conditions)): # the answer contains (1)..(2)..
is_correct, wrong_reason, correct_ratio = judge_string(
pred_str, reformat_gold_conditions, wrong_reason, finish_patterns)
if is_correct == False and verbose == True:
print("wrong_reason:",wrong_reason)
print("********* \nprediction before parse:\n", original_str)
print("********* \nprediction after parse:\n", pred_answer)
return is_correct, pred_answer, wrong_reason, correct_ratio, reformat_gold_conditions
def parse_cot_eval_instruct(pred_str, ans,
conclusion_patterns=[''],
verbose=False,
finish_patterns=[""],
reformat_gold_conditions=None,
expected_names=None,
solution_text_format=None):
print("\n" + "="*80)
print(" Processing New Sample ".center(80, '='))
# Parse ground truth data
gt_status = parse_solution_text_format(solution_text_format)
expected_names = list(gt_status.keys())
print(f"[Ground Truth] Final identities: {gt_status}")
# Extract model answer
answer_text, processed_str = extract_solution(pred_str)
print(f"\n[Model Response]\n{processed_str}")
# Validate response structure
format_correct = validate_response_structure(processed_str)
print(f"\n Format validation: {'PASS' if format_correct else 'FAIL'}")
# Validate answer content
answer_score = 0
is_correct = False
correct_ratio = 0
wrong_reason = "no_conclusion_matched"
if format_correct and answer_text:
pred_status = parse_model_answer(answer_text, expected_names)
if pred_status:
print(f"\n[Content Validation]")
print(f" Expected: {gt_status}")
print(f" Predicted: {pred_status}")
if pred_status == gt_status:
answer_score = 2
is_correct = True
correct_ratio = 1
print(" Content validation: FULL MATCH")
else:
answer_score = -1.5
correct_ratio = 0
wrong_reason = "wrong_identity"
print(" Content validation: MISMATCH")
else:
answer_score = -2
correct_ratio = 0
wrong_reason = "no_conclusion_matched"
print( "Fail to parse answer")
else:
print("\n[Content Validation] Skipped due to format errors or missing answer")
if is_correct == False and verbose == True:
print("wrong_reason:",wrong_reason)
print("********* \nprediction before parse:\n", pred_str)
print("********* \nprediction after parse:\n", answer_text)
return is_correct, answer_text, wrong_reason, correct_ratio, reformat_gold_conditions
class KKProcessor:
def __init__(self, cot=True, no_linebreak=True):
self.cot = cot
self.no_linebreak = no_linebreak
def format_example(self, test_records, idx, model_name=None):
item = test_records[idx]
prompt = "### Question: "+item["quiz"] + "\n"
if self.cot:
if model_name in ["deepseek-ai/deepseek-math-7b-instruct", "AI-MO/NuminaMath-7B-CoT"]:
prompt += "Please reason step by step, and put your final answer within \\boxed{}."
else:
prompt += "### Answer: Let's think step by step. "
else:
if self.no_linebreak:
prompt += "### Answer:"
else:
prompt += "### Answer:\n"
answer = item["solution_text"]
return prompt, answer
def gen_test_prompt(self, ntrain, test_records, idx, model_name=None):
if self.cot:
train_prompt = system_instruction
else:
train_prompt = system_instruction_no_reason
if ntrain == 1:
if self.cot:
train_prompt += "\n\n"+demonstration_2char
else:
train_prompt += "\n\n"+demonstration_2char_no_reason
elif ntrain > 1:
raise NotImplementedError
prompt_end, answer = self.format_example(test_records, idx, model_name)
prompt = train_prompt + "\n\n" + prompt_end
return prompt, answer
def _parse_cot_eval(self, pred_str, ans, model_name=None):
conclusion_patterns = ['CONCLUSION:', 'Conclusion:', 'conclusion:']
if model_name in ["deepseek-ai/deepseek-math-7b-instruct", "AI-MO/NuminaMath-7B-CoT"]:
conclusion_patterns = ['boxed{', 'CONCLUSION:', 'Conclusion:', 'conclusion:']
is_correct, pred_answer, wrong_reason, correct_ratio, reformat_gold_conditions = parse_cot_eval(
pred_str, ans, conclusion_patterns=conclusion_patterns, verbose=False)
return is_correct, pred_answer, reformat_gold_conditions
def _parse_cot_eval_instruct(self, pred_str, ans, model_name=None, expected_names=None, solution_text_format=None):
conclusion_patterns = ['']
is_correct, pred_answer, wrong_reason, correct_ratio, reformat_gold_conditions = parse_cot_eval_instruct(
pred_str, ans, conclusion_patterns=conclusion_patterns, verbose=True, expected_names=expected_names, solution_text_format=solution_text_format)
return is_correct, pred_answer, reformat_gold_conditions