|
""" |
|
@Time : 2024/7/24 16:37 |
|
@Author : didi |
|
@File : utils.py |
|
@Acknowledgement https://github.com/evalplus/evalplus/blob/master/evalplus/sanitize.py |
|
""" |
|
|
|
import ast |
|
import traceback |
|
from enum import Enum |
|
from typing import Dict, Generator, List, Optional, Set, Tuple |
|
|
|
import tree_sitter_python |
|
from tree_sitter import Language, Node, Parser |
|
|
|
|
|
class NodeType(Enum): |
|
CLASS = "class_definition" |
|
FUNCTION = "function_definition" |
|
IMPORT = ["import_statement", "import_from_statement"] |
|
IDENTIFIER = "identifier" |
|
ATTRIBUTE = "attribute" |
|
RETURN = "return_statement" |
|
EXPRESSION = "expression_statement" |
|
ASSIGNMENT = "assignment" |
|
|
|
|
|
def traverse_tree(node: Node) -> Generator[Node, None, None]: |
|
""" |
|
Traverse the tree structure starting from the given node. |
|
|
|
:param node: The root node to start the traversal from. |
|
:return: A generator object that yields nodes in the tree. |
|
""" |
|
cursor = node.walk() |
|
depth = 0 |
|
|
|
visited_children = False |
|
while True: |
|
if not visited_children: |
|
yield cursor.node |
|
if not cursor.goto_first_child(): |
|
depth += 1 |
|
visited_children = True |
|
elif cursor.goto_next_sibling(): |
|
visited_children = False |
|
elif not cursor.goto_parent() or depth == 0: |
|
break |
|
else: |
|
depth -= 1 |
|
|
|
|
|
def syntax_check(code, verbose=False): |
|
try: |
|
ast.parse(code) |
|
return True |
|
except (SyntaxError, MemoryError): |
|
if verbose: |
|
traceback.print_exc() |
|
return False |
|
|
|
|
|
def code_extract(text: str) -> str: |
|
lines = text.split("\n") |
|
longest_line_pair = (0, 0) |
|
longest_so_far = 0 |
|
|
|
for i in range(len(lines)): |
|
for j in range(i + 1, len(lines)): |
|
current_lines = "\n".join(lines[i : j + 1]) |
|
if syntax_check(current_lines): |
|
current_length = sum(1 for line in lines[i : j + 1] if line.strip()) |
|
if current_length > longest_so_far: |
|
longest_so_far = current_length |
|
longest_line_pair = (i, j) |
|
|
|
return "\n".join(lines[longest_line_pair[0] : longest_line_pair[1] + 1]) |
|
|
|
|
|
def get_definition_name(node: Node) -> str: |
|
for child in node.children: |
|
if child.type == NodeType.IDENTIFIER.value: |
|
return child.text.decode("utf8") |
|
|
|
|
|
def has_return_statement(node: Node) -> bool: |
|
traverse_nodes = traverse_tree(node) |
|
for node in traverse_nodes: |
|
if node.type == NodeType.RETURN.value: |
|
return True |
|
return False |
|
|
|
|
|
def get_deps(nodes: List[Tuple[str, Node]]) -> Dict[str, Set[str]]: |
|
def dfs_get_deps(node: Node, deps: Set[str]) -> None: |
|
for child in node.children: |
|
if child.type == NodeType.IDENTIFIER.value: |
|
deps.add(child.text.decode("utf8")) |
|
else: |
|
dfs_get_deps(child, deps) |
|
|
|
name2deps = {} |
|
for name, node in nodes: |
|
deps = set() |
|
dfs_get_deps(node, deps) |
|
name2deps[name] = deps |
|
return name2deps |
|
|
|
|
|
def get_function_dependency(entrypoint: str, call_graph: Dict[str, str]) -> Set[str]: |
|
queue = [entrypoint] |
|
visited = {entrypoint} |
|
while queue: |
|
current = queue.pop(0) |
|
if current not in call_graph: |
|
continue |
|
for neighbour in call_graph[current]: |
|
if neighbour not in visited: |
|
visited.add(neighbour) |
|
queue.append(neighbour) |
|
return visited |
|
|
|
|
|
def sanitize(code: str, entrypoint: Optional[str] = None) -> str: |
|
""" |
|
Sanitize and extract relevant parts of the given Python code. |
|
This function parses the input code, extracts import statements, class and function definitions, |
|
and variable assignments. If an entrypoint is provided, it only includes definitions that are |
|
reachable from the entrypoint in the call graph. |
|
|
|
:param code: The input Python code as a string. |
|
:param entrypoint: Optional name of a function to use as the entrypoint for dependency analysis. |
|
:return: A sanitized version of the input code, containing only relevant parts. |
|
""" |
|
code = code_extract(code) |
|
code_bytes = bytes(code, "utf8") |
|
parser = Parser(Language(tree_sitter_python.language())) |
|
tree = parser.parse(code_bytes) |
|
class_names = set() |
|
function_names = set() |
|
variable_names = set() |
|
|
|
root_node = tree.root_node |
|
import_nodes = [] |
|
definition_nodes = [] |
|
|
|
for child in root_node.children: |
|
if child.type in NodeType.IMPORT.value: |
|
import_nodes.append(child) |
|
elif child.type == NodeType.CLASS.value: |
|
name = get_definition_name(child) |
|
if not (name in class_names or name in variable_names or name in function_names): |
|
definition_nodes.append((name, child)) |
|
class_names.add(name) |
|
elif child.type == NodeType.FUNCTION.value: |
|
name = get_definition_name(child) |
|
if not (name in function_names or name in variable_names or name in class_names) and has_return_statement( |
|
child |
|
): |
|
definition_nodes.append((name, child)) |
|
function_names.add(get_definition_name(child)) |
|
elif child.type == NodeType.EXPRESSION.value and child.children[0].type == NodeType.ASSIGNMENT.value: |
|
subchild = child.children[0] |
|
name = get_definition_name(subchild) |
|
if not (name in variable_names or name in function_names or name in class_names): |
|
definition_nodes.append((name, subchild)) |
|
variable_names.add(name) |
|
|
|
if entrypoint: |
|
name2deps = get_deps(definition_nodes) |
|
reacheable = get_function_dependency(entrypoint, name2deps) |
|
|
|
sanitized_output = b"" |
|
|
|
for node in import_nodes: |
|
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" |
|
|
|
for pair in definition_nodes: |
|
name, node = pair |
|
if entrypoint and name not in reacheable: |
|
continue |
|
sanitized_output += code_bytes[node.start_byte : node.end_byte] + b"\n" |
|
return sanitized_output[:-1].decode("utf8") |
|
|