Skip to content

Commit cfe2c11

Browse files
committed
Conversion from pyarrow Expressions to QueryBuilder expressions
Introduces `ExpressionNode.from_pyarrow_expression_str`. This is required for predicate pushdown integration with polars. E.g. when we do: ``` lf = polars.scan_arcticdb(arctic_identifier) lf = lf.filter(pl.col("float_col") <= 40.1) lf.collect() ``` When calling `collect` the filter will get pushed down to our (to be implemented) `polars.scan_arcticdb` via a callback. The filter is given as a string which can be evaluated to a pyarrow expression. For reference string construction happens [here](https://github.com/pola-rs/polars/blob/4e286d8c83b0fcb56f0a7ea06d2eb731f179e01e/crates/polars-plan/src/plans/python/pyarrow.rs#L25). I've decided to keep this conversion code in core arcticdb instead of polars for a few reasons: - Gives us more flexibility if we decide to update our `QueryBuilder` expressions - We'll have less code to maintain in a repository which is not ours - Automated tests allow us to not accidentally break polars predicate pushdown to arcticdb For additional reference see how pyiceberg handles predicate pushdown from polars [here](https://github.com/pola-rs/polars/blob/4e286d8c83b0fcb56f0a7ea06d2eb731f179e01e/py-polars/polars/io/iceberg.py#L197-L200) And how this might fit in the prototype `polars.scan_arcticdb` implementation [here](https://github.com/IvoDD/polars/pull/1/files)
1 parent bdc5d24 commit cfe2c11

File tree

2 files changed

+292
-1
lines changed

2 files changed

+292
-1
lines changed

python/arcticdb/version_store/processing.py

+156-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
"""
88
from collections import namedtuple
99
import copy
10+
from collections.abc import Callable
1011
from dataclasses import dataclass
1112
import datetime
1213
from math import inf
@@ -15,7 +16,12 @@
1516
import pandas as pd
1617
from pandas.tseries.frequencies import to_offset
1718

18-
from typing import Dict, NamedTuple, Optional, Tuple, Union
19+
from typing import Dict, NamedTuple, Optional, Tuple, Union, Any
20+
21+
import sys
22+
import ast
23+
24+
from functools import singledispatch
1925

2026
from arcticdb.exceptions import ArcticDbNotYetImplemented, ArcticNativeException, UserInputException
2127
from arcticdb.version_store._normalization import normalize_dt_range_to_ts
@@ -250,6 +256,155 @@ def get_name(self):
250256
return self.name
251257

252258

259+
@classmethod
260+
def from_pyarrow_expression_str(cls, expression_str : str, function_map : Optional[Dict[str, Callable]] = None) -> "ExpressionNode":
261+
"""
262+
Builds an ExpressionNode from a pyarrow expression string.
263+
264+
It is required for an integration with polars predicate pushdown. We get the pyarrow expression as a string
265+
because pyarrow doesn't provide any APIs for traversing the expression tree.
266+
267+
Any of pyarrow's `is_null`, `is_nan` and `is_valid` will get converted to our ArcticDB's `isnull` and `notnull`,
268+
which don't differentiate nulls and nans.
269+
"""
270+
if function_map is None:
271+
function_map = {}
272+
try:
273+
expression_ast = ast.parse(expression_str, mode="eval").body
274+
return _ast_to_expression(expression_ast, function_map)
275+
except Exception as e:
276+
msg = f"Could not parse pyarrow expression as an arcticdb expression: {e}"
277+
raise ValueError(msg)
278+
279+
280+
@singledispatch
281+
def _ast_to_expression(a: Any, function_map) -> Any:
282+
"""Walks the AST to convert the PyArrow expression to an ArcticDB expression."""
283+
raise ValueError(f"Unexpected symbol: {a}")
284+
285+
286+
@_ast_to_expression.register(ast.Constant)
287+
def _(a: ast.Constant, function_map) -> Any:
288+
return a.value
289+
290+
291+
if sys.version_info < (3, 8):
292+
@_ast_to_expression.register(ast.Str)
293+
def _(a: ast.Str, function_map) -> Any:
294+
return a.s
295+
296+
@_ast_to_expression.register(ast.Num)
297+
def _(a: ast.Num, function_map) -> Any:
298+
return a.n
299+
300+
@_ast_to_expression.register(ast.Name)
301+
def _(a: ast.Name, function_map) -> Any:
302+
return a.id
303+
304+
305+
@_ast_to_expression.register(ast.UnaryOp)
306+
def _(a: ast.UnaryOp, function_map) -> Any:
307+
operand = _ast_to_expression(a.operand, function_map)
308+
if isinstance(a.op, ast.Invert):
309+
return ~operand
310+
if isinstance(a.op, ast.USub):
311+
# pyarrow expressions don't support unary subrtract, so this branch will not be reached.
312+
# Leaving as future-proofing in case they ever introduce it.
313+
return -operand
314+
raise ValueError(f"Unexpected UnaryOp: {a.op}")
315+
316+
317+
@_ast_to_expression.register(ast.Call)
318+
def _(a: ast.Call, function_map) -> Any:
319+
f = _ast_to_expression(a.func, function_map)
320+
args = [_ast_to_expression(arg, function_map) for arg in a.args]
321+
if callable(f):
322+
return f(*args)
323+
if isinstance(f, str):
324+
if f in function_map:
325+
return function_map[f](*args)
326+
raise ValueError(f"Unexpected function call: {f}")
327+
328+
329+
@_ast_to_expression.register(ast.Attribute)
330+
def _(a: ast.Attribute, function_map) -> Any:
331+
value = _ast_to_expression(a.value, function_map)
332+
attr = a.attr
333+
if isinstance(value, ExpressionNode):
334+
# Handles expression function attributes like (<some expression>).isin([1, 2, 3])
335+
if attr == "isin":
336+
return value.isin
337+
if attr == "is_null" or attr == "is_nan":
338+
return value.isnull
339+
if attr == "is_valid":
340+
return value.notnull
341+
if isinstance(value, str):
342+
# Handles attributes like "pa.compute.field" or "pc.field"
343+
if attr == "field":
344+
return ExpressionNode.column_ref
345+
if attr == "scalar":
346+
return lambda x: x
347+
return f"{value}.{attr}"
348+
raise ValueError(f"Unexpected attribute {attr} of {value}")
349+
350+
351+
@_ast_to_expression.register(ast.BinOp)
352+
def _(a: ast.BinOp, function_map) -> Any:
353+
lhs = _ast_to_expression(a.left, function_map)
354+
rhs = _ast_to_expression(a.right, function_map)
355+
356+
op = a.op
357+
if isinstance(op, ast.BitAnd):
358+
return lhs & rhs
359+
if isinstance(op, ast.BitOr):
360+
return lhs | rhs
361+
if isinstance(op, ast.BitXor):
362+
# pyarrow expressions don't support BitXor, so this branch will not be reached.
363+
# Leaving as future-proofing in case they ever introduce it.
364+
return lhs ^ rhs
365+
366+
if isinstance(op, ast.Add):
367+
return lhs + rhs
368+
if isinstance(op, ast.Sub):
369+
return lhs - rhs
370+
if isinstance(op, ast.Mult):
371+
return lhs * rhs
372+
if isinstance(op, ast.Div):
373+
return lhs / rhs
374+
raise ValueError(f"Unexpected BinOp: {op}")
375+
376+
377+
@_ast_to_expression.register(ast.Compare)
378+
def _(a: ast.Compare, function_map) -> Any:
379+
# Compares in pyarrow Expression contain exactly one comparison (i.e. 1 < field("asdf") < 3 is not supported)
380+
assert len(a.ops) == 1
381+
assert len(a.comparators) == 1
382+
op = a.ops[0]
383+
left = a.left
384+
right = a.comparators[0]
385+
lhs = _ast_to_expression(left, function_map)
386+
rhs = _ast_to_expression(right, function_map)
387+
388+
if isinstance(op, ast.Gt):
389+
return lhs > rhs
390+
if isinstance(op, ast.GtE):
391+
return lhs >= rhs
392+
if isinstance(op, ast.Eq):
393+
return lhs == rhs
394+
if isinstance(op, ast.NotEq):
395+
return lhs != rhs
396+
if isinstance(op, ast.Lt):
397+
return lhs < rhs
398+
if isinstance(op, ast.LtE):
399+
return lhs <= rhs
400+
raise ValueError(f"Unknown comparison: {op}")
401+
402+
403+
@_ast_to_expression.register(ast.List)
404+
def _(a: ast.List, function_map) -> Any:
405+
return [_ast_to_expression(e, function_map) for e in a.elts]
406+
407+
253408
def is_supported_sequence(obj):
254409
return isinstance(obj, (list, set, frozenset, tuple, np.ndarray))
255410

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
import datetime
2+
import numpy as np
3+
import pandas as pd
4+
import pyarrow as pa
5+
import pyarrow.compute as pc
6+
import pytest
7+
8+
from arcticdb.version_store.processing import QueryBuilder, ExpressionNode
9+
from arcticdb.util.test import assert_frame_equal
10+
11+
12+
def df_with_all_column_types(num_rows=100):
13+
data = {
14+
"int_col": np.arange(num_rows, dtype=np.int64),
15+
"float_col": [np.nan if i%20==5 else i for i in range(num_rows)],
16+
"str_col": [f"str_{i}" for i in range(num_rows)],
17+
"bool_col": [i%2 == 0 for i in range(num_rows)],
18+
"datetime_col": pd.date_range(start=pd.Timestamp(2025, 1, 1), periods=num_rows)
19+
}
20+
index = pd.date_range(start=pd.Timestamp(2025, 1, 1), periods=num_rows)
21+
return pd.DataFrame(data=data, index=index)
22+
23+
24+
def compare_against_pyarrow(pyarrow_expr_str, expected_adb_expr, lib, function_map = None, expect_equal=True):
25+
adb_expr = ExpressionNode.from_pyarrow_expression_str(pyarrow_expr_str, function_map)
26+
assert str(adb_expr) == str(expected_adb_expr)
27+
pa_expr = eval(pyarrow_expr_str)
28+
29+
# Setup
30+
sym = "sym"
31+
df = df_with_all_column_types()
32+
lib.write(sym, df)
33+
pa_table = pa.Table.from_pandas(df)
34+
35+
# Apply filter to adb
36+
q = QueryBuilder()
37+
q = q[adb_expr]
38+
adb_result = lib.read(sym, query_builder=q).data
39+
40+
# Apply filter to pyarrow
41+
pa_result = pa_table.filter(pa_expr).to_pandas()
42+
43+
if expect_equal:
44+
assert_frame_equal(adb_result, pa_result)
45+
else:
46+
assert len(adb_result) != len(pa_result)
47+
48+
49+
def test_basic_filters(lmdb_version_store_v1):
50+
lib = lmdb_version_store_v1
51+
52+
# Filter by boolean column
53+
expr = f"pc.field('bool_col')"
54+
expected_expr = ExpressionNode.column_ref('bool_col')
55+
compare_against_pyarrow(expr, expected_expr, lib)
56+
57+
# Filter by comparison
58+
for op in ["<", "<=", "==", ">=", ">"]:
59+
expr = f"pc.field('int_col') {op} 50"
60+
expected_expr = eval(f"ExpressionNode.column_ref('int_col') {op} 50")
61+
compare_against_pyarrow(expr, expected_expr, lib)
62+
63+
# Filter with unary operators
64+
expr = "~pc.field('bool_col')"
65+
expected_expr = ~ExpressionNode.column_ref('bool_col')
66+
compare_against_pyarrow(expr, expected_expr, lib)
67+
68+
# Filter with binary operators
69+
for op in ["+", "-", "*", "/"]:
70+
expr = f"pc.field('float_col') {op} 5.0 < 50.0"
71+
expected_expr = eval(f"ExpressionNode.column_ref('float_col') {op} 5.0 < 50.0")
72+
compare_against_pyarrow(expr, expected_expr, lib)
73+
74+
for op in ["&", "|"]:
75+
expr = f"pc.field('bool_col') {op} (pc.field('int_col') < 50)"
76+
expected_expr = eval(f"ExpressionNode.column_ref('bool_col') {op} (ExpressionNode.column_ref('int_col') < 50)")
77+
compare_against_pyarrow(expr, expected_expr, lib)
78+
79+
# Filter with expression method calls
80+
expr = "pc.field('str_col').isin(['str_0', 'str_10', 'str_20'])"
81+
expected_expr = ExpressionNode.column_ref('str_col').isin(['str_0', 'str_10', 'str_20'])
82+
compare_against_pyarrow(expr, expected_expr, lib)
83+
84+
expr = "pc.field('float_col').is_nan()"
85+
expected_expr = ExpressionNode.column_ref('float_col').isnull()
86+
# We expect a different result between adb and pyarrow because of the different nan/null handling
87+
compare_against_pyarrow(expr, expected_expr, lib, expect_equal=False)
88+
89+
expr = "pc.field('float_col').is_null()"
90+
expected_expr = ExpressionNode.column_ref('float_col').isnull()
91+
compare_against_pyarrow(expr, expected_expr, lib)
92+
93+
expr = "pc.field('float_col').is_valid()"
94+
expected_expr = ExpressionNode.column_ref('float_col').notnull()
95+
compare_against_pyarrow(expr, expected_expr, lib)
96+
97+
def test_complex_filters(lmdb_version_store_v1):
98+
lib = lmdb_version_store_v1
99+
100+
# Nested complex filters
101+
expr = "((pc.field('float_col') * 2) > 20.0) & (pc.field('int_col') <= pc.scalar(60)) | pc.field('bool_col')"
102+
expected_expr = (ExpressionNode.column_ref('float_col') * 2 > 20.0) & (ExpressionNode.column_ref('int_col') <= 60) | ExpressionNode.column_ref('bool_col')
103+
compare_against_pyarrow(expr, expected_expr, lib)
104+
105+
expr = "((pc.field('float_col') / 2) > 20.0) & (pc.field('float_col') <= pc.scalar(60)) & pc.field('str_col').isin(['str_30', 'str_41', 'str_42', 'str_53', 'str_99'])"
106+
expected_expr = (ExpressionNode.column_ref('float_col') / 2 > 20.0) & (ExpressionNode.column_ref('float_col') <= 60) & ExpressionNode.column_ref('str_col').isin(['str_30', 'str_41', 'str_42', 'str_53', 'str_99'])
107+
compare_against_pyarrow(expr, expected_expr, lib)
108+
109+
# Filters with function calls
110+
function_map = {
111+
"datetime.datetime": datetime.datetime,
112+
"abs": abs,
113+
}
114+
expr = "pc.field('datetime_col') < datetime.datetime(2025, 1, 20)"
115+
expected_expr = ExpressionNode.column_ref('datetime_col') < datetime.datetime(2025, 1, 20)
116+
compare_against_pyarrow(expr, expected_expr, lib, function_map)
117+
118+
expr = "(pc.field('datetime_col') < datetime.datetime(2025, 1, abs(-20))) & (pc.field('int_col') >= abs(-5))"
119+
expected_expr = (ExpressionNode.column_ref('datetime_col') < datetime.datetime(2025, 1, abs(-20))) & (ExpressionNode.column_ref('int_col') >= abs(-5))
120+
compare_against_pyarrow(expr, expected_expr, lib, function_map)
121+
122+
def test_broken_filters():
123+
# ill-formated filter
124+
expr = "pc.field('float_col'"
125+
with pytest.raises(ValueError):
126+
ExpressionNode.from_pyarrow_expression_str(expr)
127+
128+
# pyarrow expressions only support single comparisons
129+
expr = "1 < pc.field('int_col') < 10"
130+
with pytest.raises(ValueError):
131+
ExpressionNode.from_pyarrow_expression_str(expr)
132+
133+
# calling a mising function
134+
expr = "some.missing.function(5)"
135+
with pytest.raises(ValueError):
136+
ExpressionNode.from_pyarrow_expression_str(expr)

0 commit comments

Comments
 (0)