Skip to content

Commit 9aa94fa

Browse files
prsabahramiBela Stoyanpavelzw
authored
Add support for match ... case (#60)
Co-authored-by: Bela Stoyan <[email protected]> Co-authored-by: Pavel Zwerschke <[email protected]>
1 parent 97fb1da commit 9aa94fa

File tree

8 files changed

+2241
-1348
lines changed

8 files changed

+2241
-1348
lines changed

.github/workflows/build.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ jobs:
3535

3636
release:
3737
name: Publish package
38-
if: github.event_name == 'push' && github.ref_name == 'main' && needs.build.outputs.version-changed == 'true'
38+
if: github.event_name == 'push' && github.repository == 'Quantco/polarify' && github.ref_name == 'main' && needs.build.outputs.version-changed == 'true'
3939
needs: [build]
4040
runs-on: ubuntu-latest
4141
permissions:

README.md

+4-1
Original file line numberDiff line numberDiff line change
@@ -198,14 +198,17 @@ polarIFy is still in an early stage of development and doesn't support the full
198198
- assignments (like `x = 1`)
199199
- polars expressions (like `pl.col("x")`, TODO)
200200
- side-effect free functions that return a polars expression (can be generated by `@polarify`) (TODO)
201+
- `match` statements
201202

202203
### Unsupported operations
203204

204205
- `for` loops
205206
- `while` loops
206207
- `break` statements
207208
- `:=` walrus operator
208-
- `match ... case` statements (TODO)
209+
- dictionary mappings in `match` statements
210+
- list matching in `match` statements
211+
- star patterns in `match statements
209212
- functions with side-effects (`print`, `pl.write_csv`, ...)
210213

211214
## 🚀 Benchmarks

pixi.lock

+1,668-1,313
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pixi.toml

+7-7
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,13 @@ lint = "pre-commit run --all"
5959

6060
[environments]
6161
default = ["test"]
62-
pl014 = ["pl014", "py39", "test"]
63-
pl015 = ["pl015", "py39", "test"]
64-
pl016 = ["pl016", "py39", "test"]
65-
pl017 = ["pl017", "py39", "test"]
66-
pl018 = ["pl018", "py39", "test"]
67-
pl019 = ["pl019", "py39", "test"]
68-
pl020 = ["pl020", "py39", "test"]
62+
pl014 = ["pl014", "py310", "test"]
63+
pl015 = ["pl015", "py310", "test"]
64+
pl016 = ["pl016", "py310", "test"]
65+
pl017 = ["pl017", "py310", "test"]
66+
pl018 = ["pl018", "py310", "test"]
67+
pl019 = ["pl019", "py310", "test"]
68+
pl020 = ["pl020", "py310", "test"]
6969
py39 = ["py39", "test"]
7070
py310 = ["py310", "test"]
7171
py311 = ["py311", "test"]

polarify/main.py

+210-25
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,74 @@
11
from __future__ import annotations
22

33
import ast
4+
import sys
5+
from collections.abc import Sequence
46
from copy import copy, deepcopy
57
from dataclasses import dataclass
68

9+
PY_39 = sys.version_info <= (3, 9)
10+
711
# TODO: make walrus throw ValueError
8-
# TODO: match ... case
912

1013

11-
def build_polars_when_then_otherwise(test: ast.expr, then: ast.expr, orelse: ast.expr) -> ast.Call:
12-
when_node = ast.Call(
13-
func=ast.Attribute(value=ast.Name(id="pl", ctx=ast.Load()), attr="when", ctx=ast.Load()),
14-
args=[test],
15-
keywords=[],
16-
)
14+
@dataclass
15+
class UnresolvedCase:
16+
"""
17+
An unresolved case in a conditional statement. (if, match, etc.)
18+
Each case consists of a test expression and a state.
19+
The value of the state is not yet resolved.
20+
"""
1721

18-
then_node = ast.Call(
19-
func=ast.Attribute(value=when_node, attr="then", ctx=ast.Load()),
20-
args=[then],
21-
keywords=[],
22-
)
22+
test: ast.expr
23+
state: State
24+
25+
def __init__(self, test: ast.expr, then: State):
26+
self.test = test
27+
self.state = then
28+
29+
30+
@dataclass
31+
class ResolvedCase:
32+
"""
33+
A resolved case in a conditional statement. (if, match, etc.)
34+
Each case consists of a test expression and a state.
35+
The value of the state is resolved.
36+
"""
37+
38+
test: ast.expr
39+
state: ast.expr
40+
41+
def __init__(self, test: ast.expr, then: ast.expr):
42+
self.test = test
43+
self.state = then
44+
45+
def __iter__(self):
46+
return iter([self.test, self.state])
47+
48+
49+
def build_polars_when_then_otherwise(body: Sequence[ResolvedCase], orelse: ast.expr) -> ast.Call:
50+
nodes: list[ast.Call] = []
51+
52+
assert body or orelse, "No when-then cases provided."
53+
54+
for test, then in body:
55+
when_node = ast.Call(
56+
func=ast.Attribute(
57+
value=nodes[-1] if nodes else ast.Name(id="pl", ctx=ast.Load()),
58+
attr="when",
59+
ctx=ast.Load(),
60+
),
61+
args=[test],
62+
keywords=[],
63+
)
64+
then_node = ast.Call(
65+
func=ast.Attribute(value=when_node, attr="then", ctx=ast.Load()),
66+
args=[then],
67+
keywords=[],
68+
)
69+
nodes.append(then_node)
2370
final_node = ast.Call(
24-
func=ast.Attribute(value=then_node, attr="otherwise", ctx=ast.Load()),
71+
func=ast.Attribute(value=nodes[-1], attr="otherwise", ctx=ast.Load()),
2572
args=[orelse],
2673
keywords=[],
2774
)
@@ -63,7 +110,7 @@ def visit_IfExp(self, node: ast.IfExp) -> ast.Call:
63110
test = self.visit(node.test)
64111
body = self.visit(node.body)
65112
orelse = self.visit(node.orelse)
66-
return build_polars_when_then_otherwise(test, body, orelse)
113+
return build_polars_when_then_otherwise([ResolvedCase(test, body)], orelse)
67114

68115
def visit_Constant(self, node: ast.Constant) -> ast.Constant:
69116
return node
@@ -122,11 +169,11 @@ class ReturnState:
122169
@dataclass
123170
class ConditionalState:
124171
"""
125-
A conditional state, with a test expression and two branches.
172+
A list of conditional states.
173+
Each case consists of a test expression and a state.
126174
"""
127175

128-
test: ast.expr
129-
then: State
176+
body: Sequence[UnresolvedCase]
130177
orelse: State
131178

132179

@@ -139,25 +186,106 @@ class State:
139186

140187
node: UnresolvedState | ReturnState | ConditionalState
141188

189+
def translate_match(
190+
self,
191+
subj: ast.expr | Sequence[ast.expr] | ast.Tuple,
192+
pattern: ast.pattern,
193+
guard: ast.expr | None = None,
194+
):
195+
"""
196+
Translate a match_case statement into a regular AST expression.
197+
translate_match takes a subject, a pattern and a guard.
198+
patterns can be a MatchValue, MatchAs, MatchOr, or MatchSequence.
199+
subjects can be a single expression (e.g x or (2 * x + 1)) or a list of expressions.
200+
translate_match is called per each case in a match statement.
201+
"""
202+
203+
if isinstance(pattern, ast.MatchValue):
204+
equality_ast = ast.Compare(
205+
left=subj,
206+
ops=[ast.Eq()],
207+
comparators=[pattern.value],
208+
)
209+
210+
if guard is not None:
211+
return ast.BinOp(
212+
left=guard,
213+
op=ast.BitAnd(),
214+
right=equality_ast,
215+
)
216+
217+
return equality_ast
218+
elif isinstance(pattern, ast.MatchAs):
219+
if pattern.name is not None:
220+
self.handle_assign(
221+
ast.Assign(
222+
targets=[ast.Name(id=pattern.name, ctx=ast.Store())],
223+
value=subj,
224+
)
225+
)
226+
return guard
227+
elif isinstance(pattern, ast.MatchOr):
228+
return ast.BinOp(
229+
left=self.translate_match(subj, pattern.patterns[0], guard),
230+
op=ast.BitOr(),
231+
right=(
232+
self.translate_match(subj, ast.MatchOr(patterns=pattern.patterns[1:]))
233+
if pattern.patterns[2:]
234+
else self.translate_match(subj, pattern.patterns[1])
235+
),
236+
)
237+
elif isinstance(pattern, ast.MatchSequence):
238+
if isinstance(pattern.patterns[-1], ast.MatchStar):
239+
raise ValueError("starred patterns are not supported.")
240+
241+
if isinstance(subj, ast.Tuple):
242+
# TODO: Use polars list operations in the future
243+
left = self.translate_match(subj.elts[0], pattern.patterns[0], guard)
244+
right = (
245+
self.translate_match(
246+
ast.Tuple(elts=subj.elts[1:]),
247+
ast.MatchSequence(patterns=pattern.patterns[1:]),
248+
)
249+
if pattern.patterns[2:]
250+
else self.translate_match(subj.elts[1], pattern.patterns[1])
251+
)
252+
253+
return (
254+
left or right
255+
if left is None or right is None
256+
else ast.BinOp(left=left, op=ast.BitAnd(), right=right)
257+
)
258+
raise ValueError("Matching lists is not supported.")
259+
else:
260+
raise ValueError(
261+
f"Incompatible match and subject types: {type(pattern)} and {type(subj)}."
262+
)
263+
142264
def handle_assign(self, expr: ast.Assign | ast.AnnAssign):
143265
if isinstance(expr, ast.AnnAssign):
144266
expr = ast.Assign(targets=[expr.target], value=expr.value)
145267

146268
if isinstance(self.node, UnresolvedState):
147269
self.node.handle_assign(expr)
148270
elif isinstance(self.node, ConditionalState):
149-
self.node.then.handle_assign(expr)
271+
for case in self.node.body:
272+
case.state.handle_assign(expr)
150273
self.node.orelse.handle_assign(expr)
151274

152275
def handle_if(self, stmt: ast.If):
153276
if isinstance(self.node, UnresolvedState):
154277
self.node = ConditionalState(
155-
test=InlineTransformer.inline_expr(stmt.test, self.node.assignments),
156-
then=parse_body(stmt.body, copy(self.node.assignments)),
278+
body=[
279+
UnresolvedCase(
280+
InlineTransformer.inline_expr(stmt.test, self.node.assignments),
281+
parse_body(stmt.body, copy(self.node.assignments)),
282+
)
283+
],
157284
orelse=parse_body(stmt.orelse, copy(self.node.assignments)),
158285
)
159286
elif isinstance(self.node, ConditionalState):
160-
self.node.then.handle_if(stmt)
287+
for case in self.node.body:
288+
case.state.handle_if(stmt)
161289
self.node.orelse.handle_if(stmt)
162290

163291
def handle_return(self, value: ast.expr):
@@ -166,9 +294,58 @@ def handle_return(self, value: ast.expr):
166294
expr=InlineTransformer.inline_expr(value, self.node.assignments)
167295
)
168296
elif isinstance(self.node, ConditionalState):
169-
self.node.then.handle_return(value)
297+
for case in self.node.body:
298+
case.state.handle_return(value)
170299
self.node.orelse.handle_return(value)
171300

301+
def handle_match(self, stmt: ast.Match):
302+
def is_catch_all(case: ast.match_case) -> bool:
303+
# We check if the case is a catch-all pattern without a guard
304+
# If it has a guard, we treat it as a regular case
305+
return (
306+
isinstance(case.pattern, ast.MatchAs)
307+
and case.pattern.name is None
308+
and case.guard is None
309+
)
310+
311+
def ignore_case(case: ast.match_case) -> bool:
312+
# if the length of the pattern is not equal to the length of the subject, python ignores the case
313+
return (
314+
isinstance(case.pattern, ast.MatchSequence)
315+
and isinstance(stmt.subject, ast.Tuple)
316+
and len(stmt.subject.elts) != len(case.pattern.patterns)
317+
) or (isinstance(case.pattern, ast.MatchValue) and isinstance(stmt.subject, ast.Tuple))
318+
319+
if isinstance(self.node, UnresolvedState):
320+
# We can always rewrite catch-all patterns to orelse since python throws a SyntaxError if the catch-all pattern is not the last case.
321+
orelse = next(
322+
iter([case.body for case in stmt.cases if is_catch_all(case)]),
323+
[],
324+
)
325+
self.node = ConditionalState(
326+
body=[
327+
UnresolvedCase(
328+
# translate_match transforms the match statement case into regular AST expressions so that the InlineTransformer can handle assignments correctly
329+
# Note that by the time parse_body is called this has mutated the assignments
330+
InlineTransformer.inline_expr(
331+
self.translate_match(stmt.subject, case.pattern, case.guard),
332+
self.node.assignments,
333+
),
334+
parse_body(case.body, copy(self.node.assignments)),
335+
)
336+
for case in stmt.cases
337+
if not is_catch_all(case) and not ignore_case(case)
338+
],
339+
orelse=parse_body(
340+
orelse,
341+
copy(self.node.assignments),
342+
),
343+
)
344+
elif isinstance(self.node, ConditionalState):
345+
for case in self.node.body:
346+
case.state.handle_match(stmt)
347+
self.node.orelse.handle_match(stmt)
348+
172349

173350
def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | None = None) -> State:
174351
if assignments is None:
@@ -182,9 +359,11 @@ def parse_body(full_body: list[ast.stmt], assignments: dict[str, ast.expr] | Non
182359
elif isinstance(stmt, ast.Return):
183360
if stmt.value is None:
184361
raise ValueError("return needs a value")
185-
186362
state.handle_return(stmt.value)
187363
break
364+
elif isinstance(stmt, ast.Match):
365+
assert not PY_39
366+
state.handle_match(stmt)
188367
else:
189368
raise ValueError(f"Unsupported statement type: {type(stmt)}")
190369
return state
@@ -194,9 +373,15 @@ def transform_tree_into_expr(node: State) -> ast.expr:
194373
if isinstance(node.node, ReturnState):
195374
return node.node.expr
196375
elif isinstance(node.node, ConditionalState):
376+
if not node.node.body:
377+
# this happens if none of the cases will ever match or exist
378+
# in these cases we just need to return the orelse body
379+
return transform_tree_into_expr(node.node.orelse)
197380
return build_polars_when_then_otherwise(
198-
node.node.test,
199-
transform_tree_into_expr(node.node.then),
381+
[
382+
ResolvedCase(case.test, transform_tree_into_expr(case.state))
383+
for case in node.node.body
384+
],
200385
transform_tree_into_expr(node.node.orelse),
201386
)
202387
else:

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ build-backend = "hatchling.build"
55
[project]
66
name = "polarify"
77
description = "Simplifying conditional Polars Expressions with Python 🐍 🐻‍❄️"
8-
version = "0.1.5"
8+
version = "0.2.0"
99
readme = "README.md"
1010
license = {file = "LICENSE"}
1111
requires-python = ">=3.9"

0 commit comments

Comments
 (0)