Skip to content

Commit 9fb5ff9

Browse files
authored
Fix join on arrays of unhashable types and allow hash join on all types supported at run-time (#13388)
* Remove unused code paths from create_hashes The `downcast_primitive_array!` macro handles all primitive types and only then delegates to fallbacks. It handles Decimal128 and Decimal256 internally. * Fix join on arrays of unhashable types and allow hash join on all types supported at run-time #13388 Update can_hash to match currently supported hashes. * Rename table_with_many_types in tests * Test join on binary is hash join
1 parent c44b613 commit 9fb5ff9

File tree

5 files changed

+79
-36
lines changed

5 files changed

+79
-36
lines changed

datafusion/common/src/hash_utils.rs

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use arrow_buffer::IntervalMonthDayNano;
3232
use crate::cast::{
3333
as_binary_view_array, as_boolean_array, as_fixed_size_list_array,
3434
as_generic_binary_array, as_large_list_array, as_list_array, as_map_array,
35-
as_primitive_array, as_string_array, as_string_view_array, as_struct_array,
35+
as_string_array, as_string_view_array, as_struct_array,
3636
};
3737
use crate::error::Result;
3838
#[cfg(not(feature = "force_hash_collisions"))]
@@ -392,14 +392,6 @@ pub fn create_hashes<'a>(
392392
let array: &FixedSizeBinaryArray = array.as_any().downcast_ref().unwrap();
393393
hash_array(array, random_state, hashes_buffer, rehash)
394394
}
395-
DataType::Decimal128(_, _) => {
396-
let array = as_primitive_array::<Decimal128Type>(array)?;
397-
hash_array_primitive(array, random_state, hashes_buffer, rehash)
398-
}
399-
DataType::Decimal256(_, _) => {
400-
let array = as_primitive_array::<Decimal256Type>(array)?;
401-
hash_array_primitive(array, random_state, hashes_buffer, rehash)
402-
}
403395
DataType::Dictionary(_, _) => downcast_dictionary_array! {
404396
array => hash_dictionary(array, random_state, hashes_buffer, rehash)?,
405397
_ => unreachable!()

datafusion/expr/src/utils.rs

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use crate::{
2929
};
3030
use datafusion_expr_common::signature::{Signature, TypeSignature};
3131

32-
use arrow::datatypes::{DataType, Field, Schema, TimeUnit};
32+
use arrow::datatypes::{DataType, Field, Schema};
3333
use datafusion_common::tree_node::{
3434
Transformed, TransformedResult, TreeNode, TreeNodeRecursion,
3535
};
@@ -958,7 +958,7 @@ pub(crate) fn find_column_indexes_referenced_by_expr(
958958

959959
/// Can this data type be used in hash join equal conditions??
960960
/// Data types here come from function 'equal_rows', if more data types are supported
961-
/// in equal_rows(hash join), add those data types here to generate join logical plan.
961+
/// in create_hashes, add those data types here to generate join logical plan.
962962
pub fn can_hash(data_type: &DataType) -> bool {
963963
match data_type {
964964
DataType::Null => true,
@@ -971,31 +971,38 @@ pub fn can_hash(data_type: &DataType) -> bool {
971971
DataType::UInt16 => true,
972972
DataType::UInt32 => true,
973973
DataType::UInt64 => true,
974+
DataType::Float16 => true,
974975
DataType::Float32 => true,
975976
DataType::Float64 => true,
976-
DataType::Timestamp(time_unit, _) => match time_unit {
977-
TimeUnit::Second => true,
978-
TimeUnit::Millisecond => true,
979-
TimeUnit::Microsecond => true,
980-
TimeUnit::Nanosecond => true,
981-
},
977+
DataType::Decimal128(_, _) => true,
978+
DataType::Decimal256(_, _) => true,
979+
DataType::Timestamp(_, _) => true,
982980
DataType::Utf8 => true,
983981
DataType::LargeUtf8 => true,
984982
DataType::Utf8View => true,
985-
DataType::Decimal128(_, _) => true,
983+
DataType::Binary => true,
984+
DataType::LargeBinary => true,
985+
DataType::BinaryView => true,
986986
DataType::Date32 => true,
987987
DataType::Date64 => true,
988+
DataType::Time32(_) => true,
989+
DataType::Time64(_) => true,
990+
DataType::Duration(_) => true,
991+
DataType::Interval(_) => true,
988992
DataType::FixedSizeBinary(_) => true,
989-
DataType::Dictionary(key_type, value_type)
990-
if *value_type.as_ref() == DataType::Utf8 =>
991-
{
992-
DataType::is_dictionary_key_type(key_type)
993+
DataType::Dictionary(key_type, value_type) => {
994+
DataType::is_dictionary_key_type(key_type) && can_hash(value_type)
993995
}
994-
DataType::List(_) => true,
995-
DataType::LargeList(_) => true,
996-
DataType::FixedSizeList(_, _) => true,
996+
DataType::List(value_type) => can_hash(value_type.data_type()),
997+
DataType::LargeList(value_type) => can_hash(value_type.data_type()),
998+
DataType::FixedSizeList(value_type, _) => can_hash(value_type.data_type()),
999+
DataType::Map(map_struct, true | false) => can_hash(map_struct.data_type()),
9971000
DataType::Struct(fields) => fields.iter().all(|f| can_hash(f.data_type())),
998-
_ => false,
1001+
1002+
DataType::ListView(_)
1003+
| DataType::LargeListView(_)
1004+
| DataType::Union(_, _)
1005+
| DataType::RunEndEncoded(_, _) => false,
9991006
}
10001007
}
10011008

@@ -1403,6 +1410,7 @@ mod tests {
14031410
test::function_stub::max_udaf, test::function_stub::min_udaf,
14041411
test::function_stub::sum_udaf, Cast, ExprFunctionExt, WindowFunctionDefinition,
14051412
};
1413+
use arrow::datatypes::{UnionFields, UnionMode};
14061414

14071415
#[test]
14081416
fn test_group_window_expr_by_sort_keys_empty_case() -> Result<()> {
@@ -1805,4 +1813,21 @@ mod tests {
18051813
assert!(accum.contains(&Column::from_name("a")));
18061814
Ok(())
18071815
}
1816+
1817+
#[test]
1818+
fn test_can_hash() {
1819+
let union_fields: UnionFields = [
1820+
(0, Arc::new(Field::new("A", DataType::Int32, true))),
1821+
(1, Arc::new(Field::new("B", DataType::Float64, true))),
1822+
]
1823+
.into_iter()
1824+
.collect();
1825+
1826+
let union_type = DataType::Union(union_fields, UnionMode::Sparse);
1827+
assert!(!can_hash(&union_type));
1828+
1829+
let list_union_type =
1830+
DataType::List(Arc::new(Field::new("my_union", union_type, true)));
1831+
assert!(!can_hash(&list_union_type));
1832+
}
18081833
}

datafusion/sqllogictest/src/test_context.rs

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,8 @@ impl TestContext {
106106
let example_udf = create_example_udf();
107107
test_ctx.ctx.register_udf(example_udf);
108108
register_partition_table(&mut test_ctx).await;
109+
info!("Registering table with many types");
110+
register_table_with_many_types(test_ctx.session_ctx()).await;
109111
}
110112
"metadata.slt" => {
111113
info!("Registering metadata table tables");
@@ -251,8 +253,11 @@ pub async fn register_table_with_many_types(ctx: &SessionContext) {
251253
.unwrap();
252254
ctx.register_catalog("my_catalog", Arc::new(catalog));
253255

254-
ctx.register_table("my_catalog.my_schema.t2", table_with_many_types())
255-
.unwrap();
256+
ctx.register_table(
257+
"my_catalog.my_schema.table_with_many_types",
258+
table_with_many_types(),
259+
)
260+
.unwrap();
256261
}
257262

258263
pub async fn register_table_with_map(ctx: &SessionContext) {

datafusion/sqllogictest/test_files/information_schema_columns.slt

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ query TTTTITTTIIIIIIT rowsort
3737
SELECT * from information_schema.columns;
3838
----
3939
my_catalog my_schema t1 i 0 NULL YES Int32 NULL NULL 32 2 NULL NULL NULL
40-
my_catalog my_schema t2 binary_col 4 NULL NO Binary NULL 2147483647 NULL NULL NULL NULL NULL
41-
my_catalog my_schema t2 float64_col 1 NULL YES Float64 NULL NULL 24 2 NULL NULL NULL
42-
my_catalog my_schema t2 int32_col 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL
43-
my_catalog my_schema t2 large_binary_col 5 NULL NO LargeBinary NULL 9223372036854775807 NULL NULL NULL NULL NULL
44-
my_catalog my_schema t2 large_utf8_col 3 NULL NO LargeUtf8 NULL 9223372036854775807 NULL NULL NULL NULL NULL
45-
my_catalog my_schema t2 timestamp_nanos 6 NULL NO Timestamp(Nanosecond, None) NULL NULL NULL NULL NULL NULL NULL
46-
my_catalog my_schema t2 utf8_col 2 NULL YES Utf8 NULL 2147483647 NULL NULL NULL NULL NULL
40+
my_catalog my_schema table_with_many_types binary_col 4 NULL NO Binary NULL 2147483647 NULL NULL NULL NULL NULL
41+
my_catalog my_schema table_with_many_types float64_col 1 NULL YES Float64 NULL NULL 24 2 NULL NULL NULL
42+
my_catalog my_schema table_with_many_types int32_col 0 NULL NO Int32 NULL NULL 32 2 NULL NULL NULL
43+
my_catalog my_schema table_with_many_types large_binary_col 5 NULL NO LargeBinary NULL 9223372036854775807 NULL NULL NULL NULL NULL
44+
my_catalog my_schema table_with_many_types large_utf8_col 3 NULL NO LargeUtf8 NULL 9223372036854775807 NULL NULL NULL NULL NULL
45+
my_catalog my_schema table_with_many_types timestamp_nanos 6 NULL NO Timestamp(Nanosecond, None) NULL NULL NULL NULL NULL NULL NULL
46+
my_catalog my_schema table_with_many_types utf8_col 2 NULL YES Utf8 NULL 2147483647 NULL NULL NULL NULL NULL
4747

4848
# Cleanup
4949
statement ok
5050
drop table t1
5151

5252
statement ok
53-
drop table t2
53+
drop table table_with_many_types

datafusion/sqllogictest/test_files/joins.slt

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4292,3 +4292,24 @@ query T
42924292
select * from table1 as t1 natural join table1_stringview as t2;
42934293
----
42944294
foo
4295+
4296+
query TT
4297+
EXPLAIN SELECT count(*)
4298+
FROM my_catalog.my_schema.table_with_many_types AS l
4299+
JOIN my_catalog.my_schema.table_with_many_types AS r ON l.binary_col = r.binary_col
4300+
----
4301+
logical_plan
4302+
01)Aggregate: groupBy=[[]], aggr=[[count(Int64(1)) AS count(*)]]
4303+
02)--Projection:
4304+
03)----Inner Join: l.binary_col = r.binary_col
4305+
04)------SubqueryAlias: l
4306+
05)--------TableScan: my_catalog.my_schema.table_with_many_types projection=[binary_col]
4307+
06)------SubqueryAlias: r
4308+
07)--------TableScan: my_catalog.my_schema.table_with_many_types projection=[binary_col]
4309+
physical_plan
4310+
01)AggregateExec: mode=Single, gby=[], aggr=[count(*)]
4311+
02)--ProjectionExec: expr=[]
4312+
03)----CoalesceBatchesExec: target_batch_size=3
4313+
04)------HashJoinExec: mode=CollectLeft, join_type=Inner, on=[(binary_col@0, binary_col@0)]
4314+
05)--------MemoryExec: partitions=1, partition_sizes=[1]
4315+
06)--------MemoryExec: partitions=1, partition_sizes=[1]

0 commit comments

Comments
 (0)