summaryrefslogtreecommitdiff
path: root/kk_eval/main_eval_instruct.py
blob: 7eb132260d3525e2d6fd7b7ea4ed2b605ff25c6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
      
import argparse
import json
import os
import numpy as np
import random
import torch
import time
from vllm import LLM, SamplingParams
from kk_processor import KKProcessor
import datasets
from concurrent.futures import ThreadPoolExecutor

def load_jsonl(file_path):
    """Load data from a JSONL file."""
    if not os.path.exists(file_path):
        return []
    with open(file_path, "r") as file:
        return [json.loads(line) for line in file]

def write_jsonl(output_file, data):
    """Write data to a JSONL file."""
    with open(output_file, "w+") as file:
        for item in data:
            file.write(json.dumps(item) + "\n")

def init_seed(seed=42):
    """Initialize random seeds for reproducibility."""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.random.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)

def load_eval_records(args, subject):
    """Load evaluation records based on arguments."""
    if args.problem_type != "clean":
        records = datasets.load_dataset('K-and-K/perturbed-knights-and-knaves', 
                                        data_files=f"{args.split}/{args.problem_type}/{subject}.jsonl")["train"]
    else:
        records = datasets.load_dataset('K-and-K/knights-and-knaves', 
                                        data_files=f"{args.split}/{subject}.jsonl")["train"]
    return records

def eval_subject(args, subject, llm, test_records, kk_proc, exist_result_records):
    """Evaluate one subject."""
    cors = []
    start_index = len(exist_result_records)
    print(f"Processing {subject} starting from index {start_index}")

    # Prepare all prompts using multi-threading
    prompts, labels = [], []
    new_prompts = []
    with ThreadPoolExecutor() as executor:
        futures = [
            executor.submit(kk_proc.gen_test_prompt, args.ntrain, test_records, i, args.model)
            for i in range(start_index, len(test_records))
        ]
        for future in futures:
            prompt, label = future.result()
            prompts.append(prompt)
            labels.append(label)
            question = prompt.split("### Question:")[1].strip().split("### Answer:")[0].strip()
            new_prompt = f"""<|im_start|>system\nYou are a helpful assistant. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and<answer> </answer> tags, respectively, i.e., <think> reasoning process here </think><answer> answer here </answer>.  Now the user asks you to solve a logical reasoning problem. After thinking, when you finally reach a conclusion, clearly state the identity of each character within <answer> </answer> tags. i.e., <answer> (1) ...\n(2) ... </answer>.\n<|im_end|>\n<|im_start|>user\n{question}\n<|im_end|>\n<|im_start|>assistant\n<think>"""
            new_prompts.append(new_prompt)
    # Batch inference using vLLM
    sampling_params = SamplingParams(
        temperature=args.temperature,
        top_p=args.top_p,
        max_tokens=args.max_token,
        seed=args.seed
    )
    outputs = llm.generate(new_prompts, sampling_params)
    responses = [output.outputs[0].text for output in outputs]

    # Process results and save to exist_result_records
    for i, (new_prompt, label, response) in enumerate(zip(new_prompts, labels, responses), start=start_index):
        cor, parsed_pred, reformat_gold_conditions = kk_proc._parse_cot_eval_instruct(response, label, args.model, test_records[i]['names'], test_records[i]['solution_text_format'])
        cors.append(cor)
        new_item = {
            'quiz': test_records[i]['quiz'],
            'names': test_records[i]['names'],
            'solution': test_records[i]['solution'],
            'solution_text': test_records[i]['solution_text'],
            'solution_text_format': test_records[i]['solution_text_format'],
            'index': test_records[i]['index'],
            'predicts': parsed_pred,
            'labels': reformat_gold_conditions,
            'correct': cor,
            'response': response,
            'prompts': new_prompt,
        }
        exist_result_records.append(new_item)

    acc = np.mean(cors)
    print(f"Subject: {subject}, Accuracy: {acc:.3f}")
    return cors, acc, exist_result_records

def load_limited_test_records(args, subject, exist_result_records):
    """Load limited test records based on given arguments."""
    test_records = load_eval_records(args, subject)
    print("test_records:", test_records)
    if args.limit is not None:
        test_records = test_records.select(range(min(args.limit, len(test_records))))
        if args.limit <= len(exist_result_records):
            return None  # Already processed all samples
    return test_records

