Skip to content

Commit 7b2d8fc

Browse files
committed
feat: upcast schemas if needed during set ops
1 parent 0044845 commit 7b2d8fc

File tree

8 files changed

+145
-18
lines changed

8 files changed

+145
-18
lines changed
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
SELECT
2+
*
3+
FROM (
4+
SELECT
5+
*
6+
FROM (
7+
SELECT
8+
"t0"."id",
9+
CAST("t0"."tinyint_col" AS BIGINT) AS "i",
10+
CAST(CAST("t0"."string_col" AS TEXT) AS TEXT) AS "s"
11+
FROM "functional_alltypes" AS "t0"
12+
) AS "t1"
13+
UNION ALL
14+
SELECT
15+
*
16+
FROM (
17+
SELECT
18+
"t0"."id",
19+
CAST("t0"."bigint_col" + 256 AS BIGINT) AS "i",
20+
CAST("t0"."string_col" AS TEXT) AS "s"
21+
FROM "functional_alltypes" AS "t0"
22+
) AS "t2"
23+
) AS "t3"
24+
ORDER BY
25+
"t3"."id" ASC,
26+
"t3"."i" ASC,
27+
"t3"."s" ASC
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
SELECT
2+
*
3+
FROM (
4+
SELECT
5+
*
6+
FROM (
7+
SELECT
8+
"t0"."id",
9+
CAST("t0"."tinyint_col" AS BIGINT) AS "i",
10+
CAST(CAST("t0"."string_col" AS TEXT) AS TEXT) AS "s"
11+
FROM "functional_alltypes" AS "t0"
12+
) AS "t1"
13+
UNION
14+
SELECT
15+
*
16+
FROM (
17+
SELECT
18+
"t0"."id",
19+
CAST("t0"."bigint_col" + 256 AS BIGINT) AS "i",
20+
CAST("t0"."string_col" AS TEXT) AS "s"
21+
FROM "functional_alltypes" AS "t0"
22+
) AS "t2"
23+
) AS "t3"
24+
ORDER BY
25+
"t3"."id" ASC,
26+
"t3"."i" ASC,
27+
"t3"."s" ASC

ibis/backends/tests/sql/test_sql.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -663,3 +663,23 @@ def test_ctes_in_order():
663663

664664
sql = ibis.to_sql(expr, dialect="duckdb")
665665
assert sql.find('"first" AS (') < sql.find('"second" AS (')
666+
667+
668+
@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"])
669+
# @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
670+
def test_union_unified_schemas(snapshot, functional_alltypes, distinct):
671+
a = functional_alltypes.select(
672+
"id", i="tinyint_col", s=_.string_col.cast("!string")
673+
)
674+
b = functional_alltypes.select(
675+
"id",
676+
i=_.bigint_col + 256, # ensure doesn't fit in a tinyint
677+
s=_.string_col.cast("string"),
678+
)
679+
680+
# print(base.schema())
681+
# assert False
682+
expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s")
683+
assert expr.i.type() == b.i.type()
684+
assert expr.s.type() == b.s.type()
685+
snapshot.assert_match(to_sql(expr), "out.sql")

ibis/backends/tests/test_set_ops.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,36 @@ def test_union(backend, union_subsets, distinct):
4949
backend.assert_frame_equal(result, expected)
5050

5151

52+
@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"])
53+
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
54+
def test_unified_schemas(backend, con, distinct):
55+
a = con.table("functional_alltypes").select(
56+
"id",
57+
i="tinyint_col",
58+
s=_.string_col.cast("!string"),
59+
)
60+
b = con.table("functional_alltypes").select(
61+
"id",
62+
i=_.bigint_col + 256, # ensure doesn't fit in a tinyint
63+
s=_.string_col.cast("string"),
64+
)
65+
66+
expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s")
67+
assert expr.i.type() == b.i.type()
68+
assert expr.s.type() == b.s.type()
69+
result = expr.execute()
70+
71+
expected = (
72+
pd.concat([a.execute(), b.execute()], axis=0)
73+
.sort_values(["id", "i", "s"])
74+
.reset_index(drop=True)
75+
)
76+
if distinct:
77+
expected = expected.drop_duplicates(["id", "i", "s"])
78+
79+
backend.assert_frame_equal(result, expected, check_dtype=False)
80+
81+
5282
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
5383
def test_union_mixed_distinct(backend, union_subsets):
5484
(a, b, c), (da, db, dc) = union_subsets

