Skip to content

Commit 4327520

Browse files
committed
feat: Improve array(), map(), and struct
fixes ibis-project#8289 This does a lot of changes. It was hard for me to separate them out as I implemented them. But now that it's all hashed out, I can try to split this up into separate commits if you want. But that might be sorta hard in some cases. One this is adding support for passing in None to all these constructors. These use the new `ibis.null(<type>)` API to return `op.Literal(None, <type>)`s Make these constructors idempotent: you can pass in existing Expressions into array(), etc. The type argument for all of these now always has an effect, not just when passing in python literals. So basically it acts like a cast. A big structural change is that now ops.Array has an optional attribute "dtype", so if you pass in a 0-length sequence of values the op still knows what dtype it is. Several of the backends were always broken here, they just weren't getting caught. I marked them as broken, we can fix them in a followup. You can test this locally with eg `pytest -m <backend> -k factory ibis/backends/tests/test_array.py ibis/backends/tests/test_map.py ibis/backends/tests/test_struct.py` Also, fix a typing bug: map() can accept ArrayValues, not just ArrayColumns. Also, fix executing Literal(None) on pandas and polars, 0-length arrays on polars Also, fixing converting dtypes on clickhouse, Structs should be converted to nonnullable dtypes. Also, implement ops.StructColumn on pandas and dask
1 parent 142c105 commit 4327520

File tree

20 files changed

+434
-178
lines changed

20 files changed

+434
-178
lines changed

ibis/backends/dask/executor.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,11 +155,17 @@ def mapper(df, cases):
155155
return cls.partitionwise(mapper, kwargs, name=op.name, dtype=dtype)
156156

157157
@classmethod
158-
def visit(cls, op: ops.Array, exprs):
158+
def visit(cls, op: ops.Array, exprs, dtype):
159159
return cls.rowwise(
160160
lambda row: np.array(row, dtype=object), exprs, name=op.name, dtype=object
161161
)
162162

163+
@classmethod
164+
def visit(cls, op: ops.StructColumn, names, values):
165+
return cls.rowwise(
166+
lambda row: dict(zip(names, row)), values, name=op.name, dtype=object
167+
)
168+
163169
@classmethod
164170
def visit(cls, op: ops.ArrayConcat, arg):
165171
dtype = PandasType.from_ibis(op.dtype)

ibis/backends/dask/helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def concat(cls, dfs, **kwargs):
2727

2828
@classmethod
2929
def asseries(cls, value, like=None):
30-
"""Ensure that value is a pandas Series object, broadcast if necessary."""
30+
"""Ensure that value is a dask Series object, broadcast if necessary."""
3131

3232
if isinstance(value, dd.Series):
3333
return value
@@ -47,7 +47,7 @@ def asseries(cls, value, like=None):
4747
elif isinstance(value, pd.Series):
4848
return dd.from_pandas(value, npartitions=1)
4949
elif like is not None:
50-
if isinstance(value, (tuple, list, dict)):
50+
if isinstance(value, (tuple, list, dict, np.ndarray)):
5151
fn = lambda df: pd.Series([value] * len(df), index=df.index)
5252
else:
5353
fn = lambda df: pd.Series(value, index=df.index)

ibis/backends/exasol/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ class ExasolCompiler(SQLGlotCompiler):
7575
ops.StringSplit,
7676
ops.StringToDate,
7777
ops.StringToTimestamp,
78+
ops.StructColumn,
7879
ops.TimeDelta,
7980
ops.TimestampAdd,
8081
ops.TimestampBucket,

ibis/backends/pandas/executor.py

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -49,12 +49,14 @@ def visit(cls, op: ops.Node, **kwargs):
4949

5050
@classmethod
5151
def visit(cls, op: ops.Literal, value, dtype):
52+
if value is None:
53+
return None
5254
if dtype.is_interval():
53-
value = pd.Timedelta(value, dtype.unit.short)
54-
elif dtype.is_array():
55-
value = np.array(value)
56-
elif dtype.is_date():
57-
value = pd.Timestamp(value, tz="UTC").tz_localize(None)
55+
return pd.Timedelta(value, dtype.unit.short)
56+
if dtype.is_array():
57+
return np.array(value)
58+
if dtype.is_date():
59+
return pd.Timestamp(value, tz="UTC").tz_localize(None)
5860
return value
5961

6062
@classmethod
@@ -220,9 +222,13 @@ def visit(cls, op: ops.FindInSet, needle, values):
220222
return pd.Series(result, name=op.name)
221223

