1 | """Utility functions, node construction macros, etc."""
|
---|
2 | # Author: Collin Winter
|
---|
3 |
|
---|
4 | from itertools import islice
|
---|
5 |
|
---|
6 | # Local imports
|
---|
7 | from .pgen2 import token
|
---|
8 | from .pytree import Leaf, Node
|
---|
9 | from .pygram import python_symbols as syms
|
---|
10 | from . import patcomp
|
---|
11 |
|
---|
12 |
|
---|
13 | ###########################################################
|
---|
14 | ### Common node-construction "macros"
|
---|
15 | ###########################################################
|
---|
16 |
|
---|
17 | def KeywordArg(keyword, value):
|
---|
18 | return Node(syms.argument,
|
---|
19 | [keyword, Leaf(token.EQUAL, u"="), value])
|
---|
20 |
|
---|
21 | def LParen():
|
---|
22 | return Leaf(token.LPAR, u"(")
|
---|
23 |
|
---|
24 | def RParen():
|
---|
25 | return Leaf(token.RPAR, u")")
|
---|
26 |
|
---|
27 | def Assign(target, source):
|
---|
28 | """Build an assignment statement"""
|
---|
29 | if not isinstance(target, list):
|
---|
30 | target = [target]
|
---|
31 | if not isinstance(source, list):
|
---|
32 | source.prefix = u" "
|
---|
33 | source = [source]
|
---|
34 |
|
---|
35 | return Node(syms.atom,
|
---|
36 | target + [Leaf(token.EQUAL, u"=", prefix=u" ")] + source)
|
---|
37 |
|
---|
38 | def Name(name, prefix=None):
|
---|
39 | """Return a NAME leaf"""
|
---|
40 | return Leaf(token.NAME, name, prefix=prefix)
|
---|
41 |
|
---|
42 | def Attr(obj, attr):
|
---|
43 | """A node tuple for obj.attr"""
|
---|
44 | return [obj, Node(syms.trailer, [Dot(), attr])]
|
---|
45 |
|
---|
46 | def Comma():
|
---|
47 | """A comma leaf"""
|
---|
48 | return Leaf(token.COMMA, u",")
|
---|
49 |
|
---|
50 | def Dot():
|
---|
51 | """A period (.) leaf"""
|
---|
52 | return Leaf(token.DOT, u".")
|
---|
53 |
|
---|
54 | def ArgList(args, lparen=LParen(), rparen=RParen()):
|
---|
55 | """A parenthesised argument list, used by Call()"""
|
---|
56 | node = Node(syms.trailer, [lparen.clone(), rparen.clone()])
|
---|
57 | if args:
|
---|
58 | node.insert_child(1, Node(syms.arglist, args))
|
---|
59 | return node
|
---|
60 |
|
---|
61 | def Call(func_name, args=None, prefix=None):
|
---|
62 | """A function call"""
|
---|
63 | node = Node(syms.power, [func_name, ArgList(args)])
|
---|
64 | if prefix is not None:
|
---|
65 | node.prefix = prefix
|
---|
66 | return node
|
---|
67 |
|
---|
68 | def Newline():
|
---|
69 | """A newline literal"""
|
---|
70 | return Leaf(token.NEWLINE, u"\n")
|
---|
71 |
|
---|
72 | def BlankLine():
|
---|
73 | """A blank line"""
|
---|
74 | return Leaf(token.NEWLINE, u"")
|
---|
75 |
|
---|
76 | def Number(n, prefix=None):
|
---|
77 | return Leaf(token.NUMBER, n, prefix=prefix)
|
---|
78 |
|
---|
79 | def Subscript(index_node):
|
---|
80 | """A numeric or string subscript"""
|
---|
81 | return Node(syms.trailer, [Leaf(token.LBRACE, u"["),
|
---|
82 | index_node,
|
---|
83 | Leaf(token.RBRACE, u"]")])
|
---|
84 |
|
---|
85 | def String(string, prefix=None):
|
---|
86 | """A string leaf"""
|
---|
87 | return Leaf(token.STRING, string, prefix=prefix)
|
---|
88 |
|
---|
89 | def ListComp(xp, fp, it, test=None):
|
---|
90 | """A list comprehension of the form [xp for fp in it if test].
|
---|
91 |
|
---|
92 | If test is None, the "if test" part is omitted.
|
---|
93 | """
|
---|
94 | xp.prefix = u""
|
---|
95 | fp.prefix = u" "
|
---|
96 | it.prefix = u" "
|
---|
97 | for_leaf = Leaf(token.NAME, u"for")
|
---|
98 | for_leaf.prefix = u" "
|
---|
99 | in_leaf = Leaf(token.NAME, u"in")
|
---|
100 | in_leaf.prefix = u" "
|
---|
101 | inner_args = [for_leaf, fp, in_leaf, it]
|
---|
102 | if test:
|
---|
103 | test.prefix = u" "
|
---|
104 | if_leaf = Leaf(token.NAME, u"if")
|
---|
105 | if_leaf.prefix = u" "
|
---|
106 | inner_args.append(Node(syms.comp_if, [if_leaf, test]))
|
---|
107 | inner = Node(syms.listmaker, [xp, Node(syms.comp_for, inner_args)])
|
---|
108 | return Node(syms.atom,
|
---|
109 | [Leaf(token.LBRACE, u"["),
|
---|
110 | inner,
|
---|
111 | Leaf(token.RBRACE, u"]")])
|
---|
112 |
|
---|
113 | def FromImport(package_name, name_leafs):
|
---|
114 | """ Return an import statement in the form:
|
---|
115 | from package import name_leafs"""
|
---|
116 | # XXX: May not handle dotted imports properly (eg, package_name='foo.bar')
|
---|
117 | #assert package_name == '.' or '.' not in package_name, "FromImport has "\
|
---|
118 | # "not been tested with dotted package names -- use at your own "\
|
---|
119 | # "peril!"
|
---|
120 |
|
---|
121 | for leaf in name_leafs:
|
---|
122 | # Pull the leaves out of their old tree
|
---|
123 | leaf.remove()
|
---|
124 |
|
---|
125 | children = [Leaf(token.NAME, u"from"),
|
---|
126 | Leaf(token.NAME, package_name, prefix=u" "),
|
---|
127 | Leaf(token.NAME, u"import", prefix=u" "),
|
---|
128 | Node(syms.import_as_names, name_leafs)]
|
---|
129 | imp = Node(syms.import_from, children)
|
---|
130 | return imp
|
---|
131 |
|
---|
132 |
|
---|
133 | ###########################################################
|
---|
134 | ### Determine whether a node represents a given literal
|
---|
135 | ###########################################################
|
---|
136 |
|
---|
137 | def is_tuple(node):
|
---|
138 | """Does the node represent a tuple literal?"""
|
---|
139 | if isinstance(node, Node) and node.children == [LParen(), RParen()]:
|
---|
140 | return True
|
---|
141 | return (isinstance(node, Node)
|
---|
142 | and len(node.children) == 3
|
---|
143 | and isinstance(node.children[0], Leaf)
|
---|
144 | and isinstance(node.children[1], Node)
|
---|
145 | and isinstance(node.children[2], Leaf)
|
---|
146 | and node.children[0].value == u"("
|
---|
147 | and node.children[2].value == u")")
|
---|
148 |
|
---|
149 | def is_list(node):
|
---|
150 | """Does the node represent a list literal?"""
|
---|
151 | return (isinstance(node, Node)
|
---|
152 | and len(node.children) > 1
|
---|
153 | and isinstance(node.children[0], Leaf)
|
---|
154 | and isinstance(node.children[-1], Leaf)
|
---|
155 | and node.children[0].value == u"["
|
---|
156 | and node.children[-1].value == u"]")
|
---|
157 |
|
---|
158 |
|
---|
159 | ###########################################################
|
---|
160 | ### Misc
|
---|
161 | ###########################################################
|
---|
162 |
|
---|
163 | def parenthesize(node):
|
---|
164 | return Node(syms.atom, [LParen(), node, RParen()])
|
---|
165 |
|
---|
166 |
|
---|
167 | consuming_calls = set(["sorted", "list", "set", "any", "all", "tuple", "sum",
|
---|
168 | "min", "max", "enumerate"])
|
---|
169 |
|
---|
170 | def attr_chain(obj, attr):
|
---|
171 | """Follow an attribute chain.
|
---|
172 |
|
---|
173 | If you have a chain of objects where a.foo -> b, b.foo-> c, etc,
|
---|
174 | use this to iterate over all objects in the chain. Iteration is
|
---|
175 | terminated by getattr(x, attr) is None.
|
---|
176 |
|
---|
177 | Args:
|
---|
178 | obj: the starting object
|
---|
179 | attr: the name of the chaining attribute
|
---|
180 |
|
---|
181 | Yields:
|
---|
182 | Each successive object in the chain.
|
---|
183 | """
|
---|
184 | next = getattr(obj, attr)
|
---|
185 | while next:
|
---|
186 | yield next
|
---|
187 | next = getattr(next, attr)
|
---|
188 |
|
---|
189 | p0 = """for_stmt< 'for' any 'in' node=any ':' any* >
|
---|
190 | | comp_for< 'for' any 'in' node=any any* >
|
---|
191 | """
|
---|
192 | p1 = """
|
---|
193 | power<
|
---|
194 | ( 'iter' | 'list' | 'tuple' | 'sorted' | 'set' | 'sum' |
|
---|
195 | 'any' | 'all' | 'enumerate' | (any* trailer< '.' 'join' >) )
|
---|
196 | trailer< '(' node=any ')' >
|
---|
197 | any*
|
---|
198 | >
|
---|
199 | """
|
---|
200 | p2 = """
|
---|
201 | power<
|
---|
202 | ( 'sorted' | 'enumerate' )
|
---|
203 | trailer< '(' arglist<node=any any*> ')' >
|
---|
204 | any*
|
---|
205 | >
|
---|
206 | """
|
---|
207 | pats_built = False
|
---|
208 | def in_special_context(node):
|
---|
209 | """ Returns true if node is in an environment where all that is required
|
---|
210 | of it is being iterable (ie, it doesn't matter if it returns a list
|
---|
211 | or an iterator).
|
---|
212 | See test_map_nochange in test_fixers.py for some examples and tests.
|
---|
213 | """
|
---|
214 | global p0, p1, p2, pats_built
|
---|
215 | if not pats_built:
|
---|
216 | p0 = patcomp.compile_pattern(p0)
|
---|
217 | p1 = patcomp.compile_pattern(p1)
|
---|
218 | p2 = patcomp.compile_pattern(p2)
|
---|
219 | pats_built = True
|
---|
220 | patterns = [p0, p1, p2]
|
---|
221 | for pattern, parent in zip(patterns, attr_chain(node, "parent")):
|
---|
222 | results = {}
|
---|
223 | if pattern.match(parent, results) and results["node"] is node:
|
---|
224 | return True
|
---|
225 | return False
|
---|
226 |
|
---|
227 | def is_probably_builtin(node):
|
---|
228 | """
|
---|
229 | Check that something isn't an attribute or function name etc.
|
---|
230 | """
|
---|
231 | prev = node.prev_sibling
|
---|
232 | if prev is not None and prev.type == token.DOT:
|
---|
233 | # Attribute lookup.
|
---|
234 | return False
|
---|
235 | parent = node.parent
|
---|
236 | if parent.type in (syms.funcdef, syms.classdef):
|
---|
237 | return False
|
---|
238 | if parent.type == syms.expr_stmt and parent.children[0] is node:
|
---|
239 | # Assignment.
|
---|
240 | return False
|
---|
241 | if parent.type == syms.parameters or \
|
---|
242 | (parent.type == syms.typedargslist and (
|
---|
243 | (prev is not None and prev.type == token.COMMA) or
|
---|
244 | parent.children[0] is node
|
---|
245 | )):
|
---|
246 | # The name of an argument.
|
---|
247 | return False
|
---|
248 | return True
|
---|
249 |
|
---|
250 | def find_indentation(node):
|
---|
251 | """Find the indentation of *node*."""
|
---|
252 | while node is not None:
|
---|
253 | if node.type == syms.suite and len(node.children) > 2:
|
---|
254 | indent = node.children[1]
|
---|
255 | if indent.type == token.INDENT:
|
---|
256 | return indent.value
|
---|
257 | node = node.parent
|
---|
258 | return u""
|
---|
259 |
|
---|
260 | ###########################################################
|
---|
261 | ### The following functions are to find bindings in a suite
|
---|
262 | ###########################################################
|
---|
263 |
|
---|
264 | def make_suite(node):
|
---|
265 | if node.type == syms.suite:
|
---|
266 | return node
|
---|
267 | node = node.clone()
|
---|
268 | parent, node.parent = node.parent, None
|
---|
269 | suite = Node(syms.suite, [node])
|
---|
270 | suite.parent = parent
|
---|
271 | return suite
|
---|
272 |
|
---|
273 | def find_root(node):
|
---|
274 | """Find the top level namespace."""
|
---|
275 | # Scamper up to the top level namespace
|
---|
276 | while node.type != syms.file_input:
|
---|
277 | node = node.parent
|
---|
278 | if not node:
|
---|
279 | raise ValueError("root found before file_input node was found.")
|
---|
280 | return node
|
---|
281 |
|
---|
282 | def does_tree_import(package, name, node):
|
---|
283 | """ Returns true if name is imported from package at the
|
---|
284 | top level of the tree which node belongs to.
|
---|
285 | To cover the case of an import like 'import foo', use
|
---|
286 | None for the package and 'foo' for the name. """
|
---|
287 | binding = find_binding(name, find_root(node), package)
|
---|
288 | return bool(binding)
|
---|
289 |
|
---|
290 | def is_import(node):
|
---|
291 | """Returns true if the node is an import statement."""
|
---|
292 | return node.type in (syms.import_name, syms.import_from)
|
---|
293 |
|
---|
294 | def touch_import(package, name, node):
|
---|
295 | """ Works like `does_tree_import` but adds an import statement
|
---|
296 | if it was not imported. """
|
---|
297 | def is_import_stmt(node):
|
---|
298 | return (node.type == syms.simple_stmt and node.children and
|
---|
299 | is_import(node.children[0]))
|
---|
300 |
|
---|
301 | root = find_root(node)
|
---|
302 |
|
---|
303 | if does_tree_import(package, name, root):
|
---|
304 | return
|
---|
305 |
|
---|
306 | # figure out where to insert the new import. First try to find
|
---|
307 | # the first import and then skip to the last one.
|
---|
308 | insert_pos = offset = 0
|
---|
309 | for idx, node in enumerate(root.children):
|
---|
310 | if not is_import_stmt(node):
|
---|
311 | continue
|
---|
312 | for offset, node2 in enumerate(root.children[idx:]):
|
---|
313 | if not is_import_stmt(node2):
|
---|
314 | break
|
---|
315 | insert_pos = idx + offset
|
---|
316 | break
|
---|
317 |
|
---|
318 | # if there are no imports where we can insert, find the docstring.
|
---|
319 | # if that also fails, we stick to the beginning of the file
|
---|
320 | if insert_pos == 0:
|
---|
321 | for idx, node in enumerate(root.children):
|
---|
322 | if (node.type == syms.simple_stmt and node.children and
|
---|
323 | node.children[0].type == token.STRING):
|
---|
324 | insert_pos = idx + 1
|
---|
325 | break
|
---|
326 |
|
---|
327 | if package is None:
|
---|
328 | import_ = Node(syms.import_name, [
|
---|
329 | Leaf(token.NAME, u"import"),
|
---|
330 | Leaf(token.NAME, name, prefix=u" ")
|
---|
331 | ])
|
---|
332 | else:
|
---|
333 | import_ = FromImport(package, [Leaf(token.NAME, name, prefix=u" ")])
|
---|
334 |
|
---|
335 | children = [import_, Newline()]
|
---|
336 | root.insert_child(insert_pos, Node(syms.simple_stmt, children))
|
---|
337 |
|
---|
338 |
|
---|
339 | _def_syms = set([syms.classdef, syms.funcdef])
|
---|
340 | def find_binding(name, node, package=None):
|
---|
341 | """ Returns the node which binds variable name, otherwise None.
|
---|
342 | If optional argument package is supplied, only imports will
|
---|
343 | be returned.
|
---|
344 | See test cases for examples."""
|
---|
345 | for child in node.children:
|
---|
346 | ret = None
|
---|
347 | if child.type == syms.for_stmt:
|
---|
348 | if _find(name, child.children[1]):
|
---|
349 | return child
|
---|
350 | n = find_binding(name, make_suite(child.children[-1]), package)
|
---|
351 | if n: ret = n
|
---|
352 | elif child.type in (syms.if_stmt, syms.while_stmt):
|
---|
353 | n = find_binding(name, make_suite(child.children[-1]), package)
|
---|
354 | if n: ret = n
|
---|
355 | elif child.type == syms.try_stmt:
|
---|
356 | n = find_binding(name, make_suite(child.children[2]), package)
|
---|
357 | if n:
|
---|
358 | ret = n
|
---|
359 | else:
|
---|
360 | for i, kid in enumerate(child.children[3:]):
|
---|
361 | if kid.type == token.COLON and kid.value == ":":
|
---|
362 | # i+3 is the colon, i+4 is the suite
|
---|
363 | n = find_binding(name, make_suite(child.children[i+4]), package)
|
---|
364 | if n: ret = n
|
---|
365 | elif child.type in _def_syms and child.children[1].value == name:
|
---|
366 | ret = child
|
---|
367 | elif _is_import_binding(child, name, package):
|
---|
368 | ret = child
|
---|
369 | elif child.type == syms.simple_stmt:
|
---|
370 | ret = find_binding(name, child, package)
|
---|
371 | elif child.type == syms.expr_stmt:
|
---|
372 | if _find(name, child.children[0]):
|
---|
373 | ret = child
|
---|
374 |
|
---|
375 | if ret:
|
---|
376 | if not package:
|
---|
377 | return ret
|
---|
378 | if is_import(ret):
|
---|
379 | return ret
|
---|
380 | return None
|
---|
381 |
|
---|
382 | _block_syms = set([syms.funcdef, syms.classdef, syms.trailer])
|
---|
383 | def _find(name, node):
|
---|
384 | nodes = [node]
|
---|
385 | while nodes:
|
---|
386 | node = nodes.pop()
|
---|
387 | if node.type > 256 and node.type not in _block_syms:
|
---|
388 | nodes.extend(node.children)
|
---|
389 | elif node.type == token.NAME and node.value == name:
|
---|
390 | return node
|
---|
391 | return None
|
---|
392 |
|
---|
393 | def _is_import_binding(node, name, package=None):
|
---|
394 | """ Will reuturn node if node will import name, or node
|
---|
395 | will import * from package. None is returned otherwise.
|
---|
396 | See test cases for examples. """
|
---|
397 |
|
---|
398 | if node.type == syms.import_name and not package:
|
---|
399 | imp = node.children[1]
|
---|
400 | if imp.type == syms.dotted_as_names:
|
---|
401 | for child in imp.children:
|
---|
402 | if child.type == syms.dotted_as_name:
|
---|
403 | if child.children[2].value == name:
|
---|
404 | return node
|
---|
405 | elif child.type == token.NAME and child.value == name:
|
---|
406 | return node
|
---|
407 | elif imp.type == syms.dotted_as_name:
|
---|
408 | last = imp.children[-1]
|
---|
409 | if last.type == token.NAME and last.value == name:
|
---|
410 | return node
|
---|
411 | elif imp.type == token.NAME and imp.value == name:
|
---|
412 | return node
|
---|
413 | elif node.type == syms.import_from:
|
---|
414 | # unicode(...) is used to make life easier here, because
|
---|
415 | # from a.b import parses to ['import', ['a', '.', 'b'], ...]
|
---|
416 | if package and unicode(node.children[1]).strip() != package:
|
---|
417 | return None
|
---|
418 | n = node.children[3]
|
---|
419 | if package and _find(u"as", n):
|
---|
420 | # See test_from_import_as for explanation
|
---|
421 | return None
|
---|
422 | elif n.type == syms.import_as_names and _find(name, n):
|
---|
423 | return node
|
---|
424 | elif n.type == syms.import_as_name:
|
---|
425 | child = n.children[2]
|
---|
426 | if child.type == token.NAME and child.value == name:
|
---|
427 | return node
|
---|
428 | elif n.type == token.NAME and n.value == name:
|
---|
429 | return node
|
---|
430 | elif package and n.type == token.STAR:
|
---|
431 | return node
|
---|
432 | return None
|
---|