1
1
import ast
2
2
3
3
4
- class AsyncTransformer (ast . NodeTransformer ):
4
+ class AsyncTransformer ():
5
5
"""Converts all async nodes into their synchronous counterparts."""
6
6
7
7
def visit_Await (self , node ):
@@ -16,3 +16,53 @@ def visit_AsyncFor(self, node):
16
16
17
17
def visit_AsyncWith (self , node ):
18
18
return self .visit (ast .With (** node .__dict__ ))
19
+
20
+
21
+ class ChainedFunctionTransformer ():
22
+ def visit_chain (self , node , depth = 1 ):
23
+ if (
24
+ isinstance (node .value , ast .Call ) and
25
+ isinstance (node .value .func , ast .Attribute ) and
26
+ isinstance (node .value .func .value , ast .Call )
27
+ ):
28
+ # Node is assignment or return with value like `b.c().d()`
29
+ call_node = node .value
30
+ # If we want to handle nested functions in future, depth needs fixing
31
+ temp_var_id = '__chain_tmp_{}' .format (depth )
32
+ # AST tree is from right to left, so d() is the outer Call and b.c() is the inner Call
33
+ unvisited_inner_call = ast .Assign (
34
+ targets = [ast .Name (id = temp_var_id , ctx = ast .Store ())],
35
+ value = call_node .func .value ,
36
+ )
37
+ ast .copy_location (unvisited_inner_call , node )
38
+ inner_calls = self .visit_chain (unvisited_inner_call , depth + 1 )
39
+ for inner_call_node in inner_calls :
40
+ ast .copy_location (inner_call_node , node )
41
+ outer_call = self .generic_visit (type (node )(
42
+ value = ast .Call (
43
+ func = ast .Attribute (
44
+ value = ast .Name (id = temp_var_id , ctx = ast .Load ()),
45
+ attr = call_node .func .attr ,
46
+ ctx = ast .Load (),
47
+ ),
48
+ args = call_node .args ,
49
+ keywords = call_node .keywords ,
50
+ ),
51
+ ** {field : value for field , value in ast .iter_fields (node ) if field != 'value' } # e.g. targets
52
+ ))
53
+ ast .copy_location (outer_call , node )
54
+ ast .copy_location (outer_call .value , node )
55
+ ast .copy_location (outer_call .value .func , node )
56
+ return [* inner_calls , outer_call ]
57
+ else :
58
+ return [self .generic_visit (node )]
59
+
60
+ def visit_Assign (self , node ):
61
+ return self .visit_chain (node )
62
+
63
+ def visit_Return (self , node ):
64
+ return self .visit_chain (node )
65
+
66
+
67
+ class PytTransformer (AsyncTransformer , ChainedFunctionTransformer , ast .NodeTransformer ):
68
+ pass
0 commit comments