-
Notifications
You must be signed in to change notification settings - Fork 1
/
programSliceTransformer.py
121 lines (87 loc) · 3.94 KB
/
programSliceTransformer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import ast
from visitor.functionDeclarationVisitor import FunctionDeclarationVisitor
class ProgramSliceTransformer(ast.NodeTransformer):
'''
This class derives from `ast.NodeTransformer`.
The NodeTransformer will walk the AST and use the return value of the visitor methods to replace or remove the old node.
If the return value of the visitor method is None, the node will be removed from its location, otherwise it is replaced with the return value.
The return value may be the original node in which case no replacement takes place.
See https://docs.python.org/3/library/ast.html#ast.NodeTransformer for more details
'''
def __init__(self) -> None:
super().__init__()
self.removedLineNumbers = set()
self.functionDeclarationVisitor = FunctionDeclarationVisitor()
self.context = []
def getSlicedProgram(self, lineNumbers: set, node: ast.AST) -> ast.AST:
'''
Returns the AST with only the lines in `node` specified by `lineNumbers`.
However, certain control flow statements like ifs, for loops, while loops are handled differently.
Function calls will only be kept for user defined functions.
Furthermore, imports, return statements, function declarations will all be kept.
'''
self.lineNumbers = lineNumbers
self.functionNames = self.functionDeclarationVisitor.getAllFunctionDefinitionNames(node)
return self.generic_visit(node)
def visit_Assign(self, node: ast.Assign):
if node.lineno not in self.lineNumbers:
self.removedLineNumbers.add(node.lineno)
return None
return node
def visit_AugAssign(self, node: ast.AugAssign):
if node.lineno not in self.lineNumbers:
self.removedLineNumbers.add(node.lineno)
return None
return node
def visit_Expr(self, node: ast.Expr):
self.context.append('expr')
t = self.generic_visit(node)
self.context.pop()
if not hasattr(t, 'value'):
return None
return node
def visit_Call(self, node: ast.Call):
if (len(self.context) == 0 or self.context[-1] != 'expr'):
return node
# keep user defined function calls
if node.lineno in self.lineNumbers:
return node
# if ANY function calls inside this node reference a user defined function, keep it
# e.g. if foo is a user defined function, print(foo(arr)) and foo(len(arr)) should be kept
if type(node.func) is ast.Name and node.func.id in self.functionNames:
return node
for arg in node.args:
if type(arg) is ast.Call:
t = self.visit_Call(arg)
if t:
return node
for keyword in node.keywords:
if type(keyword.value) is ast.Call:
t = self.visit_Call(arg)
if t:
return node
self.removedLineNumbers.add(node.lineno)
return None
def visit_If(self, node: ast.If):
result: ast.If = self.generic_visit(node)
body = result.body
orElse = result.orelse
if len(body) == 0 and len(orElse) == 0:
self.removedLineNumbers.add(node.lineno)
return None
# empty body should just be replaced with pass, the orelse block may have important information to keep
if len(body) == 0:
body.append(ast.Pass())
return result
def visit_For(self, node: ast.For):
result: ast.For = self.generic_visit(node)
if len(result.body) == 0:
self.removedLineNumbers.add(node.lineno)
return None
return result
def visit_While(self, node: ast.While):
result: ast.For = self.generic_visit(node)
if len(result.body) == 0:
self.removedLineNumbers.add(node.lineno)
return None
return result