ibis/expr/operations/relations.py

Lines changed: 38 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
import ibis.expr.datatypes as dt
1414
from ibis.common.annotations import attribute
1515
from ibis.common.collections import (
16-
ConflictingValuesError,
1716
FrozenDict,
1817
FrozenOrderedDict,
1918
)
@@ -338,19 +337,44 @@ class Set(Relation):
338337
values = FrozenOrderedDict()
339338

340339
def __init__(self, left, right, **kwargs):
341-
err_msg = "Table schemas must be equal for set operations."
342-
try:
343-
missing_from_left = right.schema - left.schema
344-
missing_from_right = left.schema - right.schema
345-
except ConflictingValuesError as e:
346-
raise RelationError(err_msg + "\n" + str(e)) from e
347-
if missing_from_left or missing_from_right:
348-
msgs = [err_msg]
349-
if missing_from_left:
350-
msgs.append(f"Columns missing from the left:\n{missing_from_left}.")
351-
if missing_from_right:
352-
msgs.append(f"Columns missing from the right:\n{missing_from_right}.")
353-
raise RelationError("\n".join(msgs))
340+
# TODO: hoist this up into the user facing API so we can see
341+
# all the tables at once and give a better error message
342+
errs = ["Table schemas must be unifiable for set operations."]
343+
missing_from_left = set(right.schema.names) - set(left.schema.names)
344+
missing_from_right = set(left.schema.names) - set(right.schema.names)
345+
if missing_from_left:
346+
errs.append(f"Columns missing from the left:\n{missing_from_left}.")
347+
if missing_from_right:
348+
errs.append(f"Columns missing from the right:\n{missing_from_right}.")
349+
if len(errs) > 1:
350+
raise RelationError("\n".join(errs))
351+
352+
upcasts = {}
353+
for name in left.schema.names:
354+
ltype, rtype = left.schema[name], right.schema[name]
355+
try:
356+
unified_dt = dt.highest_precedence([ltype, rtype])
357+
if unified_dt != ltype or unified_dt != rtype:
358+
upcasts[name] = unified_dt
359+
except IbisTypeError:
360+
errs.append(f"Unable to find a common dtype for column {name}")
361+
errs.append(f"Left dtype: {ltype!s}")
362+
errs.append(f"Right dtype: {rtype!s}")
363+
if len(errs) > 1:
364+
raise RelationError("\n".join(errs))
365+
366+
if upcasts:
367+
from ibis.expr.operations.generic import Cast
368+
369+
def get_new_val(relation, name):
370+
if name not in upcasts:
371+
return Field(relation, name)
372+
return Cast(Field(relation, name), upcasts[name])
373+
374+
lcols = {name: get_new_val(left, name) for name in left.schema.names}
375+
rcols = {name: get_new_val(right, name) for name in left.schema.names}
376+
left = Project(left, lcols)
377+
right = Project(right, rcols)
354378

355379
if left.schema.names != right.schema.names:
356380
# rewrite so that both sides have the columns in the same order making it

ibis/tests/expr/test_set_operations.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class D:
4040

4141
@pytest.mark.parametrize("method", ["union", "intersect", "difference"])
4242
def test_operation_requires_equal_schemas(method):
43-
with pytest.raises(RelationError, match="`c`: string != float64"):
43+
with pytest.raises(RelationError, match="Left dtype: int64\nRight dtype: string"):
4444
getattr(a, method)(d)
4545

4646

ibis/tests/expr/test_table.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525

2626
set_ops_schema_top = [("key", "string"), ("value", "double")]
2727
set_ops_schema_bottom = [("key", "string"), ("key2", "string"), ("value", "double")]
28-
setops_relation_error_message = "Table schemas must be equal for set operations"
28+
setops_relation_error_message = "Table schemas must be unifiable for set operations"
2929

3030

3131
@pytest.fixture

ibis/tests/expr/test_value_exprs.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -641,8 +641,7 @@ def test_null_column():
641641
def test_null_column_union():
642642
s = ibis.table([("a", "string"), ("b", "double")])
643643
t = ibis.table([("a", "string")])
644-
with pytest.raises(ibis.common.exceptions.RelationError):
645-
s.union(t.mutate(b=ibis.null())) # needs a type
644+
assert s.union(t.mutate(b=ibis.null())).schema() == s.schema()
646645
assert s.union(t.mutate(b=ibis.null().cast("double"))).schema() == s.schema()
647646

648647

0 commit comments

Comments
 (0)