Skip to content

Commit c77307b

Browse files
BlizzaraNirnay Roy
authored and
Nirnay Roy
committed
fix(substrait): fix regressed edge case in renaming inner struct fields (apache#15634)
* add a failing test case for apache#15239 (comment) * fix invariants to use logical schema equality instead, and mark utf8 and utf8view as logically equivalent * fix logical equivalence to be strictly superset of semantic equivalence, including ignoring decimal precision/scale
1 parent d0e7842 commit c77307b

File tree

4 files changed

+61
-12
lines changed

4 files changed

+61
-12
lines changed

datafusion/common/src/dfschema.rs

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -641,7 +641,7 @@ impl DFSchema {
641641
|| (!DFSchema::datatype_is_semantically_equal(
642642
f1.data_type(),
643643
f2.data_type(),
644-
) && !can_cast_types(f2.data_type(), f1.data_type()))
644+
))
645645
{
646646
_plan_err!(
647647
"Schema mismatch: Expected field '{}' with type {:?}, \
@@ -659,9 +659,12 @@ impl DFSchema {
659659
}
660660

661661
/// Checks if two [`DataType`]s are logically equal. This is a notably weaker constraint
662-
/// than datatype_is_semantically_equal in that a Dictionary<K,V> type is logically
663-
/// equal to a plain V type, but not semantically equal. Dictionary<K1, V1> is also
664-
/// logically equal to Dictionary<K2, V1>.
662+
/// than datatype_is_semantically_equal in that different representations of same data can be
663+
/// logically but not semantically equivalent. Semantically equivalent types are always also
664+
/// logically equivalent. For example:
665+
/// - a Dictionary<K,V> type is logically equal to a plain V type
666+
/// - a Dictionary<K1, V1> is also logically equal to Dictionary<K2, V1>
667+
/// - Utf8 and Utf8View are logically equal
665668
pub fn datatype_is_logically_equal(dt1: &DataType, dt2: &DataType) -> bool {
666669
// check nested fields
667670
match (dt1, dt2) {
@@ -711,12 +714,15 @@ impl DFSchema {
711714
.zip(iter2)
712715
.all(|((t1, f1), (t2, f2))| t1 == t2 && Self::field_is_logically_equal(f1, f2))
713716
}
714-
_ => dt1 == dt2,
717+
// Utf8 and Utf8View are logically equivalent
718+
(DataType::Utf8, DataType::Utf8View) => true,
719+
(DataType::Utf8View, DataType::Utf8) => true,
720+
_ => Self::datatype_is_semantically_equal(dt1, dt2),
715721
}
716722
}
717723

718724
/// Returns true of two [`DataType`]s are semantically equal (same
719-
/// name and type), ignoring both metadata and nullability.
725+
/// name and type), ignoring both metadata and nullability, and decimal precision/scale.
720726
///
721727
/// request to upstream: <https://github.com/apache/arrow-rs/issues/3199>
722728
pub fn datatype_is_semantically_equal(dt1: &DataType, dt2: &DataType) -> bool {

datafusion/expr/src/logical_plan/invariants.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,11 @@ fn assert_valid_semantic_plan(plan: &LogicalPlan) -> Result<()> {
112112
/// Returns an error if the plan does not have the expected schema.
113113
/// Ignores metadata and nullability.
114114
pub fn assert_expected_schema(schema: &DFSchemaRef, plan: &LogicalPlan) -> Result<()> {
115-
let compatible = plan.schema().has_equivalent_names_and_types(schema);
115+
let compatible = plan.schema().logically_equivalent_names_and_types(schema);
116116

117-
if let Err(e) = compatible {
117+
if !compatible {
118118
internal_err!(
119-
"Failed due to a difference in schemas: {e}, original schema: {:?}, new schema: {:?}",
119+
"Failed due to a difference in schemas: original schema: {:?}, new schema: {:?}",
120120
schema,
121121
plan.schema()
122122
)

datafusion/optimizer/src/optimizer.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -506,8 +506,11 @@ mod tests {
506506
});
507507
let err = opt.optimize(plan, &config, &observe).unwrap_err();
508508

509-
// Simplify assert to check the error message contains the expected message, which is only the schema length mismatch
510-
assert_contains!(err.strip_backtrace(), "Schema mismatch: the schema length are not same Expected schema length: 3, got: 0");
509+
// Simplify assert to check the error message contains the expected message
510+
assert_contains!(
511+
err.strip_backtrace(),
512+
"Failed due to a difference in schemas: original schema: DFSchema"
513+
);
511514
}
512515

513516
#[test]

datafusion/substrait/tests/cases/roundtrip_logical_plan.rs

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1061,7 +1061,7 @@ async fn roundtrip_literal_list() -> Result<()> {
10611061
async fn roundtrip_literal_struct() -> Result<()> {
10621062
let plan = generate_plan_from_sql(
10631063
"SELECT STRUCT(1, true, CAST(NULL AS STRING)) FROM data",
1064-
false,
1064+
true,
10651065
true,
10661066
)
10671067
.await?;
@@ -1076,6 +1076,46 @@ async fn roundtrip_literal_struct() -> Result<()> {
10761076
Ok(())
10771077
}
10781078

1079+
#[tokio::test]
1080+
async fn roundtrip_literal_named_struct() -> Result<()> {
1081+
let plan = generate_plan_from_sql(
1082+
"SELECT STRUCT(1 as int_field, true as boolean_field, CAST(NULL AS STRING) as string_field) FROM data",
1083+
true,
1084+
true,
1085+
)
1086+
.await?;
1087+
1088+
assert_snapshot!(
1089+
plan,
1090+
@r#"
1091+
Projection: Struct({int_field:1,boolean_field:true,string_field:}) AS named_struct(Utf8("int_field"),Int64(1),Utf8("boolean_field"),Boolean(true),Utf8("string_field"),NULL)
1092+
TableScan: data projection=[]
1093+
"#
1094+
);
1095+
Ok(())
1096+
}
1097+
1098+
#[tokio::test]
1099+
async fn roundtrip_literal_renamed_struct() -> Result<()> {
1100+
// This test aims to hit a case where the struct column itself has the expected name, but its
1101+
// inner field needs to be renamed.
1102+
let plan = generate_plan_from_sql(
1103+
"SELECT CAST((STRUCT(1)) AS Struct<\"int_field\"Int>) AS 'Struct({c0:1})' FROM data",
1104+
true,
1105+
true,
1106+
)
1107+
.await?;
1108+
1109+
assert_snapshot!(
1110+
plan,
1111+
@r#"
1112+
Projection: Struct({int_field:1}) AS Struct({c0:1})
1113+
TableScan: data projection=[]
1114+
"#
1115+
);
1116+
Ok(())
1117+
}
1118+
10791119
#[tokio::test]
10801120
async fn roundtrip_values() -> Result<()> {
10811121
// TODO: would be nice to have a struct inside the LargeList, but arrow_cast doesn't support that currently

0 commit comments

Comments
 (0)