From 947d9dfdf16ae37109898111a5caacae7377b96d Mon Sep 17 00:00:00 2001
From: = <=>
Date: Wed, 4 Jun 2025 11:49:37 +0800
Subject: update code and kk eval
---
kk_eval/compute_score.py | 132 +++++++++++++++++++++++++++++++++++++++++++++++
1 file changed, 132 insertions(+)
create mode 100644 kk_eval/compute_score.py
(limited to 'kk_eval/compute_score.py')
diff --git a/kk_eval/compute_score.py b/kk_eval/compute_score.py
new file mode 100644
index 0000000..133e9ab
--- /dev/null
+++ b/kk_eval/compute_score.py
@@ -0,0 +1,132 @@
+import re
+from typing import Dict, Tuple, Optional
+
+def extract_solution(solution_str: str) -> Tuple[Optional[str], str]:
+ """Extracts the final answer from the model's response string.
+
+ Args:
+ solution_str: Raw response string from the language model
+
+ Returns:
+ Tuple containing (extracted_answer, processed_string)
+ """
+ # Split response to isolate assistant output
+ # if "Assistant:" in solution_str:
+ # processed_str = solution_str.split("Assistant:", 1)[1]
+ # elif "<|im_start|>assistant" in solution_str:
+ # processed_str = solution_str.split("<|im_start|>assistant", 1)[1]
+ # else:
+ # print("[Error] Failed to locate model response header")
+ # return None, solution_str
+ processed_str = solution_str
+
+ # Extract final answer using XML-style tags
+ answer_pattern = r'(.*?)'
+ matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL))
+
+ if not matches:
+ print("[Error] No valid answer tags found")
+ return None, processed_str
+
+ final_answer = matches[-1].group(1).strip()
+ return final_answer, processed_str
+
+def parse_solution_text_format(solution_text: str) -> Dict[str, str]:
+ """Parses ground truth solution text into status dictionary.
+
+ Args:
+ solution_text: Formatted solution text from dataset
+
+ Returns:
+ Dictionary mapping character names to their roles (knight/knave)
+ """
+ status_dict = {}
+ print("\n[Ground Truth Parsing]")
+
+ for line in solution_text.split('\n'):
+ line = line.strip()
+ if not line:
+ continue
+
+ match = re.search(r'\b([A-Za-z]+)\b.*?\b(knight|knave)\b', line, re.IGNORECASE)
+ if match:
+ name, role = match.groups()
+ status_dict[name] = role.lower()
+ print(f" Found: {name} → {role}")
+ else:
+ print(f" [Warning] Unparseable line: '{line}'")
+
+ return status_dict
+
+def parse_model_answer(answer_text: str, expected_names: list) -> Optional[Dict[str, str]]:
+ """Parses model's answer text into status dictionary.
+
+ Args:
+ answer_text: Text extracted from model's tags
+ expected_names: List of character names requiring identification
+
+ Returns:
+ Dictionary mapping character names to predicted roles, or None if incomplete
+ """
+ status_dict = {}
+ print("\n[Model Answer Parsing]")
+ print(f" Expected characters: {expected_names}")
+
+ for name in expected_names:
+ pattern = re.compile(
+ rf'\b{re.escape(name)}\b.*?\b(knight|knave)\b',
+ re.IGNORECASE
+ )
+ match = pattern.search(answer_text)
+
+ if match:
+ role = match.group(1).lower()
+ status_dict[name] = role
+ print(f" Found: {name} → {role}")
+ else:
+ print(f" [Error] Missing identification for {name}")
+ return None
+
+ return status_dict
+
+def validate_response_structure(processed_str: str) -> bool:
+ """Performs comprehensive validation of response structure.
+
+ Args:
+ processed_str: Processed response string from the model
+
+ Returns:
+ Boolean indicating whether all formatting requirements are met
+ """
+ print("\n[Structure Validation]")
+ validation_passed = True
+ return validation_passed
+
+ # Check required tags
+ tags = {
+ 'think_end': ('', 1),
+ 'answer_start': ('', 1),
+ 'answer_end': ('', 1)
+ }
+
+ positions = {}
+ for tag_name, (tag_str, expected_count) in tags.items():
+ count = processed_str.count(tag_str)
+ positions[tag_name] = pos = processed_str.find(tag_str)
+
+ print(f" {tag_str}: count={count}, position={pos}")
+
+ if count != expected_count:
+ print(f" [Error] {tag_str} appears {count} times (expected {expected_count})")
+ validation_passed = False
+
+ # Verify tag order
+ if (
+ positions['think_end'] > positions['answer_start'] or
+ positions['answer_start'] > positions['answer_end']):
+ print(" [Error] Incorrect tag order: Expected ......")
+ validation_passed = False
+ else:
+ print(" Tag sequence validation passed")
+
+ return validation_passed
--
cgit v1.2.3