Skip to content

Commit ec4a6da

Browse files
alambLordworms
andauthored
coercion vec[Dictionary, Utf8] to Dictionary for coalesce function (#9958) (#10104)
* for debug finish remove print add space * fix clippy * finish * fix clippy Co-authored-by: Lordworms <[email protected]>
1 parent 9974cee commit ec4a6da

File tree

2 files changed

+75
-24
lines changed

2 files changed

+75
-24
lines changed

datafusion/expr/src/type_coercion/functions.rs

Lines changed: 35 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -311,33 +311,43 @@ fn coerced_from<'a>(
311311
type_from: &'a DataType,
312312
) -> Option<DataType> {
313313
use self::DataType::*;
314-
315-
match type_into {
314+
// match Dictionary first
315+
match (type_into, type_from) {
316+
// coerced dictionary first
317+
(cur_type, Dictionary(_, value_type)) | (Dictionary(_, value_type), cur_type)
318+
if coerced_from(cur_type, value_type).is_some() =>
319+
{
320+
Some(type_into.clone())
321+
}
316322
// coerced into type_into
317-
Int8 if matches!(type_from, Null | Int8) => Some(type_into.clone()),
318-
Int16 if matches!(type_from, Null | Int8 | Int16 | UInt8) => {
323+
(Int8, _) if matches!(type_from, Null | Int8) => Some(type_into.clone()),
324+
(Int16, _) if matches!(type_from, Null | Int8 | Int16 | UInt8) => {
319325
Some(type_into.clone())
320326
}
321-
Int32 if matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) => {
327+
(Int32, _)
328+
if matches!(type_from, Null | Int8 | Int16 | Int32 | UInt8 | UInt16) =>
329+
{
322330
Some(type_into.clone())
323331
}
324-
Int64
332+
(Int64, _)
325333
if matches!(
326334
type_from,
327335
Null | Int8 | Int16 | Int32 | Int64 | UInt8 | UInt16 | UInt32
328336
) =>
329337
{
330338
Some(type_into.clone())
331339
}
332-
UInt8 if matches!(type_from, Null | UInt8) => Some(type_into.clone()),
333-
UInt16 if matches!(type_from, Null | UInt8 | UInt16) => Some(type_into.clone()),
334-
UInt32 if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => {
340+
(UInt8, _) if matches!(type_from, Null | UInt8) => Some(type_into.clone()),
341+
(UInt16, _) if matches!(type_from, Null | UInt8 | UInt16) => {
342+
Some(type_into.clone())
343+
}
344+
(UInt32, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32) => {
335345
Some(type_into.clone())
336346
}
337-
UInt64 if matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64) => {
347+
(UInt64, _) if matches!(type_from, Null | UInt8 | UInt16 | UInt32 | UInt64) => {
338348
Some(type_into.clone())
339349
}
340-
Float32
350+
(Float32, _)
341351
if matches!(
342352
type_from,
343353
Null | Int8
@@ -353,7 +363,7 @@ fn coerced_from<'a>(
353363
{
354364
Some(type_into.clone())
355365
}
356-
Float64
366+
(Float64, _)
357367
if matches!(
358368
type_from,
359369
Null | Int8
@@ -371,31 +381,35 @@ fn coerced_from<'a>(
371381
{
372382
Some(type_into.clone())
373383
}
374-
Timestamp(TimeUnit::Nanosecond, None)
384+
(Timestamp(TimeUnit::Nanosecond, None), _)
375385
if matches!(
376386
type_from,
377387
Null | Timestamp(_, None) | Date32 | Utf8 | LargeUtf8
378388
) =>
379389
{
380390
Some(type_into.clone())
381391
}
382-
Interval(_) if matches!(type_from, Utf8 | LargeUtf8) => Some(type_into.clone()),
392+
(Interval(_), _) if matches!(type_from, Utf8 | LargeUtf8) => {
393+
Some(type_into.clone())
394+
}
383395
// Any type can be coerced into strings
384-
Utf8 | LargeUtf8 => Some(type_into.clone()),
385-
Null if can_cast_types(type_from, type_into) => Some(type_into.clone()),
396+
(Utf8 | LargeUtf8, _) => Some(type_into.clone()),
397+
(Null, _) if can_cast_types(type_from, type_into) => Some(type_into.clone()),
386398

387-
List(_) if matches!(type_from, FixedSizeList(_, _)) => Some(type_into.clone()),
399+
(List(_), _) if matches!(type_from, FixedSizeList(_, _)) => {
400+
Some(type_into.clone())
401+
}
388402

389403
// Only accept list and largelist with the same number of dimensions unless the type is Null.
390404
// List or LargeList with different dimensions should be handled in TypeSignature or other places before this
391-
List(_) | LargeList(_)
405+
(List(_) | LargeList(_), _)
392406
if datafusion_common::utils::base_type(type_from).eq(&Null)
393407
|| list_ndims(type_from) == list_ndims(type_into) =>
394408
{
395409
Some(type_into.clone())
396410
}
397411
// should be able to coerce wildcard fixed size list to non wildcard fixed size list
398-
FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD) => match type_from {
412+
(FixedSizeList(f_into, FIXED_SIZE_LIST_WILDCARD), _) => match type_from {
399413
FixedSizeList(f_from, size_from) => {
400414
match coerced_from(f_into.data_type(), f_from.data_type()) {
401415
Some(data_type) if &data_type != f_into.data_type() => {
@@ -410,7 +424,7 @@ fn coerced_from<'a>(
410424
_ => None,
411425
},
412426

413-
Timestamp(unit, Some(tz)) if tz.as_ref() == TIMEZONE_WILDCARD => {
427+
(Timestamp(unit, Some(tz)), _) if tz.as_ref() == TIMEZONE_WILDCARD => {
414428
match type_from {
415429
Timestamp(_, Some(from_tz)) => {
416430
Some(Timestamp(unit.clone(), Some(from_tz.clone())))
@@ -422,15 +436,14 @@ fn coerced_from<'a>(
422436
_ => None,
423437
}
424438
}
425-
Timestamp(_, Some(_))
439+
(Timestamp(_, Some(_)), _)
426440
if matches!(
427441
type_from,
428442
Null | Timestamp(_, _) | Date32 | Utf8 | LargeUtf8
429443
) =>
430444
{
431445
Some(type_into.clone())
432446
}
433-
434447
// More coerce rules.
435448
// Note that not all rules in `comparison_coercion` can be reused here.
436449
// For example, all numeric types can be coerced into Utf8 for comparison,

datafusion/sqllogictest/test_files/scalar.slt

Lines changed: 40 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1779,6 +1779,46 @@ SELECT COALESCE(NULL, 'test')
17791779
----
17801780
test
17811781

1782+
1783+
statement ok
1784+
create table test1 as values (arrow_cast('foo', 'Dictionary(Int32, Utf8)')), (null);
1785+
1786+
# test coercion string
1787+
query ?
1788+
select coalesce(column1, 'none_set') from test1;
1789+
----
1790+
foo
1791+
none_set
1792+
1793+
# test coercion Int
1794+
query I
1795+
select coalesce(34, arrow_cast(123, 'Dictionary(Int32, Int8)'));
1796+
----
1797+
34
1798+
1799+
# test with Int
1800+
query I
1801+
select coalesce(arrow_cast(123, 'Dictionary(Int32, Int8)'),34);
1802+
----
1803+
123
1804+
1805+
# test with null
1806+
query I
1807+
select coalesce(null, 34, arrow_cast(123, 'Dictionary(Int32, Int8)'));
1808+
----
1809+
34
1810+
1811+
# test with null
1812+
query T
1813+
select coalesce(null, column1, 'none_set') from test1;
1814+
----
1815+
foo
1816+
none_set
1817+
1818+
statement ok
1819+
drop table test1
1820+
1821+
17821822
statement ok
17831823
CREATE TABLE test(
17841824
c1 INT,
@@ -2162,5 +2202,3 @@ query I
21622202
select strpos('joséésoj', arrow_cast(null, 'Utf8'));
21632203
----
21642204
NULL
2165-
2166-

0 commit comments

Comments
 (0)