def save_final_acc_results(all_cors, results, fname):
    """Save final accuracy results to a file."""
    if all_cors:
        weighted_acc = np.mean(np.concatenate(all_cors))
        results["weighted_accuracy"] = weighted_acc
        print(f"Overall Average Accuracy: {weighted_acc:.3f}")
        with open(fname, "w") as f:
            json.dump(results, f)

def load_previous_acc_results(fname):
    """Load previous accuracy results."""
    acc_results = {"subject": {}}
    if os.path.isfile(fname):
        with open(fname, 'r', encoding='utf-8') as file:
            acc_results = json.load(file)
    return acc_results

def get_subjects_to_eval(args):
    """Get subjects to evaluate."""
    subjects = []
    if args.split == "test":
        if args.eval_nppl == 0:
            subjects = [f"people{nppl}_num100" for nppl in range(2, 9)]
        else:
            subjects = [f"people{args.eval_nppl}_num100"]
    elif args.split == "train":
        if args.eval_nppl == 2:
            subjects = ["people2_num200"]
        elif args.eval_nppl > 2:
            subjects = [f"people{args.eval_nppl}_num1000"]
    return subjects

def main(args):
    model_short_name = "/".join(args.model.split("/")[-2:])
    prefix = os.path.join(args.save_dir, f"{model_short_name}_{args.ntrain}shot")
    args.config += f"_token{args.max_token}{('_cot' if args.cot else '')}" \
                   f"_{args.split}{('_' + args.problem_type if args.problem_type != 'clean' else '')}"
    output_folder = os.path.join(prefix, args.config)
    acc_fname = args.output_file # os.path.join(prefix, f"result_{args.config}.json")
    os.makedirs(output_folder, exist_ok=True)

    print("Configuration:", args.config)
    print("Output Folder:", output_folder)

    kk_proc = KKProcessor(cot=args.cot, no_linebreak=args.no_linebreak)
    subjects = get_subjects_to_eval(args)
    acc_results = load_previous_acc_results(acc_fname)

    # Initialize vLLM model
    llm = LLM(
        model=args.model,
        tensor_parallel_size=args.ngpus,
        max_model_len=args.max_token,
        gpu_memory_utilization=0.85,
        enforce_eager=False
    )

    all_cors = []
    for subject in subjects:
        result_outfile = os.path.join(output_folder, f"{subject}.jsonl")
        exist_result_records = load_jsonl(result_outfile)
        test_records = load_limited_test_records(args, subject, exist_result_records)
        print("exist_result_records:", exist_result_records)
        print("test_records:", test_records)
        if test_records is None:
            continue

        cors, acc, result_records = eval_subject(args, subject, llm, test_records, kk_proc, exist_result_records)
        write_jsonl(result_outfile, result_records)
        all_cors.append(cors)
        acc_results["subject"][subject] = acc

    save_final_acc_results(all_cors, acc_results, acc_fname)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Evaluation script for KK dataset")
    parser.add_argument("--ntrain", "-k", type=int, default=0, help="Number of training examples")
    parser.add_argument("--data_dir", "-d", type=str, default="data", help="Data directory")
    parser.add_argument("--save_dir", "-s", type=str, default="result_qa", help="Save directory")
    parser.add_argument("--model", "-m", type=str, required=True, help="Model name or path")
    parser.add_argument("--max_token", type=int, default=1024, help="Maximum number of tokens")
    parser.add_argument("--limit", type=int, default=None, help="Limit the number of examples")
    parser.add_argument("--cot", action="store_true", help="Use chain-of-thought prompting")
    parser.add_argument("--no_linebreak", action="store_true", help="Remove line breaks")
    parser.add_argument("--batch_size", type=int, default=8, help="Batch size for VLLM")
    parser.add_argument("--split", type=str, default="test", choices=["test", "train"], help="Data split to use")
    parser.add_argument("--eval_nppl", type=int, default=0, help="Number of people to evaluate")
    parser.add_argument("--problem_type", type=str, default="clean", help="Problem perturbation type")
    parser.add_argument("--temperature", type=float, default=1.0, help="Temperature for sampling")
    parser.add_argument("--top_p", type=float, default=1.0, help="Top-p (nucleus) sampling")
    parser.add_argument("--config", type=str, default="default_config", help="Configuration string for saving results")
    parser.add_argument("--ngpus", type=int, default=1, help="Number of GPUs for parallel inference")
    parser.add_argument("--output_file", type=str, default="output_file", help="Output File Path")
    parser.add_argument("--seed", type=int, default=42, help="Seed")
    args = parser.parse_args()

    init_seed(seed=args.seed)
    main(args)