summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/benchmark/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'code_eval/OpenCodeEval/benchmark/base.py')
-rw-r--r--code_eval/OpenCodeEval/benchmark/base.py124
1 files changed, 124 insertions, 0 deletions
diff --git a/code_eval/OpenCodeEval/benchmark/base.py b/code_eval/OpenCodeEval/benchmark/base.py
new file mode 100644
index 0000000..4bd9750
--- /dev/null
+++ b/code_eval/OpenCodeEval/benchmark/base.py
@@ -0,0 +1,124 @@
+import os
+import sys
+
+ROOT = os.path.dirname(os.path.abspath(__file__))
+
+PYTHON_STOP = [ "\nif __name__",
+ "\ndef main(",
+ "\nprint("
+ ]
+
+PYTHON_IMPORTS = [ "import math",
+ "import re",
+ "import sys",
+ "import copy",
+ "import datetime",
+ "import itertools",
+ "import collections",
+ "import heapq",
+ "import functools",
+ "import hashlib",
+ "import numpy",
+ "import numpy as np",
+ "import string",
+ "from typing import *",
+ "from collections import *"
+ ]
+
+LEETCODE_IMPORTS = [
+ 'from typing import *',
+ 'from functools import *',
+ 'from collections import *',
+ 'from itertools import *',
+ 'from heapq import *',
+ 'from bisect import *',
+ 'from string import *',
+ 'from operator import *',
+ 'from math import *',
+ 'import math',
+ 'import datetime',
+ "inf = float('inf')",
+]
+
+from abc import ABC, abstractmethod
+
+class Benchmark(ABC):
+
+ name: str = None
+ split: str = None
+ path: str = os.path.abspath(os.path.join(os.path.dirname(__file__), "data/"))
+
+ imports = []
+ chat_stop = []
+ base_stop = []
+
+ def __init__(self):
+ """
+ :param stop_words: list
+ list of stop words if the generation uses a stopping criteria during generation
+ :param requires_execution: bool
+ wheter the task requires code execution during evaluation or not
+ """
+ pass
+
+ def fewshot_examples(self):
+ """Loads and returns the few-shot examples for the task if they exist."""
+ pass
+
+ @abstractmethod
+ def get_task(self):
+ """Builds the task for the LM to generate from.
+ """
+ pass
+
+ @abstractmethod
+ def get_prompt(self, doc):
+ """Builds the prompt for the LM to generate from.
+ :param doc: dict[str: str]
+ sample from the test dataset
+ """
+ pass
+
+
+ def get_reference(self, doc):
+ """Builds the reference solution for the doc.
+ :param doc: dict[str: str]
+ sample from the test dataset
+ """
+ pass
+
+ @abstractmethod
+ def postprocess_generation(self, task, generation):
+ """Defines the postprocessing for a LM generation.
+ :param generation: str
+ code generation from LM
+ :param idx: int
+ index of doc in the dataset to which the generation belongs
+ """
+ pass
+
+ @abstractmethod
+ def process_results(self, generations, references):
+ """Takes the list of LM generations and evaluates them against ground truth references,
+ returning the metric for the generations as in {"metric_name": result}.
+ :param generations: list(list(str))
+ list of lists containing generations
+ :param references: list(str)
+ list of str containing refrences
+ :return: dict[str: float]
+ """
+ pass
+
+ def _stop_at_stop_token(decoded_string, stop_tokens):
+ """
+ Produces the prefix of decoded_string that ends at the first occurrence of
+ a stop_token.
+ WARNING: the decoded_string *must not* include the prompt, which may have stop tokens
+ itself.
+ """
+ min_stop_index = len(decoded_string)
+ for stop_token in stop_tokens:
+ stop_index = decoded_string.find(stop_token)
+ if stop_index != -1 and stop_index < min_stop_index:
+ min_stop_index = stop_index
+ return decoded_string[:min_stop_index] \ No newline at end of file