1 | """
|
---|
2 | Convert use of sys.exitfunc to use the atexit module.
|
---|
3 | """
|
---|
4 |
|
---|
5 | # Author: Benjamin Peterson
|
---|
6 |
|
---|
7 | from lib2to3 import pytree, fixer_base
|
---|
8 | from lib2to3.fixer_util import Name, Attr, Call, Comma, Newline, syms
|
---|
9 |
|
---|
10 |
|
---|
11 | class FixExitfunc(fixer_base.BaseFix):
|
---|
12 | keep_line_order = True
|
---|
13 | BM_compatible = True
|
---|
14 |
|
---|
15 | PATTERN = """
|
---|
16 | (
|
---|
17 | sys_import=import_name<'import'
|
---|
18 | ('sys'
|
---|
19 | |
|
---|
20 | dotted_as_names< (any ',')* 'sys' (',' any)* >
|
---|
21 | )
|
---|
22 | >
|
---|
23 | |
|
---|
24 | expr_stmt<
|
---|
25 | power< 'sys' trailer< '.' 'exitfunc' > >
|
---|
26 | '=' func=any >
|
---|
27 | )
|
---|
28 | """
|
---|
29 |
|
---|
30 | def __init__(self, *args):
|
---|
31 | super(FixExitfunc, self).__init__(*args)
|
---|
32 |
|
---|
33 | def start_tree(self, tree, filename):
|
---|
34 | super(FixExitfunc, self).start_tree(tree, filename)
|
---|
35 | self.sys_import = None
|
---|
36 |
|
---|
37 | def transform(self, node, results):
|
---|
38 | # First, find a the sys import. We'll just hope it's global scope.
|
---|
39 | if "sys_import" in results:
|
---|
40 | if self.sys_import is None:
|
---|
41 | self.sys_import = results["sys_import"]
|
---|
42 | return
|
---|
43 |
|
---|
44 | func = results["func"].clone()
|
---|
45 | func.prefix = u""
|
---|
46 | register = pytree.Node(syms.power,
|
---|
47 | Attr(Name(u"atexit"), Name(u"register"))
|
---|
48 | )
|
---|
49 | call = Call(register, [func], node.prefix)
|
---|
50 | node.replace(call)
|
---|
51 |
|
---|
52 | if self.sys_import is None:
|
---|
53 | # That's interesting.
|
---|
54 | self.warning(node, "Can't find sys import; Please add an atexit "
|
---|
55 | "import at the top of your file.")
|
---|
56 | return
|
---|
57 |
|
---|
58 | # Now add an atexit import after the sys import.
|
---|
59 | names = self.sys_import.children[1]
|
---|
60 | if names.type == syms.dotted_as_names:
|
---|
61 | names.append_child(Comma())
|
---|
62 | names.append_child(Name(u"atexit", u" "))
|
---|
63 | else:
|
---|
64 | containing_stmt = self.sys_import.parent
|
---|
65 | position = containing_stmt.children.index(self.sys_import)
|
---|
66 | stmt_container = containing_stmt.parent
|
---|
67 | new_import = pytree.Node(syms.import_name,
|
---|
68 | [Name(u"import"), Name(u"atexit", u" ")]
|
---|
69 | )
|
---|
70 | new = pytree.Node(syms.simple_stmt, [new_import])
|
---|
71 | containing_stmt.insert_child(position + 1, Newline())
|
---|
72 | containing_stmt.insert_child(position + 2, new)
|
---|