summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/benchmark/Bird.py
blob: b4359fbb645c5d6742bfccfcd440d296f5fbc05a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
import os
import sys

from loguru import logger
from typing import Literal

from OpenCodeEval.benchmark.base import Benchmark
from OpenCodeEval.utils import refine_text, program_extract, stream_jsonl
from OpenCodeEval.eval.sql_test import check_correctness

class Bird(Benchmark):

    name: str = "Bird"

    def __init__(
        self,
        split: Literal["train", "dev"] = "dev",
        time_out: float = 30.0,
        prompt_type: str = "Instruction"
    ):
    
        super().__init__()
        
        self.split = split
        self.time_out = time_out
        self.prompt_type = prompt_type

        if self.prompt_type == "Completion":
            logger.error("Completion prompt type not supported for Bird")

        self.database = os.path.join(self.path, f"{self.name}/{self.split}/database")
        self.path = os.path.join(self.path, f"{self.name}/{self.split}/data.jsonl")

        self.tasks = self.get_task()
        
    def get_task(self):
        """
        Get the task data from the json file into a dictionary.
        """
    
        tasks = {}
        
        for task_data in stream_jsonl(filename = self.path):

            tasks[int(task_data['id'])] = task_data
        
        return tasks

    def get_prompt(self):
        """
        Builds the prompt for the LM to generate from.
        """

        def construct_prompt(data):
            instruction = data['o_schema']
            instruction += f"\n\n-- External Knowledge: {data['evidence']}\n\n"
            instruction += "-- Using valid SQLite and understanding External Knowledge, answer the following questions for the tables provided above.\n\n"
            instruction += f"Question: {data['question']}\n"
            return instruction

        prompts = []
        
        
        for task_id, task_data in self.tasks.items():
            
            prompt = construct_prompt(task_data)
            
            prompts.append(
                dict(
                    task_id = task_id,
                    prompt = refine_text(prompt)
                )
            )
        return prompts

    def postprocess_generation(self, generation):
        """
        Postprocess the generations.
        """

        def one_line_sql(response):

            response = program_extract(response, "sql", "last").strip()

            lines = response.splitlines()
            lines = [l.strip() for l in lines if l.strip()]
            sql = " ".join(lines)
            
            return sql

        result = dict(
            task_id = generation['task_id'],
            completion_id = generation['completion_id'],
            solution = one_line_sql(generation['completion'])
        )

        return result

    def process_results(self, solution):
        """
        Takes the list of LM generations and evaluates them against the test cases
        """

        task_data = self.tasks[solution['task_id']]

        db_path = self.database + f"/{task_data['db_id']}/{task_data['db_id']}.sqlite"

        result, passed, sql_return = check_correctness(
            solution['solution'],
            task_data['sql'],
            db_path,
            self.time_out,
            "set_match"
        )
        
        return dict(
            task_id = solution['task_id'],
            completion_id = solution['completion_id'],
            passed = passed,
            result = result,
            solution = solution['solution'],
            sql_return = sql_return
        )