Skip to content

Commit fb22996

Browse files
committed
feat: Improve array(), map(), and struct
fixes ibis-project#8289 This does a lot of changes. It was hard for me to separate them out as I implemented them. But now that it's all hashed out, I can try to split this up into separate commits if you want. But that might be sorta hard in some cases. One this is adding support for passing in None to all these constructors. These use the new `ibis.null(<type>)` API to return `op.Literal(None, <type>)`s Make these constructors idempotent: you can pass in existing Expressions into array(), etc. The type argument for all of these now always has an effect, not just when passing in python literals. So basically it acts like a cast. A big structural change is that now ops.Array has an optional attribute "dtype", so if you pass in a 0-length sequence of values the op still knows what dtype it is. Several of the backends were always broken here, they just weren't getting caught. I marked them as broken, we can fix them in a followup. You can test this locally with eg `pytest -m <backend> -k factory ibis/backends/tests/test_array.py ibis/backends/tests/test_map.py ibis/backends/tests/test_struct.py` Also, fix a typing bug: map() can accept ArrayValues, not just ArrayColumns. Also, fix executing Literal(None) on pandas and polars, 0-length arrays on polars Also, fixing converting dtypes on clickhouse, Structs should be converted to nonnullable dtypes. Also, implement ops.StructColumn on pandas and dask
1 parent 4707c44 commit fb22996

File tree

18 files changed

+393
-150
lines changed

18 files changed

+393
-150
lines changed

ibis/backends/dask/executor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,17 @@ def mapper(df, cases):
155155
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)
156156

157157
@classmethod
158-
def visit(cls, op: ops.Array, exprs):
158+
def visit(cls, op: ops.Array, exprs, dtype):
159159
return cls.rowwise(
160160
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
161161
)
162162

163+
@classmethod
164+
def visit(cls, op: ops.StructColumn, names, values):
165+
return cls.rowwise(
166+
lambda row: dict(zip(names, row)), values, name=op.name, dtype=object
167+
)
168+
163169
@classmethod
164170
def visit(cls, op: ops.ArrayConcat, arg):
165171
dtype = PandasType.from_ibis(op.dtype)

ibis/backends/exasol/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class ExasolCompiler(SQLGlotCompiler):
7575
ops.StringSplit,
7676
ops.StringToDate,
7777
ops.StringToTimestamp,
78+
ops.StructColumn,
7879
ops.TimeDelta,
7980
ops.TimestampAdd,
8081
ops.TimestampBucket,

ibis/backends/pandas/executor.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ def visit(cls, op: ops.Node, **kwargs):
4949

5050
@classmethod
5151
def visit(cls, op: ops.Literal, value, dtype):
52+
if value is None:
53+
return None
5254
if dtype.is_interval():
53-
value = pd.Timedelta(value, dtype.unit.short)
54-
elif dtype.is_array():
55-
value = np.array(value)
56-
elif dtype.is_date():
57-
value = pd.Timestamp(value, tz="UTC").tz_localize(None)
55+
return pd.Timedelta(value, dtype.unit.short)
56+
if dtype.is_array():
57+
return np.array(value)
58+
if dtype.is_date():
59+
return pd.Timestamp(value, tz="UTC").tz_localize(None)
5860
return value
5961

6062
@classmethod
@@ -220,9 +222,13 @@ def visit(cls, op: ops.FindInSet, needle, values):
220222
return pd.Series(result, name=op.name)
221223

222224
@classmethod
223-
def visit(cls, op: ops.Array, exprs):
225+
def visit(cls, op: ops.Array, exprs, dtype):
224226
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)
225227

228+
@classmethod
229+
def visit(cls, op: ops.StructColumn, names, values):
230+
return cls.rowwise(lambda row: dict(zip(names, row)), values)
231+
226232
@classmethod
227233
def visit(cls, op: ops.ArrayConcat, arg):
228234
return cls.rowwise(lambda row: np.concatenate(row.values), arg)

ibis/backends/polars/compiler.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,14 @@ def _make_duration(value, dtype):
8686
def literal(op, **_):
8787
value = op.value
8888
dtype = op.dtype
89+
if dtype.is_interval():
90+
return _make_duration(value, dtype)
8991

