From 331145b10cef7a1c2157f98477f2204d4cdc9c50 Mon Sep 17 00:00:00 2001 From: Nick Crews Date: Sun, 26 Jan 2025 09:59:14 -0900 Subject: [PATCH] feat: upcast schemas if needed during set ops --- .../test_union_unified_schemas/all/out.sql | 27 ++++++++++ .../distinct/out.sql | 27 ++++++++++ ibis/backends/tests/sql/test_sql.py | 17 ++++++ ibis/backends/tests/test_set_ops.py | 39 +++++++++++++- ibis/expr/operations/relations.py | 52 ++++++++++++++----- ibis/tests/expr/test_set_operations.py | 2 +- ibis/tests/expr/test_table.py | 2 +- ibis/tests/expr/test_value_exprs.py | 3 +- 8 files changed, 150 insertions(+), 19 deletions(-) create mode 100644 ibis/backends/tests/sql/snapshots/test_sql/test_union_unified_schemas/all/out.sql create mode 100644 ibis/backends/tests/sql/snapshots/test_sql/test_union_unified_schemas/distinct/out.sql diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_union_unified_schemas/all/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_union_unified_schemas/all/out.sql new file mode 100644 index 000000000000..83da5e696a7c --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_union_unified_schemas/all/out.sql @@ -0,0 +1,27 @@ +SELECT + * +FROM ( + SELECT + * + FROM ( + SELECT + "t0"."id", + CAST("t0"."tinyint_col" AS BIGINT) AS "i", + CAST(CAST("t0"."string_col" AS TEXT) AS TEXT) AS "s" + FROM "functional_alltypes" AS "t0" + ) AS "t1" + UNION ALL + SELECT + * + FROM ( + SELECT + "t0"."id", + CAST("t0"."bigint_col" + 256 AS BIGINT) AS "i", + CAST("t0"."string_col" AS TEXT) AS "s" + FROM "functional_alltypes" AS "t0" + ) AS "t2" +) AS "t3" +ORDER BY + "t3"."id" ASC, + "t3"."i" ASC, + "t3"."s" ASC \ No newline at end of file diff --git a/ibis/backends/tests/sql/snapshots/test_sql/test_union_unified_schemas/distinct/out.sql b/ibis/backends/tests/sql/snapshots/test_sql/test_union_unified_schemas/distinct/out.sql new file mode 100644 index 000000000000..1f5bb2cd5949 --- /dev/null +++ b/ibis/backends/tests/sql/snapshots/test_sql/test_union_unified_schemas/distinct/out.sql @@ -0,0 +1,27 @@ +SELECT + * +FROM ( + SELECT + * + FROM ( + SELECT + "t0"."id", + CAST("t0"."tinyint_col" AS BIGINT) AS "i", + CAST(CAST("t0"."string_col" AS TEXT) AS TEXT) AS "s" + FROM "functional_alltypes" AS "t0" + ) AS "t1" + UNION + SELECT + * + FROM ( + SELECT + "t0"."id", + CAST("t0"."bigint_col" + 256 AS BIGINT) AS "i", + CAST("t0"."string_col" AS TEXT) AS "s" + FROM "functional_alltypes" AS "t0" + ) AS "t2" +) AS "t3" +ORDER BY + "t3"."id" ASC, + "t3"."i" ASC, + "t3"."s" ASC \ No newline at end of file diff --git a/ibis/backends/tests/sql/test_sql.py b/ibis/backends/tests/sql/test_sql.py index bca6f56578f9..6d0c3b8d142f 100644 --- a/ibis/backends/tests/sql/test_sql.py +++ b/ibis/backends/tests/sql/test_sql.py @@ -663,3 +663,20 @@ def test_ctes_in_order(): sql = ibis.to_sql(expr, dialect="duckdb") assert sql.find('"first" AS (') < sql.find('"second" AS (') + + +@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"]) +def test_union_unified_schemas(snapshot, functional_alltypes, distinct): + a = functional_alltypes.select( + "id", i="tinyint_col", s=_.string_col.cast("!string") + ) + b = functional_alltypes.select( + "id", + i=_.bigint_col + 256, # ensure doesn't fit in a tinyint + s=_.string_col.cast("string"), + ) + expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s") + + assert expr.i.type() == b.i.type() + assert expr.s.type() == b.s.type() + snapshot.assert_match(to_sql(expr), "out.sql") diff --git a/ibis/backends/tests/test_set_ops.py b/ibis/backends/tests/test_set_ops.py index 699438169e59..8d8e4eeca175 100644 --- a/ibis/backends/tests/test_set_ops.py +++ b/ibis/backends/tests/test_set_ops.py @@ -8,7 +8,11 @@ import ibis import ibis.expr.types as ir from ibis import _ -from ibis.backends.tests.errors import PsycoPg2InternalError, PyDruidProgrammingError +from ibis.backends.tests.errors import ( + OracleDatabaseError, + PsycoPg2InternalError, + PyDruidProgrammingError, +) pd = pytest.importorskip("pandas") @@ -49,6 +53,39 @@ def test_union(backend, union_subsets, distinct): backend.assert_frame_equal(result, expected) +@pytest.mark.parametrize("distinct", [False, True], ids=["all", "distinct"]) +@pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) +@pytest.mark.notyet( + ["oracle"], raises=OracleDatabaseError, reason="does not support NOT NULL types" +) +def test_unified_schemas(backend, con, distinct): + a = con.table("functional_alltypes").select( + "id", + i="tinyint_col", + s=_.string_col.cast("!string"), + ) + b = con.table("functional_alltypes").select( + "id", + i=_.bigint_col + 256, # ensure doesn't fit in a tinyint + s=_.string_col.cast("string"), + ) + + expr = ibis.union(a, b, distinct=distinct).order_by("id", "i", "s") + assert expr.i.type() == b.i.type() + assert expr.s.type() == b.s.type() + result = expr.execute() + + expected = ( + pd.concat([a.execute(), b.execute()], axis=0) + .sort_values(["id", "i", "s"]) + .reset_index(drop=True) + ) + if distinct: + expected = expected.drop_duplicates(["id", "i", "s"]) + + backend.assert_frame_equal(result, expected, check_dtype=False) + + @pytest.mark.notimpl(["druid"], raises=PyDruidProgrammingError) def test_union_mixed_distinct(backend, union_subsets): (a, b, c), (da, db, dc) = union_subsets diff --git a/ibis/expr/operations/relations.py b/ibis/expr/operations/relations.py index 90dc400a8ded..36acc70eb4db 100644 --- a/ibis/expr/operations/relations.py +++ b/ibis/expr/operations/relations.py @@ -13,7 +13,6 @@ import ibis.expr.datatypes as dt from ibis.common.annotations import attribute from ibis.common.collections import ( - ConflictingValuesError, FrozenDict, FrozenOrderedDict, ) @@ -338,19 +337,44 @@ class Set(Relation): values = FrozenOrderedDict() def __init__(self, left, right, **kwargs): - err_msg = "Table schemas must be equal for set operations." - try: - missing_from_left = right.schema - left.schema - missing_from_right = left.schema - right.schema - except ConflictingValuesError as e: - raise RelationError(err_msg + "\n" + str(e)) from e - if missing_from_left or missing_from_right: - msgs = [err_msg] - if missing_from_left: - msgs.append(f"Columns missing from the left:\n{missing_from_left}.") - if missing_from_right: - msgs.append(f"Columns missing from the right:\n{missing_from_right}.") - raise RelationError("\n".join(msgs)) + # TODO: hoist this up into the user facing API so we can see + # all the tables at once and give a better error message + errs = ["Table schemas must be unifiable for set operations."] + missing_from_left = set(right.schema.names) - set(left.schema.names) + missing_from_right = set(left.schema.names) - set(right.schema.names) + if missing_from_left: + errs.append(f"Columns missing from the left:\n{missing_from_left}.") + if missing_from_right: + errs.append(f"Columns missing from the right:\n{missing_from_right}.") + if len(errs) > 1: + raise RelationError("\n".join(errs)) + + upcasts = {} + for name in left.schema.names: + ltype, rtype = left.schema[name], right.schema[name] + try: + unified_dt = dt.highest_precedence([ltype, rtype]) + if unified_dt != ltype or unified_dt != rtype: + upcasts[name] = unified_dt + except IbisTypeError: + errs.append(f"Unable to find a common dtype for column {name}") + errs.append(f"Left dtype: {ltype!s}") + errs.append(f"Right dtype: {rtype!s}") + if len(errs) > 1: + raise RelationError("\n".join(errs)) + + if upcasts: + from ibis.expr.operations.generic import Cast + + def get_new_val(relation, name): + if name not in upcasts: + return Field(relation, name) + return Cast(Field(relation, name), upcasts[name]) + + lcols = {name: get_new_val(left, name) for name in left.schema.names} + rcols = {name: get_new_val(right, name) for name in left.schema.names} + left = Project(left, lcols) + right = Project(right, rcols) if left.schema.names != right.schema.names: # rewrite so that both sides have the columns in the same order making it diff --git a/ibis/tests/expr/test_set_operations.py b/ibis/tests/expr/test_set_operations.py index 3eb93f3c56be..ea2fafa6dd67 100644 --- a/ibis/tests/expr/test_set_operations.py +++ b/ibis/tests/expr/test_set_operations.py @@ -40,7 +40,7 @@ class D: @pytest.mark.parametrize("method", ["union", "intersect", "difference"]) def test_operation_requires_equal_schemas(method): - with pytest.raises(RelationError, match="`c`: string != float64"): + with pytest.raises(RelationError, match="Left dtype: int64\nRight dtype: string"): getattr(a, method)(d) diff --git a/ibis/tests/expr/test_table.py b/ibis/tests/expr/test_table.py index e79acecbb622..9c03d80d3bbc 100644 --- a/ibis/tests/expr/test_table.py +++ b/ibis/tests/expr/test_table.py @@ -25,7 +25,7 @@ set_ops_schema_top = [("key", "string"), ("value", "double")] set_ops_schema_bottom = [("key", "string"), ("key2", "string"), ("value", "double")] -setops_relation_error_message = "Table schemas must be equal for set operations" +setops_relation_error_message = "Table schemas must be unifiable for set operations" @pytest.fixture diff --git a/ibis/tests/expr/test_value_exprs.py b/ibis/tests/expr/test_value_exprs.py index 4b6e349aaf6e..94048bc97e74 100644 --- a/ibis/tests/expr/test_value_exprs.py +++ b/ibis/tests/expr/test_value_exprs.py @@ -641,8 +641,7 @@ def test_null_column(): def test_null_column_union(): s = ibis.table([("a", "string"), ("b", "double")]) t = ibis.table([("a", "string")]) - with pytest.raises(ibis.common.exceptions.RelationError): - s.union(t.mutate(b=ibis.null())) # needs a type + assert s.union(t.mutate(b=ibis.null())).schema() == s.schema() assert s.union(t.mutate(b=ibis.null().cast("double"))).schema() == s.schema()