Skip to content

Commit e2a5c1e

Browse files
authored
Support MIN and MAX for DataType::List (#16025)
* Fix comparisons between lists that contain nulls * Add support for lists in min/max agg functions * Add sqllogictests * Support lists in window frame target type
1 parent 828ee5a commit e2a5c1e

File tree

5 files changed

+200
-7
lines changed

5 files changed

+200
-7
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -605,6 +605,20 @@ fn partial_cmp_list(arr1: &dyn Array, arr2: &dyn Array) -> Option<Ordering> {
605605
let eq_res = arrow::compute::kernels::cmp::eq(&arr1_trimmed, &arr2_trimmed).ok()?;
606606

607607
for j in 0..lt_res.len() {
608+
// In Postgres, NULL values in lists are always considered to be greater than non-NULL values:
609+
//
610+
// $ SELECT ARRAY[NULL]::integer[] > ARRAY[1]
611+
// true
612+
//
613+
// These next two if statements are introduced for replicating Postgres behavior, as
614+
// arrow::compute does not account for this.
615+
if arr1_trimmed.is_null(j) && !arr2_trimmed.is_null(j) {
616+
return Some(Ordering::Greater);
617+
}
618+
if !arr1_trimmed.is_null(j) && arr2_trimmed.is_null(j) {
619+
return Some(Ordering::Less);
620+
}
621+
608622
if lt_res.is_valid(j) && lt_res.value(j) {
609623
return Some(Ordering::Less);
610624
}
@@ -4878,6 +4892,24 @@ mod tests {
48784892
])]),
48794893
));
48804894
assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater));
4895+
4896+
let a =
4897+
ScalarValue::List(Arc::new(
4898+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4899+
None,
4900+
Some(2),
4901+
Some(3),
4902+
])]),
4903+
));
4904+
let b =
4905+
ScalarValue::List(Arc::new(
4906+
ListArray::from_iter_primitive::<Int64Type, _, _>(vec![Some(vec![
4907+
Some(1),
4908+
Some(2),
4909+
Some(3),
4910+
])]),
4911+
));
4912+
assert_eq!(a.partial_cmp(&b), Some(Ordering::Greater));
48814913
}
48824914

48834915
#[test]

datafusion/functions-aggregate/src/min_max.rs

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -616,7 +616,8 @@ fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
616616
min_binary_view
617617
)
618618
}
619-
DataType::Struct(_) => min_max_batch_struct(values, Ordering::Greater)?,
619+
DataType::Struct(_) => min_max_batch_generic(values, Ordering::Greater)?,
620+
DataType::List(_) => min_max_batch_generic(values, Ordering::Greater)?,
620621
DataType::Dictionary(_, _) => {
621622
let values = values.as_any_dictionary().values();
622623
min_batch(values)?
@@ -625,7 +626,7 @@ fn min_batch(values: &ArrayRef) -> Result<ScalarValue> {
625626
})
626627
}
627628