90-
if dtype.is_array():
92+
typ = PolarsType.from_ibis(dtype)
93+
if value is None:
94+
return pl.lit(None, dtype=typ)
95+
elif dtype.is_array():
9196
value = pl.Series("", value)
92-
typ = PolarsType.from_ibis(dtype)
9397
val = pl.lit(value, dtype=typ)
9498
return val.implode()
9599
elif dtype.is_struct():
@@ -98,14 +102,11 @@ def literal(op, **_):
98102
for k, v in value.items()
99103
]
100104
return pl.struct(values)
101-
elif dtype.is_interval():
102-
return _make_duration(value, dtype)
103105
elif dtype.is_null():
104106
return pl.lit(value)
105107
elif dtype.is_binary():
106108
return pl.lit(value)
107109
else:
108-
typ = PolarsType.from_ibis(dtype)
109110
return pl.lit(op.value, dtype=typ)
110111

111112

@@ -980,9 +981,11 @@ def array_concat(op, **kw):
980981

981982

982983
@translate.register(ops.Array)
983-
def array_column(op, **kw):
984-
cols = [translate(col, **kw) for col in op.exprs]
985-
return pl.concat_list(cols)
984+
def array_literal(op, **kw):
985+
if len(op.exprs) > 0:
986+
return pl.concat_list([translate(col, **kw) for col in op.exprs])
987+
else:
988+
return pl.lit([], dtype=PolarsType.from_ibis(op.dtype))
986989

987990

988991
@translate.register(ops.ArrayCollect)

ibis/backends/risingwave/compiler.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
import ibis.expr.datatypes as dt
99
import ibis.expr.operations as ops
1010
from ibis.backends.postgres.compiler import PostgresCompiler
11-
from ibis.backends.sql.compiler import ALL_OPERATIONS
11+
from ibis.backends.sql.compiler import ALL_OPERATIONS, SQLGlotCompiler
1212
from ibis.backends.sql.datatypes import RisingWaveType
1313
from ibis.backends.sql.dialects import RisingWave
1414

@@ -51,6 +51,11 @@ def visit_Correlation(self, op, *, left, right, how, where):
5151
op, left=left, right=right, how=how, where=where
5252
)
5353

