From 947d9dfdf16ae37109898111a5caacae7377b96d Mon Sep 17 00:00:00 2001 From: = <=> Date: Wed, 4 Jun 2025 11:49:37 +0800 Subject: update code and kk eval --- kk_eval/compute_score.py | 132 +++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 132 insertions(+) create mode 100644 kk_eval/compute_score.py (limited to 'kk_eval/compute_score.py') diff --git a/kk_eval/compute_score.py b/kk_eval/compute_score.py new file mode 100644 index 0000000..133e9ab --- /dev/null +++ b/kk_eval/compute_score.py @@ -0,0 +1,132 @@ +import re +from typing import Dict, Tuple, Optional + +def extract_solution(solution_str: str) -> Tuple[Optional[str], str]: + """Extracts the final answer from the model's response string. + + Args: + solution_str: Raw response string from the language model + + Returns: + Tuple containing (extracted_answer, processed_string) + """ + # Split response to isolate assistant output + # if "Assistant:" in solution_str: + # processed_str = solution_str.split("Assistant:", 1)[1] + # elif "<|im_start|>assistant" in solution_str: + # processed_str = solution_str.split("<|im_start|>assistant", 1)[1] + # else: + # print("[Error] Failed to locate model response header") + # return None, solution_str + processed_str = solution_str + + # Extract final answer using XML-style tags + answer_pattern = r'(.*?)' + matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL)) + + if not matches: + print("[Error] No valid answer tags found") + return None, processed_str + + final_answer = matches[-1].group(1).strip() + return final_answer, processed_str + +def parse_solution_text_format(solution_text: str) -> Dict[str, str]: + """Parses ground truth solution text into status dictionary. + + Args: + solution_text: Formatted solution text from dataset + + Returns: + Dictionary mapping character names to their roles (knight/knave) + """ + status_dict = {} + print("\n[Ground Truth Parsing]") + + for line in solution_text.split('\n'): + line = line.strip() + if not line: + continue + + match = re.search(r'\b([A-Za-z]+)\b.*?\b(knight|knave)\b', line, re.IGNORECASE) + if match: + name, role = match.groups() + status_dict[name] = role.lower() + print(f" Found: {name} → {role}") + else: + print(f" [Warning] Unparseable line: '{line}'") + + return status_dict + +def parse_model_answer(answer_text: str, expected_names: list) -> Optional[Dict[str, str]]: + """Parses model's answer text into status dictionary. + + Args: + answer_text: Text extracted from model's tags + expected_names: List of character names requiring identification + + Returns: + Dictionary mapping character names to predicted roles, or None if incomplete + """ + status_dict = {} + print("\n[Model Answer Parsing]") + print(f" Expected characters: {expected_names}") + + for name in expected_names: + pattern = re.compile( + rf'\b{re.escape(name)}\b.*?\b(knight|knave)\b', + re.IGNORECASE + ) + match = pattern.search(answer_text) + + if match: + role = match.group(1).lower() + status_dict[name] = role + print(f" Found: {name} → {role}") + else: + print(f" [Error] Missing identification for {name}") + return None + + return status_dict + +def validate_response_structure(processed_str: str) -> bool: + """Performs comprehensive validation of response structure. + + Args: + processed_str: Processed response string from the model + + Returns: + Boolean indicating whether all formatting requirements are met + """ + print("\n[Structure Validation]") + validation_passed = True + return validation_passed + + # Check required tags + tags = { + 'think_end': ('', 1), + 'answer_start': ('', 1), + 'answer_end': ('', 1) + } + + positions = {} + for tag_name, (tag_str, expected_count) in tags.items(): + count = processed_str.count(tag_str) + positions[tag_name] = pos = processed_str.find(tag_str) + + print(f" {tag_str}: count={count}, position={pos}") + + if count != expected_count: + print(f" [Error] {tag_str} appears {count} times (expected {expected_count})") + validation_passed = False + + # Verify tag order + if ( + positions['think_end'] > positions['answer_start'] or + positions['answer_start'] > positions['answer_end']): + print(" [Error] Incorrect tag order: Expected ......") + validation_passed = False + else: + print(" Tag sequence validation passed") + + return validation_passed -- cgit v1.2.3