222224
@classmethod
223-
def visit(cls, op: ops.Array, exprs):
225+
def visit(cls, op: ops.Array, exprs, dtype):
224226
return cls.rowwise(lambda row: np.array(row, dtype=object), exprs)
225227

228+
@classmethod
229+
def visit(cls, op: ops.StructColumn, names, values):
230+
return cls.rowwise(lambda row: dict(zip(names, row)), values)
231+
226232
@classmethod
227233
def visit(cls, op: ops.ArrayConcat, arg):
228234
return cls.rowwise(lambda row: np.concatenate(row.values), arg)

ibis/backends/polars/compiler.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,14 @@ def _make_duration(value, dtype):
8686
def literal(op, **_):
8787
value = op.value
8888
dtype = op.dtype
89+
if dtype.is_interval():
90+
return _make_duration(value, dtype)
8991

90-
if dtype.is_array():
92+
typ = PolarsType.from_ibis(dtype)
93+
if value is None:
94+
return pl.lit(None, dtype=typ)
95+
elif dtype.is_array():
9196
value = pl.Series("", value)
92-
typ = PolarsType.from_ibis(dtype)
9397
val = pl.lit(value, dtype=typ)
9498
return val.implode()
9599
elif dtype.is_struct():
@@ -98,14 +102,11 @@ def literal(op, **_):
98102
for k, v in value.items()
99103
]
100104
return pl.struct(values)
101-
elif dtype.is_interval():
102-
return _make_duration(value, dtype)
103105
elif dtype.is_null():
104106
return pl.lit(value)
105107
elif dtype.is_binary():
106108
return pl.lit(value)
107109
else:
108-
typ = PolarsType.from_ibis(dtype)
109110
return pl.lit(op.value, dtype=typ)
110111

111112

@@ -980,9 +981,11 @@ def array_concat(op, **kw):
980981

981982

982983
@translate.register(ops.Array)
983-
def array_column(op, **kw):
984-
cols = [translate(col, **kw) for col in op.exprs]
985-
return pl.concat_list(cols)
984+
def array_literal(op, **kw):
985+
if len(op.exprs) > 0:
986+
return pl.concat_list([translate(col, **kw) for col in op.exprs])
987+
else:
988+
return pl.lit([], dtype=PolarsType.from_ibis(op.dtype))
986989

987990

988991
@translate.register(ops.ArrayCollect)

ibis/backends/sql/compiler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -970,8 +970,11 @@ def visit_InSubquery(self, op, *, rel, needle):
970970
query = sg.select(STAR).from_(query)
971971
return needle.isin(query=query)
972972

973-
def visit_Array(self, op, *, exprs):
974-
return self.f.array(*exprs)
973+
def visit_Array(self, op, *, exprs, dtype):
974+
result = self.f.array(*exprs)
975+
if len(exprs) == 0:
976+
return self.cast(result, dtype)
977+
return result
975978

976979
def visit_StructColumn(self, op, *, names, values):
977980
return sge.Struct.from_arg_list(

ibis/backends/sql/datatypes.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,8 +1007,10 @@ class ClickHouseType(SqlglotType):
10071007
def from_ibis(cls, dtype: dt.DataType) -> sge.DataType:
10081008
"""Convert a sqlglot type to an ibis type."""
10091009
typ = super().from_ibis(dtype)
1010-
if dtype.nullable and not (dtype.is_map() or dtype.is_array()):
1011-
# map cannot be nullable in clickhouse
1010+
# nested types cannot be nullable in clickhouse
1011+
if dtype.nullable and not (
1012+
dtype.is_map() or dtype.is_array() or dtype.is_struct()
1013+
):
10121014
return sge.DataType(this=typecode.NULLABLE, expressions=[typ])
10131015
else:
10141016
return typ

ibis/backends/sqlite/compiler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,7 @@ class SQLiteCompiler(SQLGlotCompiler):
6060
ops.TimestampDiff,
6161
ops.StringToDate,
6262
ops.StringToTimestamp,
63+
ops.StructColumn,
6364
ops.TimeDelta,
6465
ops.DateDelta,
6566
ops.TimestampDelta,

ibis/backends/tests/test_array.py

Lines changed: 67 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
PySparkAnalysisException,
3232
TrinoUserError,
3333
)
34+
from ibis.common.annotations import ValidationError
3435
from ibis.common.collections import frozendict
3536