54+
# def visit_StructColumn(self, op, *, names, values):
55+
# The parent Postgres compiler uses the ROW() function,
56+
# but the grandparent SQLGlot compiler uses the correct syntax
57+
# return SQLGlotCompiler.visit_StructColumn(self, op, names=names, values=values)
58+
5459
def visit_TimestampTruncate(self, op, *, arg, unit):
5560
unit_mapping = {
5661
"Y": "year",

ibis/backends/sql/compiler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -970,8 +970,11 @@ def visit_InSubquery(self, op, *, rel, needle):
970970
query = sg.select(STAR).from_(query)
971971
return needle.isin(query=query)
972972

973-
def visit_Array(self, op, *, exprs):
974-
return self.f.array(*exprs)
973+
def visit_Array(self, op, *, exprs, dtype):
974+
result = self.f.array(*exprs)
975+
if len(exprs) == 0:
976+
return self.cast(result, dtype)
977+
return result
975978

976979
def visit_StructColumn(self, op, *, names, values):
977980
return sge.Struct.from_arg_list(

ibis/backends/sql/datatypes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,8 +1007,10 @@ class ClickHouseType(SqlglotType):
10071007
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
10081008
"""Convert a sqlglot type to an ibis type."""
10091009
typ = super().from_ibis(dtype)
1010-
if dtype.nullable and not (dtype.is_map() or dtype.is_array()):
1011-
# map cannot be nullable in clickhouse
1010+
# nested types cannot be nullable in clickhouse
1011+
if dtype.nullable and not (
1012+
dtype.is_map() or dtype.is_array() or dtype.is_struct()
1013+
):
10121014
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
10131015
else:
10141016
return typ

ibis/backends/sqlite/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class SQLiteCompiler(SQLGlotCompiler):
6060
ops.TimestampDiff,
6161
ops.StringToDate,
6262
ops.StringToTimestamp,
63+
ops.StructColumn,
6364
ops.TimeDelta,
6465
ops.DateDelta,
6566
ops.TimestampDelta,

ibis/backends/tests/test_array.py

Lines changed: 54 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@
3030
PySparkAnalysisException,
3131
TrinoUserError,
3232
)
33+
from ibis.common.annotations import ValidationError
3334
from ibis.common.collections import frozendict
3435

3536
pytestmark = [
@@ -66,11 +67,64 @@
6667
pytest.mark.notimpl(["druid", "oracle"], raises=Exception),
6768
]
6869

70+
mark_notyet_datafusion = pytest.mark.notyet(
71+
"datafusion",
72+
raises=Exception,
73+
reason="datafusion can't handle array casts yet. https://github.com/apache/datafusion/issues/10464",
74+
)
75+
6976
# NB: We don't check whether results are numpy arrays or lists because this
7077
# varies across backends. At some point we should unify the result type to be
7178
# list.
7279

7380

81+
def test_array_factory(con):
82+
a = ibis.array([1, 2, 3])
83+
assert con.execute(a) == [1, 2, 3]
84+
85+
a2 = ibis.array(a)
86+
assert con.execute(a2) == [1, 2, 3]
87+
88+
89+
@mark_notyet_datafusion
90+
def test_array_factory_typed(con):
91+
typed = ibis.array([1, 2, 3], type="array<string>")
92+
assert con.execute(typed) == ["1", "2", "3"]
93+
94+
typed2 = ibis.array(ibis.array([1, 2, 3]), type="array<string>")
95+
assert con.execute(typed2) == ["1", "2", "3"]
96+
97+
98+
@mark_notyet_datafusion
99+
@pytest.mark.notimpl("flink", raises=Py4JJavaError)
100+
@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
101+
def test_array_factory_empty(con):
102+
with pytest.raises(ValidationError):
103+
ibis.array([])
104+
105+
empty_typed = ibis.array([], type="array<string>")
106+
assert empty_typed.type() == dt.Array(value_type=dt.string)
107+
assert con.execute(empty_typed) == []
108+
109+
110+
@mark_notyet_datafusion
111+
@pytest.mark.notyet(
112+
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
113+
)
114+
def test_array_factory_null(con):
115+
with pytest.raises(ValidationError):
116+
ibis.array(None)
117+
with pytest.raises(ValidationError):
118+
ibis.array(None, type="int64")
119+
none_typed = ibis.array(None, type="array<string>")
120+
assert none_typed.type() == dt.Array(value_type=dt.string)
121+
assert con.execute(none_typed) is None
122+
# Execute a real value here, so the backends that don't support arrays
123+
# actually xfail as we expect them to.
124+
# Otherwise would have to @mark.xfail every test in this file besides this one.
125+
assert con.execute(ibis.array([1, 2])) == [1, 2]
126+
127+
74128
def test_array_column(backend, alltypes, df):
75129
expr = ibis.array(
76130
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
@@ -913,11 +967,6 @@ def test_zip_null(con, fn):
913967

914968

915969
@builtin_array
916-
@pytest.mark.notyet(
917-
["clickhouse"],
918-
raises=ClickHouseDatabaseError,
919-
reason="https://github.com/ClickHouse/ClickHouse/issues/41112",
920-
)
921970
@pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError)
922971
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2ProgrammingError)
923972
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)

ibis/backends/tests/test_generic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,9 +1230,7 @@ def query(t, group_cols):
12301230
snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql")
12311231

12321232

1233-
@pytest.mark.notimpl(
1234-
["dask", "pandas", "oracle", "exasol"], raises=com.OperationNotDefinedError
1235-
)
1233+
@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError)
12361234
@pytest.mark.notimpl(["druid"], raises=AssertionError)
12371235
@pytest.mark.notyet(
12381236
["datafusion", "impala", "mssql", "mysql", "sqlite"],

0 commit comments

Comments
 (0)