diff options
Diffstat (limited to 'code_eval/OpenCodeEval/benchmark/base.py')
| -rw-r--r-- | code_eval/OpenCodeEval/benchmark/base.py | 124 |
1 files changed, 124 insertions, 0 deletions
diff --git a/code_eval/OpenCodeEval/benchmark/base.py b/code_eval/OpenCodeEval/benchmark/base.py new file mode 100644 index 0000000..4bd9750 --- /dev/null +++ b/code_eval/OpenCodeEval/benchmark/base.py @@ -0,0 +1,124 @@ +import os +import sys + +ROOT = os.path.dirname(os.path.abspath(__file__)) + +PYTHON_STOP = [ "\nif __name__", + "\ndef main(", + "\nprint(" + ] + +PYTHON_IMPORTS = [ "import math", + "import re", + "import sys", + "import copy", + "import datetime", + "import itertools", + "import collections", + "import heapq", + "import functools", + "import hashlib", + "import numpy", + "import numpy as np", + "import string", + "from typing import *", + "from collections import *" + ] + +LEETCODE_IMPORTS = [ + 'from typing import *', + 'from functools import *', + 'from collections import *', + 'from itertools import *', + 'from heapq import *', + 'from bisect import *', + 'from string import *', + 'from operator import *', + 'from math import *', + 'import math', + 'import datetime', + "inf = float('inf')", +] + +from abc import ABC, abstractmethod + +class Benchmark(ABC): + + name: str = None + split: str = None + path: str = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/")) + + imports = [] + chat_stop = [] + base_stop = [] + + def __init__(self): + """ + :param stop_words: list + list of stop words if the generation uses a stopping criteria during generation + :param requires_execution: bool + wheter the task requires code execution during evaluation or not + """ + pass + + def fewshot_examples(self): + """Loads and returns the few-shot examples for the task if they exist.""" + pass + + @abstractmethod + def get_task(self): + """Builds the task for the LM to generate from. + """ + pass + + @abstractmethod + def get_prompt(self, doc): + """Builds the prompt for the LM to generate from. + :param doc: dict[str: str] + sample from the test dataset + """ + pass + + + def get_reference(self, doc): + """Builds the reference solution for the doc. + :param doc: dict[str: str] + sample from the test dataset + """ + pass + + @abstractmethod + def postprocess_generation(self, task, generation): + """Defines the postprocessing for a LM generation. + :param generation: str + code generation from LM + :param idx: int + index of doc in the dataset to which the generation belongs + """ + pass + + @abstractmethod + def process_results(self, generations, references): + """Takes the list of LM generations and evaluates them against ground truth references, + returning the metric for the generations as in {"metric_name": result}. + :param generations: list(list(str)) + list of lists containing generations + :param references: list(str) + list of str containing refrences + :return: dict[str: float] + """ + pass + + def _stop_at_stop_token(decoded_string, stop_tokens): + """ + Produces the prefix of decoded_string that ends at the first occurrence of + a stop_token. + WARNING: the decoded_string *must not* include the prompt, which may have stop tokens + itself. + """ + min_stop_index = len(decoded_string) + for stop_token in stop_tokens: + stop_index = decoded_string.find(stop_token) + if stop_index != -1 and stop_index < min_stop_index: + min_stop_index = stop_index + return decoded_string[:min_stop_index]
\ No newline at end of file |
