summaryrefslogtreecommitdiff
path: root/code_eval/OpenCodeEval/eval/sanitize.py
blob: a82ea652056708ad22d97c57a01cc240ef0ceb50 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""Post-processing LLM-generated Python code implemented using AST."""
import ast
import traceback
import signal

from typing import Dict, List, Optional, Set, Tuple
from OpenCodeEval.utils import refine_text

def syntax_check(code, verbose = False):
    try:
        ast.parse(code)
        return True
    except (SyntaxError, MemoryError):
        if verbose:
            traceback.print_exc()
        return False

def extract_longest_valid_code(text: str) -> str:
    lines = text.splitlines()

    max_valid_lines = 0
    max_valid_snippet = ""

    for i in range(len(lines)):
        for j in range(i, len(lines)):
            current_snippet = "\n".join(lines[i:j+1])
            if syntax_check(current_snippet):
                valid_line_count = sum(1 for line in lines[i:j+1] if line.strip())
                if valid_line_count > max_valid_lines:
                    max_valid_lines = valid_line_count
                    max_valid_snippet = current_snippet

    return max_valid_snippet

def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]:
    name2deps = {}
    for name, node in nodes:
        deps = set()
        stack = [node]
        while stack:
            current = stack.pop()
            for child in ast.iter_child_nodes(current):
                if isinstance(child, ast.Name):
                    deps.add(child.id)
                elif isinstance(child, ast.Attribute):
                    deps.add(child.attr)
                else:
                    stack.append(child)
        name2deps[name] = deps
    return name2deps

def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]:
    visited = set()
    to_visit = [entrypoint]

    while to_visit:
        current = to_visit.pop(0)
        if current not in visited:
            visited.add(current)
            to_visit.extend(call_graph.get(current, set()) - visited)

    return visited

def get_definition_name(node: ast.AST) -> Optional[str]:
    if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
        return node.name
    elif isinstance(node, ast.Assign):
        targets = node.targets
        if targets and isinstance(targets[0], ast.Name):
            return targets[0].id
    return None

def has_return_statement(node: ast.AST) -> bool:
    return any(isinstance(n, ast.Return) for n in ast.walk(node))

def has_yield_statement(node: ast.AST) -> bool:
    return any(isinstance(n, ast.Yield) for n in ast.walk(node))

class TimeoutException(Exception): pass

def _timeout_handler(signum, frame):
    raise TimeoutException()
"""
def sanitize(text: str, entrypoint: Optional[str] = None) -> str:
     # 设置信号处理器
    signal.signal(signal.SIGALRM, _timeout_handler)
    signal.alarm(30)  # 30秒后发送SIGALRM
    
    text = refine_text(text)

    try:
        code = extract_longest_valid_code(text)

        tree = ast.parse(code)
            
        definitions = {}

        imports = []

        for node in tree.body:
            if isinstance(node, (ast.Import, ast.ImportFrom)):
                imports.append(node)
            elif isinstance(node, ast.ClassDef):
                name = node.name
                definitions[name] = ('class', node)
            elif isinstance(node, ast.FunctionDef):
                name = node.name
                if has_return_statement(node) or has_yield_statement(node):
                    definitions[name] = ('function', node)
            elif isinstance(node, ast.Assign):
                name = get_definition_name(node)
                if name:
                    definitions[name] = ('variable', node)

        if entrypoint:
            name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()])
            reachable = get_function_dependency(entrypoint, name2deps)

        sanitized_output = []

        for node in imports:
            sanitized_output.append(ast.unparse(node))

        for name, (_, node) in definitions.items():
            if not entrypoint or name in reachable:
                sanitized_output.append(ast.unparse(node))

        return "\n".join(sanitized_output)

    except Exception as e:
        print(f"Error extracting longest valid code: {e}")
        return ""
    finally:
        signal.alarm(0)  # 取消定时器
"""
from multiprocessing import Process, Queue, TimeoutError

def sanitize(text: str, entrypoint: Optional[str] = None) -> str:
    def _sanitize_worker(q, text, entrypoint):
        try:
            # 原有处理逻辑保持不变
            text = refine_text(text)
            code = extract_longest_valid_code(text)
            tree = ast.parse(code)
            definitions = {}
            imports = []
            for node in tree.body:
                if isinstance(node, (ast.Import, ast.ImportFrom)):
                    imports.append(node)
                elif isinstance(node, ast.ClassDef):
                    name = node.name
                    definitions[name] = ('class', node)
                elif isinstance(node, ast.FunctionDef):
                    name = node.name
                    if has_return_statement(node) or has_yield_statement(node):
                        definitions[name] = ('function', node)
                elif isinstance(node, ast.Assign):
                    name = get_definition_name(node)
                    if name:
                        definitions[name] = ('variable', node)

            if entrypoint:
                name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()])
                reachable = get_function_dependency(entrypoint, name2deps)

            sanitized_output = []
            for node in imports:
                sanitized_output.append(ast.unparse(node))
            for name, (_, node) in definitions.items():
                if not entrypoint or name in reachable:
                    sanitized_output.append(ast.unparse(node))
            
            q.put("\n".join(sanitized_output))
            
        except Exception as e:
            print(f"Error extracting longest valid code: {e}")
            q.put("")

    # 使用多进程实现超时控制
    q = Queue()
    p = Process(target=_sanitize_worker, args=(q, text, entrypoint))
    p.start()
    
    try:
        # 等待3秒获取结果
        result = q.get(timeout=3)
        
        p.terminate()
        p.join()  # 确保进程退出
        return result
    except TimeoutError:
        print("Function timed out after 3 seconds")
        p.terminate()  # 强制终止进程
        p.join()
        return ""