summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/factory.py
blob: 5d3fefc273a5d3bc8ba06a6fbd9cb18ec2a54062 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
from loguru import logger

from OpenCodeEval.benchmark.HumanEval import HumanEval
from OpenCodeEval.benchmark.mbpp import mbpp
from OpenCodeEval.benchmark.MBPP import MBPP
from OpenCodeEval.benchmark.LeetCode import LeetCode
from OpenCodeEval.benchmark.BigCodeBench import BigCodeBench
from OpenCodeEval.benchmark.Bird import Bird
from OpenCodeEval.benchmark.Spider import Spider

from OpenCodeEval.benchmark.understandml import understandml

from OpenCodeEval.backend.vllm import VllmGenerator
from OpenCodeEval.backend.sglang import SglangGenerator
from OpenCodeEval.backend.openai import OpenaiGenerator

class BenchmarkFactory:

    @staticmethod
    def get_task(args):
        task_map = {
            "HumanEval": HumanEval,
            "mbpp": mbpp,
            "MBPP": MBPP,
            "LeetCode": LeetCode,
            "BigCodeBench": BigCodeBench,
            "Bird": Bird,
            "Spider": Spider,
            "understandml": understandml
        }

        # Check if the task exists in the map
        if args.task not in task_map:
            logger.error(f"Unknown Task type: {args.task}")

        # Get the corresponding class
        task_class = task_map[args.task]

        return task_class(
            split = args.split,
            time_out = args.time_out,
            prompt_type = args.prompt_type
        )

class BackendFactory:

    @staticmethod
    def get_backend(args):
        if args.backend == "vllm":
            return VllmGenerator(
                model_name = args.model_name,
                model_type = args.model_type,
                tokenizer_name = args.tokenizer_name,
                num_gpus = args.num_gpus,
                batch_size = args.batch_size,
                temperature = args.temperature,
                top_p = args.top_p,
                num_samples = args.num_samples,
                max_tokens = args.max_tokens,
                seed = args.seed,
                trust_remote_code = args.trust_remote_code,
            )

        elif args.backend == "sglang":
            return SglangGenerator(
                model_name = args.model_name,
                model_type = args.model_type,
                tokenizer_name = args.tokenizer_name,
                num_gpus = args.num_gpus,
                batch_size = args.batch_size,
                temperature = args.temperature,
                top_p = args.top_p,
                num_samples = args.num_samples,
                max_tokens = args.max_tokens,
                trust_remote_code = args.trust_remote_code,
            )

        elif args.backend == "openai":
            return OpenaiGenerator(
                model_name = args.model_name,
                model_type = args.model_type,
                batch_size = args.batch_size,
                temperature = args.temperature,
                top_p = args.top_p,
                num_samples = args.num_samples,
                max_tokens = args.max_tokens,
            )
        else:
            logger.error(f"Unknown Backend type: {args.backend}")