diff options
| author | Yuren Hao <yurenh2@timan108.cs.illinois.edu> | 2025-09-04 22:16:22 -0500 |
|---|---|---|
| committer | Yuren Hao <yurenh2@timan108.cs.illinois.edu> | 2025-09-04 22:16:22 -0500 |
| commit | fc6d57ffb8d5ddb5820fcc00b5491a585c259ebc (patch) | |
| tree | e9841f93a353e2107225cfc721d1ce57c0e594dc /kk_eval | |
Initial commit
Diffstat (limited to 'kk_eval')
| -rw-r--r-- | kk_eval/README.md | 16 | ||||
| -rw-r--r-- | kk_eval/compute_score.py | 132 | ||||
| -rw-r--r-- | kk_eval/eval_kk-ckpt_list.sh | 50 | ||||
| -rw-r--r-- | kk_eval/kk_processor.py | 208 | ||||
| -rw-r--r-- | kk_eval/kk_prompt.py | 38 | ||||
| -rw-r--r-- | kk_eval/main_eval_instruct.py | 209 |
6 files changed, 653 insertions, 0 deletions
diff --git a/kk_eval/README.md b/kk_eval/README.md new file mode 100644 index 0000000..49d9040 --- /dev/null +++ b/kk_eval/README.md @@ -0,0 +1,16 @@ +## Quick Start + + +1. Install dependencies: + +```bash +pip install vllm +``` + +2. Execute the evaluation script: + +```bash +bash eval_kk-ckpt_list.sh +``` + + > **Note**: Ensure that all configurations are correctly set before running the script to avoid errors. diff --git a/kk_eval/compute_score.py b/kk_eval/compute_score.py new file mode 100644 index 0000000..133e9ab --- /dev/null +++ b/kk_eval/compute_score.py @@ -0,0 +1,132 @@ +import re +from typing import Dict, Tuple, Optional + +def extract_solution(solution_str: str) -> Tuple[Optional[str], str]: + """Extracts the final answer from the model's response string. + + Args: + solution_str: Raw response string from the language model + + Returns: + Tuple containing (extracted_answer, processed_string) + """ + # Split response to isolate assistant output + # if "Assistant:" in solution_str: + # processed_str = solution_str.split("Assistant:", 1)[1] + # elif "<|im_start|>assistant" in solution_str: + # processed_str = solution_str.split("<|im_start|>assistant", 1)[1] + # else: + # print("[Error] Failed to locate model response header") + # return None, solution_str + processed_str = solution_str + + # Extract final answer using XML-style tags + answer_pattern = r'<answer>(.*?)</answer>' + matches = list(re.finditer(answer_pattern, processed_str, re.DOTALL)) + + if not matches: + print("[Error] No valid answer tags found") + return None, processed_str + + final_answer = matches[-1].group(1).strip() + return final_answer, processed_str + +def parse_solution_text_format(solution_text: str) -> Dict[str, str]: + """Parses ground truth solution text into status dictionary. + + Args: + solution_text: Formatted solution text from dataset + + Returns: + Dictionary mapping character names to their roles (knight/knave) + """ + status_dict = {} + print("\n[Ground Truth Parsing]") + + for line in solution_text.split('\n'): + line = line.strip() + if not line: + continue + + match = re.search(r'\b([A-Za-z]+)\b.*?\b(knight|knave)\b', line, re.IGNORECASE) + if match: + name, role = match.groups() + status_dict[name] = role.lower() + print(f" Found: {name} → {role}") + else: + print(f" [Warning] Unparseable line: '{line}'") + + return status_dict + +def parse_model_answer(answer_text: str, expected_names: list) -> Optional[Dict[str, str]]: + """Parses model's answer text into status dictionary. + + Args: + answer_text: Text extracted from model's <answer> tags + expected_names: List of character names requiring identification + + Returns: + Dictionary mapping character names to predicted roles, or None if incomplete + """ + status_dict = {} + print("\n[Model Answer Parsing]") + print(f" Expected characters: {expected_names}") + + for name in expected_names: + pattern = re.compile( + rf'\b{re.escape(name)}\b.*?\b(knight|knave)\b', + re.IGNORECASE + ) + match = pattern.search(answer_text) + + if match: + role = match.group(1).lower() + status_dict[name] = role + print(f" Found: {name} → {role}") + else: + print(f" [Error] Missing identification for {name}") + return None + + return status_dict + +def validate_response_structure(processed_str: str) -> bool: + """Performs comprehensive validation of response structure. + + Args: + processed_str: Processed response string from the model + + Returns: + Boolean indicating whether all formatting requirements are met + """ + print("\n[Structure Validation]") + validation_passed = True + return validation_passed + + # Check required tags + tags = { + 'think_end': ('</think>', 1), + 'answer_start': ('<answer>', 1), + 'answer_end': ('</answer>', 1) + } + + positions = {} + for tag_name, (tag_str, expected_count) in tags.items(): + count = processed_str.count(tag_str) + positions[tag_name] = pos = processed_str.find(tag_str) + + print(f" {tag_str}: count={count}, position={pos}") + + if count != expected_count: + print(f" [Error] {tag_str} appears {count} times (expected {expected_count})") + validation_passed = False + + # Verify tag order + if ( + positions['think_end'] > positions['answer_start'] or + positions['answer_start'] > positions['answer_end']): + print(" [Error] Incorrect tag order: Expected ...</think><answer>...</answer>") + validation_passed = False + else: + print(" Tag sequence validation passed") + + return validation_passed diff --git a/kk_eval/eval_kk-ckpt_list.sh b/kk_eval/eval_kk-ckpt_list.sh new file mode 100644 index 0000000..c3315fa --- /dev/null +++ b/kk_eval/eval_kk-ckpt_list.sh @@ -0,0 +1,50 @@ +# 定义基础变量 +export VLLM_ENABLE_V1_MULTIPROCESSING=0 +config="vllm" +num_limit=100 +max_token=3072 +ntrain=0 +split="test" +log_path="log/test" + +# 创建日志目录 +mkdir -p ${log_path} + +# Declare an associative array to store model mappings +# TODO:Replace with actual model name and path +declare -A model_dict=( + ["model_name_1"]="/path/to/model1" + ["model_name_2"]="/path/to/model2" +) + + +# Outer Loop:遍历model_dict +for exp_name in "${!model_dict[@]}"; do + model="${model_dict[$exp_name]}" + + # Inner Loop:遍历不同的 eval_nppl 值 + for eval_nppl in 2 3 4 5 6 7 8; do + log_file="${log_path}/${exp_name}_nppl${eval_nppl}.log" # 日志文件名包含模型名称和 eval_nppl + echo "Starting job for model: $model, eval_nppl: $eval_nppl, logging to $log_file" + + # 启动评估任务 + CUDA_VISIBLE_DEVICES=1 PYTHONUNBUFFERED=1 python main_eval_instruct.py \ + --batch_size 100 \ + --model ${model} \ + --max_token ${max_token} \ + --ntrain ${ntrain} \ + --config "${config}_${exp_name}_nppl${eval_nppl}" \ + --limit ${num_limit} \ + --split ${split} \ + --temperature 0.0 \ + --top_p 1.0 \ + --seed 0 \ + --problem_type "clean" \ + --output_file "${log_path}/${exp_name}_nppl${eval_nppl}.json" \ + --eval_nppl ${eval_nppl} > "$log_file" 2>&1 + done +done + + +# 等待所有后台任务完成 +wait
\ No newline at end of file diff --git a/kk_eval/kk_processor.py b/kk_eval/kk_processor.py new file mode 100644 index 0000000..626b265 --- /dev/null +++ b/kk_eval/kk_processor.py @@ -0,0 +1,208 @@ +import numpy as np +from kk_prompt import system_instruction, demonstration_2char, system_instruction_no_reason, demonstration_2char_no_reason + +from compute_score import extract_solution, parse_solution_text_format, parse_model_answer, validate_response_structure + +def num_tokens_from_string(string): + import tiktoken + """Returns the number of tokens in a text string.""" + encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") + num_tokens = len(encoding.encode(string)) + return num_tokens + + +def parse_cot_eval(pred_str, ans, + conclusion_patterns=['CONCLUSION:'], + verbose=False, + finish_patterns=["### Reason", "Let's think step by step again", "let's go back and check", "###"], + reformat_gold_conditions=None): + + def judge_string(input_str, reformat_gold_conditions, wrong_reason, finish_patterns): + correct_count = 0 + is_correct = False + beyond_id = len(reformat_gold_conditions)+1 + beyond_id_pattern = f"({beyond_id})" + + for finish_pattern in finish_patterns: + if finish_pattern in input_str: + input_str = input_str.split(finish_pattern)[0] + + if beyond_id_pattern in input_str: + is_correct = False + wrong_reason = "beyond_list" + elif "if" in input_str: + is_correct = False + wrong_reason = "contain_if" + else: + is_correct = True + for gold_condition in reformat_gold_conditions: + if gold_condition not in input_str: + is_correct = False + wrong_reason = "wrong_identity" + else: + correct_count += 1 + correct_ratio = correct_count/len(reformat_gold_conditions) + + return is_correct, wrong_reason, correct_ratio + + def check_numbers_in_string(s, N): + for i in range(1, N + 1): + if f"({i})" not in s: + return False + return True + + original_str = pred_str + pred_str = pred_str.split("### Question")[0] + pred_answer = pred_str + is_correct = False + correct_ratio = 0 + if reformat_gold_conditions is None: + gold = ans.replace(" and ", "").replace(".", "") + gold_conditions = gold.split(",") + reformat_gold_conditions = [] + for condition in gold_conditions: + gold_condition = condition.strip() # Remove leading and trailing spaces + reformat_gold_conditions.append(gold_condition) + + wrong_reason = "no_conclusion_matched" + for pattern in conclusion_patterns: + pred = pred_str.split(pattern) + if len(pred) > 1: + if len(pred[1]) > 0: # if the matched the answer is not empty + pred_answer = pred[1] + is_correct, wrong_reason, correct_ratio = judge_string( + pred_answer, reformat_gold_conditions, wrong_reason, finish_patterns) + break + if is_correct == False and wrong_reason == "no_conclusion_matched": + if check_numbers_in_string(pred_str, len(reformat_gold_conditions)): # the answer contains (1)..(2).. + is_correct, wrong_reason, correct_ratio = judge_string( + pred_str, reformat_gold_conditions, wrong_reason, finish_patterns) + if is_correct == False and verbose == True: + print("wrong_reason:",wrong_reason) + print("********* \nprediction before parse:\n", original_str) + print("********* \nprediction after parse:\n", pred_answer) + + return is_correct, pred_answer, wrong_reason, correct_ratio, reformat_gold_conditions + + +def parse_cot_eval_instruct(pred_str, ans, + conclusion_patterns=['<answer>'], + verbose=False, + finish_patterns=["</answer>"], + reformat_gold_conditions=None, + expected_names=None, + solution_text_format=None): + print("\n" + "="*80) + print(" Processing New Sample ".center(80, '=')) + + # Parse ground truth data + gt_status = parse_solution_text_format(solution_text_format) + expected_names = list(gt_status.keys()) + print(f"[Ground Truth] Final identities: {gt_status}") + + # Extract model answer + answer_text, processed_str = extract_solution(pred_str) + print(f"\n[Model Response]\n{processed_str}") + + # Validate response structure + format_correct = validate_response_structure(processed_str) + print(f"\n Format validation: {'PASS' if format_correct else 'FAIL'}") + + # Validate answer content + answer_score = 0 + is_correct = False + correct_ratio = 0 + wrong_reason = "no_conclusion_matched" + if format_correct and answer_text: + pred_status = parse_model_answer(answer_text, expected_names) + if pred_status: + print(f"\n[Content Validation]") + print(f" Expected: {gt_status}") + print(f" Predicted: {pred_status}") + + if pred_status == gt_status: + answer_score = 2 + is_correct = True + correct_ratio = 1 + print(" Content validation: FULL MATCH") + else: + answer_score = -1.5 + correct_ratio = 0 + wrong_reason = "wrong_identity" + print(" Content validation: MISMATCH") + else: + answer_score = -2 + correct_ratio = 0 + wrong_reason = "no_conclusion_matched" + print( "Fail to parse answer") + else: + print("\n[Content Validation] Skipped due to format errors or missing answer") + + if is_correct == False and verbose == True: + print("wrong_reason:",wrong_reason) + print("********* \nprediction before parse:\n", pred_str) + print("********* \nprediction after parse:\n", answer_text) + + return is_correct, answer_text, wrong_reason, correct_ratio, reformat_gold_conditions + + +class KKProcessor: + def __init__(self, cot=True, no_linebreak=True): + self.cot = cot + self.no_linebreak = no_linebreak + + def format_example(self, test_records, idx, model_name=None): + + item = test_records[idx] + + prompt = "### Question: "+item["quiz"] + "\n" + if self.cot: + if model_name in ["deepseek-ai/deepseek-math-7b-instruct", "AI-MO/NuminaMath-7B-CoT"]: + prompt += "Please reason step by step, and put your final answer within \\boxed{}." + else: + prompt += "### Answer: Let's think step by step. " + else: + if self.no_linebreak: + prompt += "### Answer:" + else: + prompt += "### Answer:\n" + answer = item["solution_text"] + return prompt, answer + + def gen_test_prompt(self, ntrain, test_records, idx, model_name=None): + if self.cot: + train_prompt = system_instruction + else: + train_prompt = system_instruction_no_reason + + if ntrain == 1: + if self.cot: + train_prompt += "\n\n"+demonstration_2char + else: + train_prompt += "\n\n"+demonstration_2char_no_reason + elif ntrain > 1: + raise NotImplementedError + + prompt_end, answer = self.format_example(test_records, idx, model_name) + prompt = train_prompt + "\n\n" + prompt_end + + return prompt, answer + + def _parse_cot_eval(self, pred_str, ans, model_name=None): + conclusion_patterns = ['CONCLUSION:', 'Conclusion:', 'conclusion:'] + + if model_name in ["deepseek-ai/deepseek-math-7b-instruct", "AI-MO/NuminaMath-7B-CoT"]: + conclusion_patterns = ['boxed{', 'CONCLUSION:', 'Conclusion:', 'conclusion:'] + + is_correct, pred_answer, wrong_reason, correct_ratio, reformat_gold_conditions = parse_cot_eval( + pred_str, ans, conclusion_patterns=conclusion_patterns, verbose=False) + + return is_correct, pred_answer, reformat_gold_conditions + + def _parse_cot_eval_instruct(self, pred_str, ans, model_name=None, expected_names=None, solution_text_format=None): + conclusion_patterns = ['<answer>'] + + is_correct, pred_answer, wrong_reason, correct_ratio, reformat_gold_conditions = parse_cot_eval_instruct( + pred_str, ans, conclusion_patterns=conclusion_patterns, verbose=True, expected_names=expected_names, solution_text_format=solution_text_format) + + return is_correct, pred_answer, reformat_gold_conditions diff --git a/kk_eval/kk_prompt.py b/kk_eval/kk_prompt.py new file mode 100644 index 0000000..cb9e2d8 --- /dev/null +++ b/kk_eval/kk_prompt.py @@ -0,0 +1,38 @@ +system_instruction='''Your task is to solve a logical reasoning problem. You are given set of statements from which you must logically deduce the identity of a set of characters. + +You must infer the identity of each character. First, explain your reasoning. At the end of your answer, you must clearly state the identity of each character by following the format: + +CONCLUSION: +(1) ... +(2) ... +(3) ... +''' + + +system_instruction_no_reason='''Your task is to solve a logical reasoning problem. You are given set of statements from which you must logically deduce the identity of a set of characters. + +You must infer the identity of each character. At the end of your answer, you must clearly state the identity of each character by following the format: + +CONCLUSION: +(1) ... +(2) ... +(3) ... +''' + +demonstration_2char_no_reason='''### Question: A very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 2 inhabitants: Jack, and Sophia. Jack tells you that Sophia is not a knave. Sophia says that If Jack is a knight then Sophia is a knight. So who is a knight and who is a knave? +### Answer: +CONCLUSION: +(1) Jack is a knight +(2) Sophia is a knight +''' + + + +demonstration_2char='''### Question: A very special island is inhabited only by knights and knaves. Knights always tell the truth, and knaves always lie. You meet 2 inhabitants: Ella, and Penelope. In a statement by Ella: \"Ella is a knight or Penelope is a knight\". According to Penelope, \"Ella is a knave if and only if Penelope is a knight\". So who is a knight and who is a knave? +### Answer: Let's think step by step, by considering whether each person is lying and if that leads to contradiction. Assume Ella is a knight. Penelope cannot be a knight, because this would contradict the claim of their own. Penelope cannot be a knave, because this would contradict the false claim of their own. We have exhausted all possibilities for Penelope, so let us go back and reconsider Ella. Assume Ella is a knave. Penelope cannot be a knight, because this would contradict the false claim of Ella. Assume Penelope is a knave. This leads to a feasible solution. +CONCLUSION: +(1) Ella is a knave +(2) Penelope is a knave +''' + + diff --git a/kk_eval/main_eval_instruct.py b/kk_eval/main_eval_instruct.py new file mode 100644 index 0000000..7eb1322 --- /dev/null +++ b/kk_eval/main_eval_instruct.py @@ -0,0 +1,209 @@ + +import argparse +import json +import os +import numpy as np +import random +import torch +import time +from vllm import LLM, SamplingParams +from kk_processor import KKProcessor +import datasets +from concurrent.futures import ThreadPoolExecutor + +def load_jsonl(file_path): + """Load data from a JSONL file.""" + if not os.path.exists(file_path): + return [] + with open(file_path, "r") as file: + return [json.loads(line) for line in file] + +def write_jsonl(output_file, data): + """Write data to a JSONL file.""" + with open(output_file, "w+") as file: + for item in data: + file.write(json.dumps(item) + "\n") + +def init_seed(seed=42): + """Initialize random seeds for reproducibility.""" + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.random.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + +def load_eval_records(args, subject): + """Load evaluation records based on arguments.""" + if args.problem_type != "clean": + records = datasets.load_dataset('K-and-K/perturbed-knights-and-knaves', + data_files=f"{args.split}/{args.problem_type}/{subject}.jsonl")["train"] + else: + records = datasets.load_dataset('K-and-K/knights-and-knaves', + data_files=f"{args.split}/{subject}.jsonl")["train"] + return records + +def eval_subject(args, subject, llm, test_records, kk_proc, exist_result_records): + """Evaluate one subject.""" + cors = [] + start_index = len(exist_result_records) + print(f"Processing {subject} starting from index {start_index}") + + # Prepare all prompts using multi-threading + prompts, labels = [], [] + new_prompts = [] + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(kk_proc.gen_test_prompt, args.ntrain, test_records, i, args.model) + for i in range(start_index, len(test_records)) + ] + for future in futures: + prompt, label = future.result() + prompts.append(prompt) + labels.append(label) + question = prompt.split("### Question:")[1].strip().split("### Answer:")[0].strip() + new_prompt = f"""<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>. Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within <answer> </answer> tags. i.e., <answer> (1) ...\n(2) ... </answer>.\n<|im_end|>\n<|im_start|>user\n{question}\n<|im_end|>\n<|im_start|>assistant\n<think>""" + new_prompts.append(new_prompt) + # Batch inference using vLLM + sampling_params = SamplingParams( + temperature=args.temperature, + top_p=args.top_p, + max_tokens=args.max_token, + seed=args.seed + ) + outputs = llm.generate(new_prompts, sampling_params) + responses = [output.outputs[0].text for output in outputs] + + # Process results and save to exist_result_records + for i, (new_prompt, label, response) in enumerate(zip(new_prompts, labels, responses), start=start_index): + cor, parsed_pred, reformat_gold_conditions = kk_proc._parse_cot_eval_instruct(response, label, args.model, test_records[i]['names'], test_records[i]['solution_text_format']) + cors.append(cor) + new_item = { + 'quiz': test_records[i]['quiz'], + 'names': test_records[i]['names'], + 'solution': test_records[i]['solution'], + 'solution_text': test_records[i]['solution_text'], + 'solution_text_format': test_records[i]['solution_text_format'], + 'index': test_records[i]['index'], + 'predicts': parsed_pred, + 'labels': reformat_gold_conditions, + 'correct': cor, + 'response': response, + 'prompts': new_prompt, + } + exist_result_records.append(new_item) + + acc = np.mean(cors) + print(f"Subject: {subject}, Accuracy: {acc:.3f}") + return cors, acc, exist_result_records + +def load_limited_test_records(args, subject, exist_result_records): + """Load limited test records based on given arguments.""" + test_records = load_eval_records(args, subject) + print("test_records:", test_records) + if args.limit is not None: + test_records = test_records.select(range(min(args.limit, len(test_records)))) + if args.limit <= len(exist_result_records): + return None # Already processed all samples + return test_records + +def save_final_acc_results(all_cors, results, fname): + """Save final accuracy results to a file.""" + if all_cors: + weighted_acc = np.mean(np.concatenate(all_cors)) + results["weighted_accuracy"] = weighted_acc + print(f"Overall Average Accuracy: {weighted_acc:.3f}") + with open(fname, "w") as f: + json.dump(results, f) + +def load_previous_acc_results(fname): + """Load previous accuracy results.""" + acc_results = {"subject": {}} + if os.path.isfile(fname): + with open(fname, 'r', encoding='utf-8') as file: + acc_results = json.load(file) + return acc_results + +def get_subjects_to_eval(args): + """Get subjects to evaluate.""" + subjects = [] + if args.split == "test": + if args.eval_nppl == 0: + subjects = [f"people{nppl}_num100" for nppl in range(2, 9)] + else: + subjects = [f"people{args.eval_nppl}_num100"] + elif args.split == "train": + if args.eval_nppl == 2: + subjects = ["people2_num200"] + elif args.eval_nppl > 2: + subjects = [f"people{args.eval_nppl}_num1000"] + return subjects + +def main(args): + model_short_name = "/".join(args.model.split("/")[-2:]) + prefix = os.path.join(args.save_dir, f"{model_short_name}_{args.ntrain}shot") + args.config += f"_token{args.max_token}{('_cot' if args.cot else '')}" \ + f"_{args.split}{('_' + args.problem_type if args.problem_type != 'clean' else '')}" + output_folder = os.path.join(prefix, args.config) + acc_fname = args.output_file # os.path.join(prefix, f"result_{args.config}.json") + os.makedirs(output_folder, exist_ok=True) + + print("Configuration:", args.config) + print("Output Folder:", output_folder) + + kk_proc = KKProcessor(cot=args.cot, no_linebreak=args.no_linebreak) + subjects = get_subjects_to_eval(args) + acc_results = load_previous_acc_results(acc_fname) + + # Initialize vLLM model + llm = LLM( + model=args.model, + tensor_parallel_size=args.ngpus, + max_model_len=args.max_token, + gpu_memory_utilization=0.85, + enforce_eager=False + ) + + all_cors = [] + for subject in subjects: + result_outfile = os.path.join(output_folder, f"{subject}.jsonl") + exist_result_records = load_jsonl(result_outfile) + test_records = load_limited_test_records(args, subject, exist_result_records) + print("exist_result_records:", exist_result_records) + print("test_records:", test_records) + if test_records is None: + continue + + cors, acc, result_records = eval_subject(args, subject, llm, test_records, kk_proc, exist_result_records) + write_jsonl(result_outfile, result_records) + all_cors.append(cors) + acc_results["subject"][subject] = acc + + save_final_acc_results(all_cors, acc_results, acc_fname) + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Evaluation script for KK dataset") + parser.add_argument("--ntrain", "-k", type=int, default=0, help="Number of training examples") + parser.add_argument("--data_dir", "-d", type=str, default="data", help="Data directory") + parser.add_argument("--save_dir", "-s", type=str, default="result_qa", help="Save directory") + parser.add_argument("--model", "-m", type=str, required=True, help="Model name or path") + parser.add_argument("--max_token", type=int, default=1024, help="Maximum number of tokens") + parser.add_argument("--limit", type=int, default=None, help="Limit the number of examples") + parser.add_argument("--cot", action="store_true", help="Use chain-of-thought prompting") + parser.add_argument("--no_linebreak", action="store_true", help="Remove line breaks") + parser.add_argument("--batch_size", type=int, default=8, help="Batch size for VLLM") + parser.add_argument("--split", type=str, default="test", choices=["test", "train"], help="Data split to use") + parser.add_argument("--eval_nppl", type=int, default=0, help="Number of people to evaluate") + parser.add_argument("--problem_type", type=str, default="clean", help="Problem perturbation type") + parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling") + parser.add_argument("--top_p", type=float, default=1.0, help="Top-p (nucleus) sampling") + parser.add_argument("--config", type=str, default="default_config", help="Configuration string for saving results") + parser.add_argument("--ngpus", type=int, default=1, help="Number of GPUs for parallel inference") + parser.add_argument("--output_file", type=str, default="output_file", help="Output File Path") + parser.add_argument("--seed", type=int, default=42, help="Seed") + args = parser.parse_args() + + init_seed(seed=args.seed) + main(args) + +
\ No newline at end of file |