3637
pytestmark = [
@@ -72,6 +73,72 @@
7273
# list.
7374

7475

76+
def test_array_factory(con):
77+
a = ibis.array([1, 2, 3])
78+
assert a.type() == dt.Array(value_type=dt.Int8)
79+
assert con.execute(a) == [1, 2, 3]
80+
81+
a2 = ibis.array(a)
82+
assert a.type() == dt.Array(value_type=dt.Int8)
83+
assert con.execute(a2) == [1, 2, 3]
84+
85+
86+
def test_array_factory_typed(con):
87+
typed = ibis.array([1, 2, 3], type="array<string>")
88+
assert con.execute(typed) == ["1", "2", "3"]
89+
90+
typed2 = ibis.array(ibis.array([1, 2, 3]), type="array<string>")
91+
assert con.execute(typed2) == ["1", "2", "3"]
92+
93+
94+
@pytest.mark.notimpl("flink", raises=Py4JJavaError)
95+
@pytest.mark.notimpl(["pandas", "dask"], raises=ValueError)
96+
def test_array_factory_empty(con):
97+
with pytest.raises(ValidationError):
98+
ibis.array([])
99+
100+
empty_typed = ibis.array([], type="array<string>")
101+
assert empty_typed.type() == dt.Array(value_type=dt.string)
102+
assert con.execute(empty_typed) == []
103+
104+
105+
@pytest.mark.notyet(
106+
"clickhouse", raises=ClickHouseDatabaseError, reason="nested types can't be NULL"
107+
)
108+
def test_array_factory_null(con):
109+
with pytest.raises(ValidationError):
110+
ibis.array(None)
111+
with pytest.raises(ValidationError):
112+
ibis.array(None, type="int64")
113+
none_typed = ibis.array(None, type="array<string>")
114+
assert none_typed.type() == dt.Array(value_type=dt.string)
115+
assert con.execute(none_typed) is None
116+
117+
nones = ibis.array([None, None], type="array<string>")
118+
assert nones.type() == dt.Array(value_type=dt.string)
119+
assert con.execute(nones) == [None, None]
120+
121+
# Execute a real value here, so the backends that don't support arrays
122+
# actually xfail as we expect them to.
123+
# Otherwise would have to @mark.xfail every test in this file besides this one.
124+
assert con.execute(ibis.array([1, 2])) == [1, 2]
125+
126+
127+
@pytest.mark.broken(
128+
["datafusion", "polars"],
129+
raises=AssertionError,
130+
reason="[None, 1] executes to [np.nan, 1.0]",
131+
)
132+
def test_array_factory_null_mixed(con):
133+
none_and_val = ibis.array([None, 1])
134+
assert none_and_val.type() == dt.Array(value_type=dt.Int8)
135+
assert con.execute(none_and_val) == [None, 1]
136+
137+
none_and_val_typed = ibis.array([None, 1], type="array<string>")
138+
assert none_and_val_typed.type() == dt.Array(value_type=dt.String)
139+
assert con.execute(none_and_val_typed) == [None, "1"]
140+
141+
75142
def test_array_column(backend, alltypes, df):
76143
expr = ibis.array(
77144
[alltypes["double_col"], alltypes["double_col"], 5.0, ibis.literal(6.0)]
@@ -912,11 +979,6 @@ def test_zip_null(con, fn):
912979

913980

914981
@builtin_array
915-
@pytest.mark.notyet(
916-
["clickhouse"],
917-
raises=ClickHouseDatabaseError,
918-
reason="https://github.com/ClickHouse/ClickHouse/issues/41112",
919-
)
920982
@pytest.mark.notimpl(["postgres"], raises=PsycoPg2SyntaxError)
921983
@pytest.mark.notimpl(["risingwave"], raises=PsycoPg2ProgrammingError)
922984
@pytest.mark.notimpl(["datafusion"], raises=com.OperationNotDefinedError)

ibis/backends/tests/test_generic.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1230,9 +1230,7 @@ def query(t, group_cols):
12301230
snapshot.assert_match(str(ibis.to_sql(t3, dialect=con.name)), "out.sql")
12311231

12321232

1233-
@pytest.mark.notimpl(
1234-
["dask", "pandas", "oracle", "exasol"], raises=com.OperationNotDefinedError
1235-
)
1233+
@pytest.mark.notimpl(["oracle", "exasol"], raises=com.OperationNotDefinedError)
12361234
@pytest.mark.notimpl(["druid"], raises=AssertionError)
12371235
@pytest.mark.notyet(
12381236
["datafusion", "impala", "mssql", "mysql", "sqlite"],

0 commit comments

Comments
 (0)