628-
fn min_max_batch_struct(array: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
629+
fn min_max_batch_generic(array: &ArrayRef, ordering: Ordering) -> Result<ScalarValue> {
629630
if array.len() == array.null_count() {
630631
return ScalarValue::try_from(array.data_type());
631632
}
@@ -649,7 +650,7 @@ fn min_max_batch_struct(array: &ArrayRef, ordering: Ordering) -> Result<ScalarVa
649650
Ok(extreme.force_clone())
650651
}
651652

652-
macro_rules! min_max_struct {
653+
macro_rules! min_max_generic {
653654
($VALUE:expr, $DELTA:expr, $OP:ident) => {{
654655
if $VALUE.is_null() {
655656
$DELTA.clone()
@@ -703,7 +704,8 @@ pub fn max_batch(values: &ArrayRef) -> Result<ScalarValue> {
703704
max_binary
704705
)
705706
}
706-
DataType::Struct(_) => min_max_batch_struct(values, Ordering::Less)?,
707+
DataType::Struct(_) => min_max_batch_generic(values, Ordering::Less)?,
708+
DataType::List(_) => min_max_batch_generic(values, Ordering::Less)?,
707709
DataType::Dictionary(_, _) => {
708710
let values = values.as_any_dictionary().values();
709711
max_batch(values)?
@@ -983,7 +985,14 @@ macro_rules! min_max {
983985
lhs @ ScalarValue::Struct(_),
984986
rhs @ ScalarValue::Struct(_),
985987
) => {
986-
min_max_struct!(lhs, rhs, $OP)
988+
min_max_generic!(lhs, rhs, $OP)
989+
}
990+
991+
(
992+
lhs @ ScalarValue::List(_),
993+
rhs @ ScalarValue::List(_),
994+
) => {
995+
min_max_generic!(lhs, rhs, $OP)
987996
}
988997
e => {
989998
return internal_err!(

datafusion/optimizer/src/analyzer/type_coercion.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,7 @@ fn coerce_frame_bound(
718718
fn extract_window_frame_target_type(col_type: &DataType) -> Result<DataType> {
719719
if col_type.is_numeric()
720720
|| is_utf8_or_utf8view_or_large_utf8(col_type)
721+
|| matches!(col_type, DataType::List(_))
721722
|| matches!(col_type, DataType::Null)
722723
|| matches!(col_type, DataType::Boolean)
723724
{

datafusion/sqllogictest/test_files/aggregate.slt

Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6997,4 +6997,151 @@ VALUES
69976997
----
69986998
{a: 1, b: 2, c: 3} {a: 1, b: 2, c: 4}
69996999

7000+
# Min/Max with list over integers
7001+
query ??
7002+
SELECT MIN(column1), MAX(column1) FROM VALUES
7003+
([1, 2, 3]),
7004+
([1, 2]);
7005+
----
7006+
[1, 2] [1, 2, 3]
7007+
7008+
# Min/Max with lists over strings
7009+
query ??
7010+
SELECT MIN(column1), MAX(column1) FROM VALUES
7011+
(['a', 'b', 'c']),
7012+
(['a', 'b']);
7013+
----
7014+
[a, b] [a, b, c]
7015+
7016+
# Min/Max with list over booleans
7017+
query ??
7018+
SELECT MIN(column1), MAX(column1) FROM VALUES
7019+
([true, true, false]),
7020+
([false, true]);
7021+
----
7022+
[false, true] [true, true, false]
7023+
7024+
# Min/Max with list over nullable integers
7025+
query ??
7026+
SELECT MIN(column1), MAX(column1) FROM VALUES
7027+
([NULL, 1, 2]),
7028+
([1, 2]);
7029+
----
7030+
[1, 2] [NULL, 1, 2]
7031+
7032+
# Min/Max list with different lengths and nulls
7033+
query ??
7034+
SELECT MIN(column1), MAX(column1) FROM VALUES
7035+
([1, NULL, 3]),
7036+
([1, 2, 3, 4]),
7037+
([1, 2]);
7038+
----
7039+
[1, 2] [1, NULL, 3]
7040+
7041+
# Min/Max list with only NULLs
7042+
query ??
7043+
SELECT MIN(column1), MAX(column1) FROM VALUES
7044+
([NULL, NULL]),
7045+
([NULL]);
7046+
----
7047+
[NULL] [NULL, NULL]
7048+
7049+
# Min/Max list with empty lists
7050+
query ??
7051+
SELECT MIN(column1), MAX(column1) FROM VALUES
7052+
([]),
7053+
([1]),
7054+
([]);
7055+
----
7056+
[] [1]
7057+
7058+
# Min/Max list of varying types (integers and NULLs)
7059+
query ??
7060+
SELECT MIN(column1), MAX(column1) FROM VALUES
7061+
([1, 2, 3]),
7062+
([NULL, 2, 3]),
7063+
([1, 2, NULL]);
7064+
----
7065+
[1, 2, 3] [NULL, 2, 3]
7066+
7067+
# Min/Max list grouped by key with NULLs and differing lengths
7068+
query I?? rowsort
7069+
SELECT column1, MIN(column2), MAX(column2) FROM VALUES
7070+
(0, [1, NULL, 3]),
7071+
(0, [1, 2, 3, 4]),
7072+
(1, [1, 2]),
7073+
(1, [NULL, 5]),
7074+
(1, [])
7075+
GROUP BY column1;
7076+
----
7077+
0 [1, 2, 3, 4] [1, NULL, 3]
7078+
1 [] [NULL, 5]
7079+
7080+
# Min/Max list grouped by key with NULLs and differing lengths
7081+
query I?? rowsort
7082+
SELECT column1, MIN(column2), MAX(column2) FROM VALUES
7083+
(0, [NULL]),
7084+
(0, [NULL, NULL]),
7085+
(1, [NULL])
7086+
GROUP BY column1;
7087+
----
7088+
0 [NULL] [NULL, NULL]
7089+
1 [NULL] [NULL]
7090+
7091+
# Min/Max grouped list with empty and non-empty
7092+
query I?? rowsort
7093+
SELECT column1, MIN(column2), MAX(column2) FROM VALUES
7094+
(0, []),
7095+
(0, [1]),
7096+
(0, []),
7097+
(1, [5, 6]),
7098+
(1, [])
7099+
GROUP BY column1;
7100+
----
7101+
0 [] [1]
7102+
1 [] [5, 6]
7103+
7104+
# Min/Max over lists with a window function
7105+
query ?
7106+
SELECT min(column1) OVER (ORDER BY column1) FROM VALUES
7107+
([1, 2, 3]),
7108+
([1, 2, 3]),
7109+
([2, 3])
7110+
----
7111+
[1, 2, 3]
7112+
[1, 2, 3]
7113+
[1, 2, 3]
7114+
7115+
# Min/Max over lists with a window function and nulls
7116+
query ?
7117+
SELECT min(column1) OVER (ORDER BY column1) FROM VALUES
7118+
(NULL),
7119+
([4, 5]),
7120+
([2, 3])
7121+
----
7122+
[2, 3]
7123+
[2, 3]
7124+
[2, 3]
7125+
7126+
# Min/Max over lists with a window function, nulls and ROWS BETWEEN statement
7127+
query ?
7128+
SELECT min(column1) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM VALUES
7129+
(NULL),
7130+
([4, 5]),
7131+
([2, 3])
7132+
----
7133+
[2, 3]
7134+
[2, 3]
7135+
[4, 5]
7136+
7137+
# Min/Max over lists with a window function using a different column
7138+
query ?
7139+
SELECT max(column2) OVER (ORDER BY column1 ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING) FROM VALUES
7140+
([1, 2, 3], [4, 5]),
7141+
([2, 3], [2, 3]),
7142+
([1, 2, 3], NULL)
7143+
----
7144+
[4, 5]
7145+
[4, 5]
7146+
[2, 3]
70007147

datafusion/sqllogictest/test_files/array_query.slt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -108,11 +108,15 @@ SELECT * FROM data WHERE column2 is not distinct from null;
108108
# Aggregates
109109
###########
110110

111-
query error Internal error: Min/Max accumulator not implemented for type List
111+
query ?
112112
SELECT min(column1) FROM data;
113+
----
114+
[1, 2, 3]
113115

114-
query error Internal error: Min/Max accumulator not implemented for type List
116+
query ?
115117
SELECT max(column1) FROM data;
118+
----
119+
[2, 3]
116120

117121
query I
118122
SELECT count(column1) FROM data;

0 commit comments

Comments
 (0)