Skip to content

Commit

Permalink
feat: upcast schemas if needed during set ops
Browse files Browse the repository at this point in the history
  • Loading branch information
NickCrews committed Jan 26, 2025
1 parent 0044845 commit 331145b
Show file tree
Hide file tree
Showing 8 changed files with 150 additions and 19 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
SELECT
*
FROM (
SELECT
*
FROM (
SELECT
"t0"."id",
CAST("t0"."tinyint_col" AS BIGINT) AS "i",
CAST(CAST("t0"."string_col" AS TEXT) AS TEXT) AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t1"
UNION ALL
SELECT
*
FROM (
SELECT
"t0"."id",
CAST("t0"."bigint_col" + 256 AS BIGINT) AS "i",
CAST("t0"."string_col" AS TEXT) AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t2"
) AS "t3"
ORDER BY
"t3"."id" ASC,
"t3"."i" ASC,
"t3"."s" ASC
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
SELECT
*
FROM (
SELECT
*
FROM (
SELECT
"t0"."id",
CAST("t0"."tinyint_col" AS BIGINT) AS "i",
CAST(CAST("t0"."string_col" AS TEXT) AS TEXT) AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t1"
UNION
SELECT
*
FROM (
SELECT
"t0"."id",
CAST("t0"."bigint_col" + 256 AS BIGINT) AS "i",
CAST("t0"."string_col" AS TEXT) AS "s"
FROM "functional_alltypes" AS "t0"
) AS "t2"
) AS "t3"
ORDER BY
"t3"."id" ASC,
"t3"."i" ASC,
"t3"."s" ASC
17 changes: 17 additions & 0 deletions ibis/backends/tests/sql/test_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -663,3 +663,20 @@ def test_ctes_in_order():

sql = ibis.to_sql(expr, dialect="duckdb")
assert sql.find('"first" AS (') < sql.find('"second" AS (')


@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"])
def test_union_unified_schemas(snapshot, functional_alltypes, distinct):
a = functional_alltypes.select(
"id", i="tinyint_col", s=_.string_col.cast("!string")
)
b = functional_alltypes.select(
"id",
i=_.bigint_col + 256, # ensure doesn't fit in a tinyint
s=_.string_col.cast("string"),
)
expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s")

assert expr.i.type() == b.i.type()
assert expr.s.type() == b.s.type()
snapshot.assert_match(to_sql(expr), "out.sql")
39 changes: 38 additions & 1 deletion ibis/backends/tests/test_set_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@
import ibis
import ibis.expr.types as ir
from ibis import _
from ibis.backends.tests.errors import PsycoPg2InternalError, PyDruidProgrammingError
from ibis.backends.tests.errors import (
OracleDatabaseError,
PsycoPg2InternalError,
PyDruidProgrammingError,
)

pd = pytest.importorskip("pandas")

Expand Down Expand Up @@ -49,6 +53,39 @@ def test_union(backend, union_subsets, distinct):
backend.assert_frame_equal(result, expected)


@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"])
@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
@pytest.mark.notyet(
["oracle"], raises=OracleDatabaseError, reason="does not support NOT NULL types"
)
def test_unified_schemas(backend, con, distinct):
a = con.table("functional_alltypes").select(
"id",
i="tinyint_col",
s=_.string_col.cast("!string"),
)
b = con.table("functional_alltypes").select(
"id",
i=_.bigint_col + 256, # ensure doesn't fit in a tinyint
s=_.string_col.cast("string"),
)

expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s")
assert expr.i.type() == b.i.type()
assert expr.s.type() == b.s.type()
result = expr.execute()

expected = (
pd.concat([a.execute(), b.execute()], axis=0)
.sort_values(["id", "i", "s"])
.reset_index(drop=True)
)
if distinct:
expected = expected.drop_duplicates(["id", "i", "s"])

backend.assert_frame_equal(result, expected, check_dtype=False)


@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError)
def test_union_mixed_distinct(backend, union_subsets):
(a, b, c), (da, db, dc) = union_subsets
Expand Down
52 changes: 38 additions & 14 deletions ibis/expr/operations/relations.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
import ibis.expr.datatypes as dt
from ibis.common.annotations import attribute
from ibis.common.collections import (
ConflictingValuesError,
FrozenDict,
FrozenOrderedDict,
)
Expand Down Expand Up @@ -338,19 +337,44 @@ class Set(Relation):
values = FrozenOrderedDict()

def __init__(self, left, right, **kwargs):
err_msg = "Table schemas must be equal for set operations."
try:
missing_from_left = right.schema - left.schema
missing_from_right = left.schema - right.schema
except ConflictingValuesError as e:
raise RelationError(err_msg + "\n" + str(e)) from e
if missing_from_left or missing_from_right:
msgs = [err_msg]
if missing_from_left:
msgs.append(f"Columns missing from the left:\n{missing_from_left}.")
if missing_from_right:
msgs.append(f"Columns missing from the right:\n{missing_from_right}.")
raise RelationError("\n".join(msgs))
# TODO: hoist this up into the user facing API so we can see
# all the tables at once and give a better error message
errs = ["Table schemas must be unifiable for set operations."]
missing_from_left = set(right.schema.names) - set(left.schema.names)
missing_from_right = set(left.schema.names) - set(right.schema.names)
if missing_from_left:
errs.append(f"Columns missing from the left:\n{missing_from_left}.")
if missing_from_right:
errs.append(f"Columns missing from the right:\n{missing_from_right}.")

Check warning on line 348 in ibis/expr/operations/relations.py

View check run for this annotation

Codecov / codecov/patch

ibis/expr/operations/relations.py#L348

Added line #L348 was not covered by tests
if len(errs) > 1:
raise RelationError("\n".join(errs))

upcasts = {}
for name in left.schema.names:
ltype, rtype = left.schema[name], right.schema[name]
try:
unified_dt = dt.highest_precedence([ltype, rtype])
if unified_dt != ltype or unified_dt != rtype:
upcasts[name] = unified_dt
except IbisTypeError:
errs.append(f"Unable to find a common dtype for column {name}")
errs.append(f"Left dtype: {ltype!s}")
errs.append(f"Right dtype: {rtype!s}")
if len(errs) > 1:
raise RelationError("\n".join(errs))

if upcasts:
from ibis.expr.operations.generic import Cast

def get_new_val(relation, name):
if name not in upcasts:
return Field(relation, name)
return Cast(Field(relation, name), upcasts[name])

lcols = {name: get_new_val(left, name) for name in left.schema.names}
rcols = {name: get_new_val(right, name) for name in left.schema.names}
left = Project(left, lcols)
right = Project(right, rcols)

if left.schema.names != right.schema.names:
# rewrite so that both sides have the columns in the same order making it
Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_set_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class D:

@pytest.mark.parametrize("method", ["union", "intersect", "difference"])
def test_operation_requires_equal_schemas(method):
with pytest.raises(RelationError, match="`c`: string != float64"):
with pytest.raises(RelationError, match="Left dtype: int64\nRight dtype: string"):
getattr(a, method)(d)


Expand Down
2 changes: 1 addition & 1 deletion ibis/tests/expr/test_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

set_ops_schema_top = [("key", "string"), ("value", "double")]
set_ops_schema_bottom = [("key", "string"), ("key2", "string"), ("value", "double")]
setops_relation_error_message = "Table schemas must be equal for set operations"
setops_relation_error_message = "Table schemas must be unifiable for set operations"


@pytest.fixture
Expand Down
3 changes: 1 addition & 2 deletions ibis/tests/expr/test_value_exprs.py
Original file line number Diff line number Diff line change
Expand Up @@ -641,8 +641,7 @@ def test_null_column():
def test_null_column_union():
s = ibis.table([("a", "string"), ("b", "double")])
t = ibis.table([("a", "string")])
with pytest.raises(ibis.common.exceptions.RelationError):
s.union(t.mutate(b=ibis.null())) # needs a type
assert s.union(t.mutate(b=ibis.null())).schema() == s.schema()
assert s.union(t.mutate(b=ibis.null().cast("double"))).schema() == s.schema()


Expand Down

0 comments on commit 331145b

Please sign in to comment.