1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
|
"""Post-processing LLM-generated Python code implemented using AST."""
import ast
import traceback
import signal
from typing import Dict, List, Optional, Set, Tuple
from OpenCodeEval.utils import refine_text
def syntax_check(code, verbose = False):
try:
ast.parse(code)
return True
except (SyntaxError, MemoryError):
if verbose:
traceback.print_exc()
return False
def extract_longest_valid_code(text: str) -> str:
lines = text.splitlines()
max_valid_lines = 0
max_valid_snippet = ""
for i in range(len(lines)):
for j in range(i, len(lines)):
current_snippet = "\n".join(lines[i:j+1])
if syntax_check(current_snippet):
valid_line_count = sum(1 for line in lines[i:j+1] if line.strip())
if valid_line_count > max_valid_lines:
max_valid_lines = valid_line_count
max_valid_snippet = current_snippet
return max_valid_snippet
def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]:
name2deps = {}
for name, node in nodes:
deps = set()
stack = [node]
while stack:
current = stack.pop()
for child in ast.iter_child_nodes(current):
if isinstance(child, ast.Name):
deps.add(child.id)
elif isinstance(child, ast.Attribute):
deps.add(child.attr)
else:
stack.append(child)
name2deps[name] = deps
return name2deps
def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]:
visited = set()
to_visit = [entrypoint]
while to_visit:
current = to_visit.pop(0)
if current not in visited:
visited.add(current)
to_visit.extend(call_graph.get(current, set()) - visited)
return visited
def get_definition_name(node: ast.AST) -> Optional[str]:
if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
return node.name
elif isinstance(node, ast.Assign):
targets = node.targets
if targets and isinstance(targets[0], ast.Name):
return targets[0].id
return None
def has_return_statement(node: ast.AST) -> bool:
return any(isinstance(n, ast.Return) for n in ast.walk(node))
def has_yield_statement(node: ast.AST) -> bool:
return any(isinstance(n, ast.Yield) for n in ast.walk(node))
class TimeoutException(Exception): pass
def _timeout_handler(signum, frame):
raise TimeoutException()
"""
def sanitize(text: str, entrypoint: Optional[str] = None) -> str:
# 设置信号处理器
signal.signal(signal.SIGALRM, _timeout_handler)
signal.alarm(30) # 30秒后发送SIGALRM
text = refine_text(text)
try:
code = extract_longest_valid_code(text)
tree = ast.parse(code)
definitions = {}
imports = []
for node in tree.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
imports.append(node)
elif isinstance(node, ast.ClassDef):
name = node.name
definitions[name] = ('class', node)
elif isinstance(node, ast.FunctionDef):
name = node.name
if has_return_statement(node) or has_yield_statement(node):
definitions[name] = ('function', node)
elif isinstance(node, ast.Assign):
name = get_definition_name(node)
if name:
definitions[name] = ('variable', node)
if entrypoint:
name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()])
reachable = get_function_dependency(entrypoint, name2deps)
sanitized_output = []
for node in imports:
sanitized_output.append(ast.unparse(node))
for name, (_, node) in definitions.items():
if not entrypoint or name in reachable:
sanitized_output.append(ast.unparse(node))
return "\n".join(sanitized_output)
except Exception as e:
print(f"Error extracting longest valid code: {e}")
return ""
finally:
signal.alarm(0) # 取消定时器
"""
from multiprocessing import Process, Queue, TimeoutError
def sanitize(text: str, entrypoint: Optional[str] = None) -> str:
def _sanitize_worker(q, text, entrypoint):
try:
# 原有处理逻辑保持不变
text = refine_text(text)
code = extract_longest_valid_code(text)
tree = ast.parse(code)
definitions = {}
imports = []
for node in tree.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
imports.append(node)
elif isinstance(node, ast.ClassDef):
name = node.name
definitions[name] = ('class', node)
elif isinstance(node, ast.FunctionDef):
name = node.name
if has_return_statement(node) or has_yield_statement(node):
definitions[name] = ('function', node)
elif isinstance(node, ast.Assign):
name = get_definition_name(node)
if name:
definitions[name] = ('variable', node)
if entrypoint:
name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()])
reachable = get_function_dependency(entrypoint, name2deps)
sanitized_output = []
for node in imports:
sanitized_output.append(ast.unparse(node))
for name, (_, node) in definitions.items():
if not entrypoint or name in reachable:
sanitized_output.append(ast.unparse(node))
q.put("\n".join(sanitized_output))
except Exception as e:
print(f"Error extracting longest valid code: {e}")
q.put("")
# 使用多进程实现超时控制
q = Queue()
p = Process(target=_sanitize_worker, args=(q, text, entrypoint))
p.start()
try:
# 等待3秒获取结果
result = q.get(timeout=3)
p.terminate()
p.join() # 确保进程退出
return result
except TimeoutError:
print("Function timed out after 3 seconds")
p.terminate() # 强制终止进程
p.join()
return ""
|