summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/benchmark/LeetCode.py
blob: 97b3489b0a7e28f82812ae8bc7bf838db7c1aaaf (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import os
from typing import Literal
from loguru import logger

from OpenCodeEval.benchmark.base import Benchmark, PYTHON_IMPORTS, LEETCODE_IMPORTS, PYTHON_STOP
from OpenCodeEval.utils import refine_text, stream_jsonl
from OpenCodeEval.eval.func_eval import check_correctness
from OpenCodeEval.eval.sanitize import sanitize
class LeetCode(Benchmark):

    name: str = "LeetCode"

    imports_code = PYTHON_IMPORTS + LEETCODE_IMPORTS
    chat_stop = PYTHON_STOP
    base_stop = ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert "]

    def __init__(
        self,
        split: Literal["contest", "train", "validation", "test"] = "contest",
        time_out: float = 3.0,
        prompt_type: Literal["Completion", "Instruction"] = "Instruction"
    ):

        super().__init__()
        
        self.name = name
        self.split = split
        self.time_out = time_out

        self.prompt_type = prompt_type
        if self.split != "contest" and self.prompt_type == "Completion":
            logger.error(f"Completion prompt type not support {self.split} split")

        self.path = os.path.join(self.path, f"{self.name}/{self.split}.jsonl")
        self.tasks = self.get_task()

    def get_task(self):
        """
        Get the task data from the jsonl file into a dictionary.
        """

        tasks = {}
        
        for task_data in stream_jsonl(filename=self.path):

            if self.split == "contest":
                task_id = int(task_data["meta"]["questionId"])
            else:
                task_id = int(task_data["meta"]["question_id"])
            tasks[task_id] = task_data
        
        return tasks
        
    def get_prompt(self):
        """
        Builds the prompt for the LM to generate from.
        """

        prompts = []
        for task_id, task_data in self.tasks.items():

            if self.split == "contest":
                if self.prompt_type == "Completion":
                    prompt = task_data['prompt']
                elif self.prompt_type == "Instruction":
                    prompt = task_data['prompt_sft']
            else:
                prompt = task_data['meta']['query']

            prompts.append(
                dict(
                    task_id = task_id,
                    prompt = refine_text(prompt)
                )
            )

        return prompts

    def postprocess_generation(self, generation):
        """
        Postprocess the generations.
        """

        return dict(
            task_id = generation['task_id'],
            completion_id = generation['completion_id'],
            solution = sanitize(
                text = generation['completion'],
                entrypoint = "Solution",
            )
        )
    
    def process_results(self, solution):
        """
        Takes the list of LM generations and evaluates them against the test cases
        """

        task_data = self.tasks[solution['task_id']]

        if self.split == "contest":
            code = (
                "\n".join(self.imports_code) + "\n\n"
                + solution['solution'] + "\n\n"
                + task_data['test']
                )
        else:
            code = (
                "\n".join(self.imports_code) + "\n\n"
                + task_data['meta']['lang_code'] + "\n"
                + "        pass\n" + "\n"
                + solution['solution'] + "\n"
                + task_data['test'] + "\n"
                + f"check({task_data['entry_point']})"
                )
        
        result = check_correctness(solution['task_id'],
                                   solution['completion_id'],
                                   code,
                                   self.time_out)
        
        return result