summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/main.py
blob: d78ddec7ec69bd3e52f7f658aa97e54e62d496b3 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import argparse

from OpenCodeEval.args import get_args, check_args
from OpenCodeEval.utils import refine_text, write_jsonl, stream_jsonl, calculate_pass_at_k

from OpenCodeEval.factory import BenchmarkFactory, BackendFactory

from tqdm.contrib.concurrent import process_map, thread_map

def main():

    parser = argparse.ArgumentParser()
    args = get_args(parser)
    args = check_args(args)

    save_path = args.save_path
    os.makedirs(save_path, exist_ok = True)

    task = BenchmarkFactory.get_task(args)

    # get prompts
    prompts = task.get_prompt()

    for prompt in prompts:
        prompt['prompt'] = refine_text(args.prompt_prefix + prompt['prompt'] + args.prompt_suffix)
    prompts = sorted(prompts, key = lambda x: x['task_id'])
    write_jsonl(os.path.join(save_path, "prompts.jsonl"), prompts)

    if args.split  ==  'plus':
        base_path = save_path.replace('plus', 'base')
        if os.path.exists(os.path.join(base_path, "generations.jsonl")):
            generations = list(stream_jsonl(os.path.join(base_path, "generations.jsonl")))
            write_jsonl(os.path.join(save_path, "generations.jsonl"), generations)

    # check if generations exits
    if not os.path.exists(os.path.join(save_path, "generations.jsonl")):
        
        prompts = list(stream_jsonl(os.path.join(save_path, "prompts.jsonl")))
        # get generations
        decoder = BackendFactory.get_backend(args)
        if decoder.is_chat():
            decoder.set_stop(task.chat_stop)
        else:
            decoder.set_stop(task.chat_stop + task.base_stop)

        generations = decoder.generate(
            prompts,
            args.response_prefix,
            args.response_suffix
        )

        generations = sorted([data for data in generations if data['completion']], key = lambda x: (x['task_id'], x['completion_id']))
        write_jsonl(os.path.join(save_path, "generations.jsonl"), generations)

    else:
        
        generated_ids = [data['task_id'] for data in stream_jsonl(os.path.join(save_path, "generations.jsonl"))]
        prompts = [data for data in stream_jsonl(os.path.join(save_path, "prompts.jsonl")) if data['task_id'] not in generated_ids]
        if len(prompts) > 0:

            # get generations
            decoder = BackendFactory.get_backend(args)
            if decoder.is_chat():
                decoder.set_stop(task.chat_stop)
            else:
                decoder.set_stop(task.chat_stop + task.base_stop)

            continue_generations = decoder.generate(
                prompts,
                args.response_prefix,
                args.response_suffix
            )

            generations = sorted([data for data in continue_generations if data['completion']] + list(stream_jsonl(os.path.join(save_path, "generations.jsonl"))), key = lambda x: (x['task_id'], x['completion_id']))
            write_jsonl(os.path.join(save_path, "generations.jsonl"), generations)


    # post-process generations
    # if not os.path.exists(os.path.join(save_path, "solutions.jsonl")):
    if True:

        generations = list(stream_jsonl(os.path.join(save_path, "generations.jsonl")))
        solutions = thread_map(
            task.postprocess_generation,
            generations,
            max_workers = args.num_workers,
            desc = "Post-processing Generations"
        )
        solutions = sorted(solutions, key = lambda x: (x['task_id'], x['completion_id']))
        write_jsonl(os.path.join(save_path, "solutions.jsonl"), solutions)

    # evaluate solutions
    # if not os.path.exists(os.path.join(save_path, "evaluations.jsonl")):
    if True:
        solutions = list(stream_jsonl(os.path.join(save_path, "solutions.jsonl")))
        evaluations = thread_map(
            task.process_results,
            solutions,
            max_workers = args.num_workers,
            desc = "Evaluating Solutions"
        )
        evaluations = sorted(evaluations, key = lambda x: (x['task_id'], x['completion_id']))
        write_jsonl(os.path.join(save_path, "evaluations.jsonl"), evaluations)

    # calculate pass@k
    # if not os.path.exists(os.path.join(save_path, "results.jsonl")):
    if True:
        evaluations = list(stream_jsonl(os.path.join(save_path, "evaluations.jsonl")))
        results = calculate_pass_at_k(evaluations, args.num_samples, args.list_k)
        write_jsonl(os.path.join(save_path, "results.jsonl"), results)


if __name__ == "__main__":
    main()