1
1
from __future__ import annotations
2
2
3
3
import ast
4
+ import sys
5
+ from collections .abc import Sequence
4
6
from copy import copy , deepcopy
5
7
from dataclasses import dataclass
6
8
9
+ PY_39 = sys .version_info <= (3 , 9 )
10
+
7
11
# TODO: make walrus throw ValueError
8
- # TODO: match ... case
9
12
10
13
11
- def build_polars_when_then_otherwise (test : ast .expr , then : ast .expr , orelse : ast .expr ) -> ast .Call :
12
- when_node = ast .Call (
13
- func = ast .Attribute (value = ast .Name (id = "pl" , ctx = ast .Load ()), attr = "when" , ctx = ast .Load ()),
14
- args = [test ],
15
- keywords = [],
16
- )
14
+ @dataclass
15
+ class UnresolvedCase :
16
+ """
17
+ An unresolved case in a conditional statement. (if, match, etc.)
18
+ Each case consists of a test expression and a state.
19
+ The value of the state is not yet resolved.
20
+ """
17
21
18
- then_node = ast .Call (
19
- func = ast .Attribute (value = when_node , attr = "then" , ctx = ast .Load ()),
20
- args = [then ],
21
- keywords = [],
22
- )
22
+ test : ast .expr
23
+ state : State
24
+
25
+ def __init__ (self , test : ast .expr , then : State ):
26
+ self .test = test
27
+ self .state = then
28
+
29
+
30
+ @dataclass
31
+ class ResolvedCase :
32
+ """
33
+ A resolved case in a conditional statement. (if, match, etc.)
34
+ Each case consists of a test expression and a state.
35
+ The value of the state is resolved.
36
+ """
37
+
38
+ test : ast .expr
39
+ state : ast .expr
40
+
41
+ def __init__ (self , test : ast .expr , then : ast .expr ):
42
+ self .test = test
43
+ self .state = then
44
+
45
+ def __iter__ (self ):
46
+ return iter ([self .test , self .state ])
47
+
48
+
49
+ def build_polars_when_then_otherwise (body : Sequence [ResolvedCase ], orelse : ast .expr ) -> ast .Call :
50
+ nodes : list [ast .Call ] = []
51
+
52
+ assert body or orelse , "No when-then cases provided."
53
+
54
+ for test , then in body :
55
+ when_node = ast .Call (
56
+ func = ast .Attribute (
57
+ value = nodes [- 1 ] if nodes else ast .Name (id = "pl" , ctx = ast .Load ()),
58
+ attr = "when" ,
59
+ ctx = ast .Load (),
60
+ ),
61
+ args = [test ],
62
+ keywords = [],
63
+ )
64
+ then_node = ast .Call (
65
+ func = ast .Attribute (value = when_node , attr = "then" , ctx = ast .Load ()),
66
+ args = [then ],
67
+ keywords = [],
68
+ )
69
+ nodes .append (then_node )
23
70
final_node = ast .Call (
24
- func = ast .Attribute (value = then_node , attr = "otherwise" , ctx = ast .Load ()),
71
+ func = ast .Attribute (value = nodes [ - 1 ] , attr = "otherwise" , ctx = ast .Load ()),
25
72
args = [orelse ],
26
73
keywords = [],
27
74
)
@@ -63,7 +110,7 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Call:
63
110
test = self .visit (node .test )
64
111
body = self .visit (node .body )
65
112
orelse = self .visit (node .orelse )
66
- return build_polars_when_then_otherwise (test , body , orelse )
113
+ return build_polars_when_then_otherwise ([ ResolvedCase ( test , body )] , orelse )
67
114
68
115
def visit_Constant (self , node : ast .Constant ) -> ast .Constant :
69
116
return node
@@ -122,11 +169,11 @@ class ReturnState:
122
169
@dataclass
123
170
class ConditionalState :
124
171
"""
125
- A conditional state, with a test expression and two branches.
172
+ A list of conditional states.
173
+ Each case consists of a test expression and a state.
126
174
"""
127
175
128
- test : ast .expr
129
- then : State
176
+ body : Sequence [UnresolvedCase ]
130
177
orelse : State
131
178
132
179
@@ -139,25 +186,106 @@ class State:
139
186
140
187
node : UnresolvedState | ReturnState | ConditionalState
141
188
189
+ def translate_match (
190
+ self ,
191
+ subj : ast .expr | Sequence [ast .expr ] | ast .Tuple ,
192
+ pattern : ast .pattern ,
193
+ guard : ast .expr | None = None ,
194
+ ):
195
+ """
196
+ Translate a match_case statement into a regular AST expression.
197
+ translate_match takes a subject, a pattern and a guard.
198
+ patterns can be a MatchValue, MatchAs, MatchOr, or MatchSequence.
199
+ subjects can be a single expression (e.g x or (2 * x + 1)) or a list of expressions.
200
+ translate_match is called per each case in a match statement.
201
+ """
202
+
203
+ if isinstance (pattern , ast .MatchValue ):
204
+ equality_ast = ast .Compare (
205
+ left = subj ,
206
+ ops = [ast .Eq ()],
207
+ comparators = [pattern .value ],
208
+ )
209
+
210
+ if guard is not None :
211
+ return ast .BinOp (
212
+ left = guard ,
213
+ op = ast .BitAnd (),
214
+ right = equality_ast ,
215
+ )
216
+
217
+ return equality_ast
218
+ elif isinstance (pattern , ast .MatchAs ):
219
+ if pattern .name is not None :
220
+ self .handle_assign (
221
+ ast .Assign (
222
+ targets = [ast .Name (id = pattern .name , ctx = ast .Store ())],
223
+ value = subj ,
224
+ )
225
+ )
226
+ return guard
227
+ elif isinstance (pattern , ast .MatchOr ):
228
+ return ast .BinOp (
229
+ left = self .translate_match (subj , pattern .patterns [0 ], guard ),
230
+ op = ast .BitOr (),
231
+ right = (
232
+ self .translate_match (subj , ast .MatchOr (patterns = pattern .patterns [1 :]))
233
+ if pattern .patterns [2 :]
234
+ else self .translate_match (subj , pattern .patterns [1 ])
235
+ ),
236
+ )
237
+ elif isinstance (pattern , ast .MatchSequence ):
238
+ if isinstance (pattern .patterns [- 1 ], ast .MatchStar ):
239
+ raise ValueError ("starred patterns are not supported." )
240
+
241
+ if isinstance (subj , ast .Tuple ):
242
+ # TODO: Use polars list operations in the future
243
+ left = self .translate_match (subj .elts [0 ], pattern .patterns [0 ], guard )
244
+ right = (
245
+ self .translate_match (
246
+ ast .Tuple (elts = subj .elts [1 :]),
247
+ ast .MatchSequence (patterns = pattern .patterns [1 :]),
248
+ )
249
+ if pattern .patterns [2 :]
250
+ else self .translate_match (subj .elts [1 ], pattern .patterns [1 ])
251
+ )
252
+
253
+ return (
254
+ left or right
255
+ if left is None or right is None
256
+ else ast .BinOp (left = left , op = ast .BitAnd (), right = right )
257
+ )
258
+ raise ValueError ("Matching lists is not supported." )
259
+ else :
260
+ raise ValueError (
261
+ f"Incompatible match and subject types: { type (pattern )} and { type (subj )} ."
262
+ )
263
+
142
264
def handle_assign (self , expr : ast .Assign | ast .AnnAssign ):
143
265
if isinstance (expr , ast .AnnAssign ):
144
266
expr = ast .Assign (targets = [expr .target ], value = expr .value )
145
267
146
268
if isinstance (self .node , UnresolvedState ):
147
269
self .node .handle_assign (expr )
148
270
elif isinstance (self .node , ConditionalState ):
149
- self .node .then .handle_assign (expr )
271
+ for case in self .node .body :
272
+ case .state .handle_assign (expr )
150
273
self .node .orelse .handle_assign (expr )
151
274
152
275
def handle_if (self , stmt : ast .If ):
153
276
if isinstance (self .node , UnresolvedState ):
154
277
self .node = ConditionalState (
155
- test = InlineTransformer .inline_expr (stmt .test , self .node .assignments ),
156
- then = parse_body (stmt .body , copy (self .node .assignments )),
278
+ body = [
279
+ UnresolvedCase (
280
+ InlineTransformer .inline_expr (stmt .test , self .node .assignments ),
281
+ parse_body (stmt .body , copy (self .node .assignments )),
282
+ )
283
+ ],
157
284
orelse = parse_body (stmt .orelse , copy (self .node .assignments )),
158
285
)
159
286
elif isinstance (self .node , ConditionalState ):
160
- self .node .then .handle_if (stmt )
287
+ for case in self .node .body :
288
+ case .state .handle_if (stmt )
161
289
self .node .orelse .handle_if (stmt )
162
290
163
291
def handle_return (self , value : ast .expr ):
@@ -166,9 +294,58 @@ def handle_return(self, value: ast.expr):
166
294
expr = InlineTransformer .inline_expr (value , self .node .assignments )
167
295
)
168
296
elif isinstance (self .node , ConditionalState ):
169
- self .node .then .handle_return (value )
297
+ for case in self .node .body :
298
+ case .state .handle_return (value )
170
299
self .node .orelse .handle_return (value )
171
300
301
+ def handle_match (self , stmt : ast .Match ):
302
+ def is_catch_all (case : ast .match_case ) -> bool :
303
+ # We check if the case is a catch-all pattern without a guard
304
+ # If it has a guard, we treat it as a regular case
305
+ return (
306
+ isinstance (case .pattern , ast .MatchAs )
307
+ and case .pattern .name is None
308
+ and case .guard is None
309
+ )
310
+
311
+ def ignore_case (case : ast .match_case ) -> bool :
312
+ # if the length of the pattern is not equal to the length of the subject, python ignores the case
313
+ return (
314
+ isinstance (case .pattern , ast .MatchSequence )
315
+ and isinstance (stmt .subject , ast .Tuple )
316
+ and len (stmt .subject .elts ) != len (case .pattern .patterns )
317
+ ) or (isinstance (case .pattern , ast .MatchValue ) and isinstance (stmt .subject , ast .Tuple ))
318
+
319
+ if isinstance (self .node , UnresolvedState ):
320
+ # We can always rewrite catch-all patterns to orelse since python throws a SyntaxError if the catch-all pattern is not the last case.
321
+ orelse = next (
322
+ iter ([case .body for case in stmt .cases if is_catch_all (case )]),
323
+ [],
324
+ )
325
+ self .node = ConditionalState (
326
+ body = [
327
+ UnresolvedCase (
328
+ # translate_match transforms the match statement case into regular AST expressions so that the InlineTransformer can handle assignments correctly
329
+ # Note that by the time parse_body is called this has mutated the assignments
330
+ InlineTransformer .inline_expr (
331
+ self .translate_match (stmt .subject , case .pattern , case .guard ),
332
+ self .node .assignments ,
333
+ ),
334
+ parse_body (case .body , copy (self .node .assignments )),
335
+ )
336
+ for case in stmt .cases
337
+ if not is_catch_all (case ) and not ignore_case (case )
338
+ ],
339
+ orelse = parse_body (
340
+ orelse ,
341
+ copy (self .node .assignments ),
342
+ ),
343
+ )
344
+ elif isinstance (self .node , ConditionalState ):
345
+ for case in self .node .body :
346
+ case .state .handle_match (stmt )
347
+ self .node .orelse .handle_match (stmt )
348
+
172
349
173
350
def parse_body (full_body : list [ast .stmt ], assignments : dict [str , ast .expr ] | None = None ) -> State :
174
351
if assignments is None :
@@ -182,9 +359,11 @@ def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | Non
182
359
elif isinstance (stmt , ast .Return ):
183
360
if stmt .value is None :
184
361
raise ValueError ("return needs a value" )
185
-
186
362
state .handle_return (stmt .value )
187
363
break
364
+ elif isinstance (stmt , ast .Match ):
365
+ assert not PY_39
366
+ state .handle_match (stmt )
188
367
else :
189
368
raise ValueError (f"Unsupported statement type: { type (stmt )} " )
190
369
return state
@@ -194,9 +373,15 @@ def transform_tree_into_expr(node: State) -> ast.expr:
194
373
if isinstance (node .node , ReturnState ):
195
374
return node .node .expr
196
375
elif isinstance (node .node , ConditionalState ):
376
+ if not node .node .body :
377
+ # this happens if none of the cases will ever match or exist
378
+ # in these cases we just need to return the orelse body
379
+ return transform_tree_into_expr (node .node .orelse )
197
380
return build_polars_when_then_otherwise (
198
- node .node .test ,
199
- transform_tree_into_expr (node .node .then ),
381
+ [
382
+ ResolvedCase (case .test , transform_tree_into_expr (case .state ))
383
+ for case in node .node .body
384
+ ],
200
385
transform_tree_into_expr (node .node .orelse ),
201
386
)
202
387
else :
0 commit comments