diff options
| author | = <=> | 2025-06-04 11:49:37 +0800 |
|---|---|---|
| committer | = <=> | 2025-06-04 11:49:37 +0800 |
| commit | 947d9dfdf16ae37109898111a5caacae7377b96d (patch) | |
| tree | ff4e884020fb7d968a6192106f370b215647f569 /code_eval/OpenCodeEval/benchmark/Bird.py | |
| parent | 5e163b529a78d528b745b8b57ba794b7b2bba97a (diff) | |
update code and kk eval
Diffstat (limited to 'code_eval/OpenCodeEval/benchmark/Bird.py')
| -rw-r--r-- | code_eval/OpenCodeEval/benchmark/Bird.py | 123 |
1 files changed, 123 insertions, 0 deletions
diff --git a/code_eval/OpenCodeEval/benchmark/Bird.py b/code_eval/OpenCodeEval/benchmark/Bird.py new file mode 100644 index 0000000..b4359fb --- /dev/null +++ b/code_eval/OpenCodeEval/benchmark/Bird.py @@ -0,0 +1,123 @@ +import os +import sys + +from loguru import logger +from typing import Literal + +from OpenCodeEval.benchmark.base import Benchmark +from OpenCodeEval.utils import refine_text, program_extract, stream_jsonl +from OpenCodeEval.eval.sql_test import check_correctness + +class Bird(Benchmark): + + name: str = "Bird" + + def __init__( + self, + split: Literal["train", "dev"] = "dev", + time_out: float = 30.0, + prompt_type: str = "Instruction" + ): + + super().__init__() + + self.split = split + self.time_out = time_out + self.prompt_type = prompt_type + + if self.prompt_type == "Completion": + logger.error("Completion prompt type not supported for Bird") + + self.database = os.path.join(self.path, f"{self.name}/{self.split}/database") + self.path = os.path.join(self.path, f"{self.name}/{self.split}/data.jsonl") + + self.tasks = self.get_task() + + def get_task(self): + """ + Get the task data from the json file into a dictionary. + """ + + tasks = {} + + for task_data in stream_jsonl(filename = self.path): + + tasks[int(task_data['id'])] = task_data + + return tasks + + def get_prompt(self): + """ + Builds the prompt for the LM to generate from. + """ + + def construct_prompt(data): + instruction = data['o_schema'] + instruction += f"\n\n-- External Knowledge: {data['evidence']}\n\n" + instruction += "-- Using valid SQLite and understanding External Knowledge, answer the following questions for the tables provided above.\n\n" + instruction += f"Question: {data['question']}\n" + return instruction + + prompts = [] + + + for task_id, task_data in self.tasks.items(): + + prompt = construct_prompt(task_data) + + prompts.append( + dict( + task_id = task_id, + prompt = refine_text(prompt) + ) + ) + return prompts + + def postprocess_generation(self, generation): + """ + Postprocess the generations. + """ + + def one_line_sql(response): + + response = program_extract(response, "sql", "last").strip() + + lines = response.splitlines() + lines = [l.strip() for l in lines if l.strip()] + sql = " ".join(lines) + + return sql + + result = dict( + task_id = generation['task_id'], + completion_id = generation['completion_id'], + solution = one_line_sql(generation['completion']) + ) + + return result + + def process_results(self, solution): + """ + Takes the list of LM generations and evaluates them against the test cases + """ + + task_data = self.tasks[solution['task_id']] + + db_path = self.database + f"/{task_data['db_id']}/{task_data['db_id']}.sqlite" + + result, passed, sql_return = check_correctness( + solution['solution'], + task_data['sql'], + db_path, + self.time_out, + "set_match" + ) + + return dict( + task_id = solution['task_id'], + completion_id = solution['completion_id'], + passed = passed, + result = result, + solution = solution['solution'], + sql_return = sql_return + )
\ No newline at end of file |
