From 947d9dfdf16ae37109898111a5caacae7377b96d Mon Sep 17 00:00:00 2001 From: = <=> Date: Wed, 4 Jun 2025 11:49:37 +0800 Subject: update code and kk eval --- code_eval/OpenCodeEval/benchmark/LeetCode.py | 121 +++++++++++++++++++++++++++ 1 file changed, 121 insertions(+) create mode 100644 code_eval/OpenCodeEval/benchmark/LeetCode.py (limited to 'code_eval/OpenCodeEval/benchmark/LeetCode.py') diff --git a/code_eval/OpenCodeEval/benchmark/LeetCode.py b/code_eval/OpenCodeEval/benchmark/LeetCode.py new file mode 100644 index 0000000..97b3489 --- /dev/null +++ b/code_eval/OpenCodeEval/benchmark/LeetCode.py @@ -0,0 +1,121 @@ +import os +from typing import Literal +from loguru import logger + +from OpenCodeEval.benchmark.base import Benchmark, PYTHON_IMPORTS, LEETCODE_IMPORTS, PYTHON_STOP +from OpenCodeEval.utils import refine_text, stream_jsonl +from OpenCodeEval.eval.func_eval import check_correctness +from OpenCodeEval.eval.sanitize import sanitize +class LeetCode(Benchmark): + + name: str = "LeetCode" + + imports_code = PYTHON_IMPORTS + LEETCODE_IMPORTS + chat_stop = PYTHON_STOP + base_stop = ["\ndef ", "\nclass ", "\nimport ", "\nfrom ", "\nassert "] + + def __init__( + self, + split: Literal["contest", "train", "validation", "test"] = "contest", + time_out: float = 3.0, + prompt_type: Literal["Completion", "Instruction"] = "Instruction" + ): + + super().__init__() + + self.name = name + self.split = split + self.time_out = time_out + + self.prompt_type = prompt_type + if self.split != "contest" and self.prompt_type == "Completion": + logger.error(f"Completion prompt type not support {self.split} split") + + self.path = os.path.join(self.path, f"{self.name}/{self.split}.jsonl") + self.tasks = self.get_task() + + def get_task(self): + """ + Get the task data from the jsonl file into a dictionary. + """ + + tasks = {} + + for task_data in stream_jsonl(filename=self.path): + + if self.split == "contest": + task_id = int(task_data["meta"]["questionId"]) + else: + task_id = int(task_data["meta"]["question_id"]) + tasks[task_id] = task_data + + return tasks + + def get_prompt(self): + """ + Builds the prompt for the LM to generate from. + """ + + prompts = [] + for task_id, task_data in self.tasks.items(): + + if self.split == "contest": + if self.prompt_type == "Completion": + prompt = task_data['prompt'] + elif self.prompt_type == "Instruction": + prompt = task_data['prompt_sft'] + else: + prompt = task_data['meta']['query'] + + prompts.append( + dict( + task_id = task_id, + prompt = refine_text(prompt) + ) + ) + + return prompts + + def postprocess_generation(self, generation): + """ + Postprocess the generations. + """ + + return dict( + task_id = generation['task_id'], + completion_id = generation['completion_id'], + solution = sanitize( + text = generation['completion'], + entrypoint = "Solution", + ) + ) + + def process_results(self, solution): + """ + Takes the list of LM generations and evaluates them against the test cases + """ + + task_data = self.tasks[solution['task_id']] + + if self.split == "contest": + code = ( + "\n".join(self.imports_code) + "\n\n" + + solution['solution'] + "\n\n" + + task_data['test'] + ) + else: + code = ( + "\n".join(self.imports_code) + "\n\n" + + task_data['meta']['lang_code'] + "\n" + + " pass\n" + "\n" + + solution['solution'] + "\n" + + task_data['test'] + "\n" + + f"check({task_data['entry_point']})" + ) + + result = check_correctness(solution['task_id'], + solution['completion_id'], + code, + self.time_out) + + return result \ No newline at end of file -- cgit v1.2.3