Skip to content

Commit d9d8ddd

Browse files
authored
feat: Support array_sort(list_sort) (#8279)
* Minor: Improve the document format of JoinHashMap * list sort * fix: example doc * fix: ci * fix: doc error * fix pb * like DuckDB function semantics * fix ci * fix pb * fix: doc * add table test * fix: not as expected * fix: return null * resolve conflicts * doc * merge
1 parent fa8a0d9 commit d9d8ddd

File tree

11 files changed

+194
-12
lines changed

11 files changed

+194
-12
lines changed

datafusion/expr/src/built_in_function.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,8 @@ pub enum BuiltinScalarFunction {
130130
// array functions
131131
/// array_append
132132
ArrayAppend,
133+
/// array_sort
134+
ArraySort,
133135
/// array_concat
134136
ArrayConcat,
135137
/// array_has
@@ -398,6 +400,7 @@ impl BuiltinScalarFunction {
398400
BuiltinScalarFunction::Tanh => Volatility::Immutable,
399401
BuiltinScalarFunction::Trunc => Volatility::Immutable,
400402
BuiltinScalarFunction::ArrayAppend => Volatility::Immutable,
403+
BuiltinScalarFunction::ArraySort => Volatility::Immutable,
401404
BuiltinScalarFunction::ArrayConcat => Volatility::Immutable,
402405
BuiltinScalarFunction::ArrayEmpty => Volatility::Immutable,
403406
BuiltinScalarFunction::ArrayHasAll => Volatility::Immutable,
@@ -545,6 +548,7 @@ impl BuiltinScalarFunction {
545548
Ok(data_type)
546549
}
547550
BuiltinScalarFunction::ArrayAppend => Ok(input_expr_types[0].clone()),
551+
BuiltinScalarFunction::ArraySort => Ok(input_expr_types[0].clone()),
548552
BuiltinScalarFunction::ArrayConcat => {
549553
let mut expr_type = Null;
550554
let mut max_dims = 0;
@@ -909,6 +913,9 @@ impl BuiltinScalarFunction {
909913
// for now, the list is small, as we do not have many built-in functions.
910914
match self {
911915
BuiltinScalarFunction::ArrayAppend => Signature::any(2, self.volatility()),
916+
BuiltinScalarFunction::ArraySort => {
917+
Signature::variadic_any(self.volatility())
918+
}
912919
BuiltinScalarFunction::ArrayPopFront => Signature::any(1, self.volatility()),
913920
BuiltinScalarFunction::ArrayPopBack => Signature::any(1, self.volatility()),
914921
BuiltinScalarFunction::ArrayConcat => {
@@ -1558,6 +1565,7 @@ impl BuiltinScalarFunction {
15581565
"array_push_back",
15591566
"list_push_back",
15601567
],
1568+
BuiltinScalarFunction::ArraySort => &["array_sort", "list_sort"],
15611569
BuiltinScalarFunction::ArrayConcat => {
15621570
&["array_concat", "array_cat", "list_concat", "list_cat"]
15631571
}

datafusion/expr/src/expr_fn.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -583,6 +583,8 @@ scalar_expr!(
583583
"appends an element to the end of an array."
584584
);
585585

586+
scalar_expr!(ArraySort, array_sort, array desc null_first, "returns sorted array.");
587+
586588
scalar_expr!(
587589
ArrayPopBack,
588590
array_pop_back,
@@ -1184,6 +1186,7 @@ mod test {
11841186
test_scalar_expr!(FromUnixtime, from_unixtime, unixtime);
11851187

11861188
test_scalar_expr!(ArrayAppend, array_append, array, element);
1189+
test_scalar_expr!(ArraySort, array_sort, array, desc, null_first);
11871190
test_scalar_expr!(ArrayPopFront, array_pop_front, array);
11881191
test_scalar_expr!(ArrayPopBack, array_pop_back, array);
11891192
test_unary_scalar_expr!(ArrayDims, array_dims);

datafusion/physical-expr/src/array_expressions.rs

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ use arrow::datatypes::{DataType, Field, UInt64Type};
2929
use arrow::row::{RowConverter, SortField};
3030
use arrow_buffer::NullBuffer;
3131

32-
use arrow_schema::FieldRef;
32+
use arrow_schema::{FieldRef, SortOptions};
3333
use datafusion_common::cast::{
3434
as_generic_list_array, as_generic_string_array, as_int64_array, as_list_array,
3535
as_null_array, as_string_array,
@@ -693,7 +693,7 @@ fn general_append_and_prepend(
693693
/// # Arguments
694694
///
695695
/// * `args` - An array of 1 to 3 ArrayRefs representing start, stop, and step(step value can not be zero.) values.
696-
///
696+
///
697697
/// # Examples
698698
///
699699
/// gen_range(3) => [0, 1, 2]
@@ -777,6 +777,85 @@ pub fn array_append(args: &[ArrayRef]) -> Result<ArrayRef> {
777777
Ok(res)
778778
}
779779

780+
/// Array_sort SQL function
781+
pub fn array_sort(args: &[ArrayRef]) -> Result<ArrayRef> {
782+
let sort_option = match args.len() {
783+
1 => None,
784+
2 => {
785+
let sort = as_string_array(&args[1])?.value(0);
786+
Some(SortOptions {
787+
descending: order_desc(sort)?,
788+
nulls_first: true,
789+
})
790+
}
791+
3 => {
792+
let sort = as_string_array(&args[1])?.value(0);
793+
let nulls_first = as_string_array(&args[2])?.value(0);
794+
Some(SortOptions {
795+
descending: order_desc(sort)?,
796+
nulls_first: order_nulls_first(nulls_first)?,
797+
})
798+
}
799+
_ => return internal_err!("array_sort expects 1 to 3 arguments"),
800+
};
801+
802+
let list_array = as_list_array(&args[0])?;
803+
let row_count = list_array.len();
804+
805+
let mut array_lengths = vec![];
806+
let mut arrays = vec![];
807+
let mut valid = BooleanBufferBuilder::new(row_count);
808+
for i in 0..row_count {
809+
if list_array.is_null(i) {
810+
array_lengths.push(0);
811+
valid.append(false);
812+
} else {
813+
let arr_ref = list_array.value(i);
814+
let arr_ref = arr_ref.as_ref();
815+
816+
let sorted_array = compute::sort(arr_ref, sort_option)?;
817+
array_lengths.push(sorted_array.len());
818+
arrays.push(sorted_array);
819+
valid.append(true);
820+
}
821+
}
822+
823+
// Assume all arrays have the same data type
824+
let data_type = list_array.value_type();
825+
let buffer = valid.finish();
826+
827+
let elements = arrays
828+
.iter()
829+
.map(|a| a.as_ref())
830+
.collect::<Vec<&dyn Array>>();
831+
832+
let list_arr = ListArray::new(
833+
Arc::new(Field::new("item", data_type, true)),
834+
OffsetBuffer::from_lengths(array_lengths),
835+
Arc::new(compute::concat(elements.as_slice())?),
836+
Some(NullBuffer::new(buffer)),
837+
);
838+
Ok(Arc::new(list_arr))
839+
}
840+
841+
fn order_desc(modifier: &str) -> Result<bool> {
842+
match modifier.to_uppercase().as_str() {
843+
"DESC" => Ok(true),
844+
"ASC" => Ok(false),
845+
_ => internal_err!("the second parameter of array_sort expects DESC or ASC"),
846+
}
847+
}
848+
849+
fn order_nulls_first(modifier: &str) -> Result<bool> {
850+
match modifier.to_uppercase().as_str() {
851+
"NULLS FIRST" => Ok(true),
852+
"NULLS LAST" => Ok(false),
853+
_ => internal_err!(
854+
"the third parameter of array_sort expects NULLS FIRST or NULLS LAST"
855+
),
856+
}
857+
}
858+
780859
/// Array_prepend SQL function
781860
pub fn array_prepend(args: &[ArrayRef]) -> Result<ArrayRef> {
782861
let list_array = as_list_array(&args[1])?;

datafusion/physical-expr/src/functions.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -329,6 +329,9 @@ pub fn create_physical_fun(
329329
BuiltinScalarFunction::ArrayAppend => {
330330
Arc::new(|args| make_scalar_function(array_expressions::array_append)(args))
331331
}
332+
BuiltinScalarFunction::ArraySort => {
333+
Arc::new(|args| make_scalar_function(array_expressions::array_sort)(args))
334+
}
332335
BuiltinScalarFunction::ArrayConcat => {
333336
Arc::new(|args| make_scalar_function(array_expressions::array_concat)(args))
334337
}

datafusion/proto/proto/datafusion.proto

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -644,6 +644,7 @@ enum ScalarFunction {
644644
Levenshtein = 125;
645645
SubstrIndex = 126;
646646
FindInSet = 127;
647+
ArraySort = 128;
647648
}
648649

649650
message ScalarFunctionNode {

datafusion/proto/src/generated/pbjson.rs

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/generated/prost.rs

Lines changed: 3 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/proto/src/logical_plan/from_proto.rs

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,11 @@ use datafusion_expr::{
4444
array_except, array_has, array_has_all, array_has_any, array_intersect, array_length,
4545
array_ndims, array_position, array_positions, array_prepend, array_remove,
4646
array_remove_all, array_remove_n, array_repeat, array_replace, array_replace_all,
47-
array_replace_n, array_slice, array_to_string, arrow_typeof, ascii, asin, asinh,
48-
atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil, character_length,
49-
chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot, current_date,
50-
current_time, date_bin, date_part, date_trunc, decode, degrees, digest, encode, exp,
47+
array_replace_n, array_slice, array_sort, array_to_string, arrow_typeof, ascii, asin,
48+
asinh, atan, atan2, atanh, bit_length, btrim, cardinality, cbrt, ceil,
49+
character_length, chr, coalesce, concat_expr, concat_ws_expr, cos, cosh, cot,
50+
current_date, current_time, date_bin, date_part, date_trunc, decode, degrees, digest,
51+
encode, exp,
5152
expr::{self, InList, Sort, WindowFunction},
5253
factorial, find_in_set, flatten, floor, from_unixtime, gcd, gen_range, isnan, iszero,
5354
lcm, left, levenshtein, ln, log, log10, log2,
@@ -463,6 +464,7 @@ impl From<&protobuf::ScalarFunction> for BuiltinScalarFunction {
463464
ScalarFunction::Rtrim => Self::Rtrim,
464465
ScalarFunction::ToTimestamp => Self::ToTimestamp,
465466
ScalarFunction::ArrayAppend => Self::ArrayAppend,
467+
ScalarFunction::ArraySort => Self::ArraySort,
466468
ScalarFunction::ArrayConcat => Self::ArrayConcat,
467469
ScalarFunction::ArrayEmpty => Self::ArrayEmpty,
468470
ScalarFunction::ArrayExcept => Self::ArrayExcept,
@@ -1343,6 +1345,11 @@ pub fn parse_expr(
13431345
parse_expr(&args[0], registry)?,
13441346
parse_expr(&args[1], registry)?,
13451347
)),
1348+
ScalarFunction::ArraySort => Ok(array_sort(
1349+
parse_expr(&args[0], registry)?,
1350+
parse_expr(&args[1], registry)?,
1351+
parse_expr(&args[2], registry)?,
1352+
)),
13461353
ScalarFunction::ArrayPopFront => {
13471354
Ok(array_pop_front(parse_expr(&args[0], registry)?))
13481355
}

datafusion/proto/src/logical_plan/to_proto.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1502,6 +1502,7 @@ impl TryFrom<&BuiltinScalarFunction> for protobuf::ScalarFunction {
15021502
BuiltinScalarFunction::Rtrim => Self::Rtrim,
15031503
BuiltinScalarFunction::ToTimestamp => Self::ToTimestamp,
15041504
BuiltinScalarFunction::ArrayAppend => Self::ArrayAppend,
1505+
BuiltinScalarFunction::ArraySort => Self::ArraySort,
15051506
BuiltinScalarFunction::ArrayConcat => Self::ArrayConcat,
15061507
BuiltinScalarFunction::ArrayEmpty => Self::ArrayEmpty,
15071508
BuiltinScalarFunction::ArrayExcept => Self::ArrayExcept,

datafusion/sqllogictest/test_files/array.slt

Lines changed: 44 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1052,6 +1052,44 @@ select make_array(['a','b'], null);
10521052
----
10531053
[[a, b], ]
10541054

1055+
## array_sort (aliases: `list_sort`)
1056+
query ???
1057+
select array_sort(make_array(1, 3, null, 5, NULL, -5)), array_sort(make_array(1, 3, null, 2), 'ASC'), array_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST');
1058+
----
1059+
[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1]
1060+
1061+
query ?
1062+
select array_sort(column1, 'DESC', 'NULLS LAST') from arrays_values;
1063+
----
1064+
[10, 9, 8, 7, 6, 5, 4, 3, 2, ]
1065+
[20, 18, 17, 16, 15, 14, 13, 12, 11, ]
1066+
[30, 29, 28, 27, 26, 25, 23, 22, 21, ]
1067+
[40, 39, 38, 37, 35, 34, 33, 32, 31, ]
1068+
NULL
1069+
[50, 49, 48, 47, 46, 45, 44, 43, 42, 41]
1070+
[60, 59, 58, 57, 56, 55, 54, 52, 51, ]
1071+
[70, 69, 68, 67, 66, 65, 64, 63, 62, 61]
1072+
1073+
query ?
1074+
select array_sort(column1, 'ASC', 'NULLS FIRST') from arrays_values;
1075+
----
1076+
[, 2, 3, 4, 5, 6, 7, 8, 9, 10]
1077+
[, 11, 12, 13, 14, 15, 16, 17, 18, 20]
1078+
[, 21, 22, 23, 25, 26, 27, 28, 29, 30]
1079+
[, 31, 32, 33, 34, 35, 37, 38, 39, 40]
1080+
NULL
1081+
[41, 42, 43, 44, 45, 46, 47, 48, 49, 50]
1082+
[, 51, 52, 54, 55, 56, 57, 58, 59, 60]
1083+
[61, 62, 63, 64, 65, 66, 67, 68, 69, 70]
1084+
1085+
1086+
## list_sort (aliases: `array_sort`)
1087+
query ???
1088+
select list_sort(make_array(1, 3, null, 5, NULL, -5)), list_sort(make_array(1, 3, null, 2), 'ASC'), list_sort(make_array(1, 3, null, 2), 'desc', 'NULLS FIRST');
1089+
----
1090+
[, , -5, 1, 3, 5] [, 1, 2, 3] [, 3, 2, 1]
1091+
1092+
10551093
## array_append (aliases: `list_append`, `array_push_back`, `list_push_back`)
10561094

10571095
# TODO: array_append with NULLs
@@ -1224,7 +1262,7 @@ select array_prepend(make_array(1, 11, 111), column1), array_prepend(column2, ma
12241262

12251263
# array_repeat scalar function #1
12261264
query ????????
1227-
select
1265+
select
12281266
array_repeat(1, 5),
12291267
array_repeat(3.14, 3),
12301268
array_repeat('l', 4),
@@ -1257,7 +1295,7 @@ AS VALUES
12571295
(0, 3, 3.3, 'datafusion', make_array(8, 9));
12581296

12591297
query ??????
1260-
select
1298+
select
12611299
array_repeat(column2, column1),
12621300
array_repeat(column3, column1),
12631301
array_repeat(column4, column1),
@@ -1272,7 +1310,7 @@ from array_repeat_table;
12721310
[] [] [] [] [3, 3, 3] []
12731311

12741312
statement ok
1275-
drop table array_repeat_table;
1313+
drop table array_repeat_table;
12761314

12771315
## array_concat (aliases: `array_cat`, `list_concat`, `list_cat`)
12781316

@@ -2188,7 +2226,7 @@ select array_remove(make_array(1, 2, 2, 1, 1), 2), array_remove(make_array(1.0,
21882226
[1, 2, 1, 1] [2.0, 2.0, 1.0, 1.0] [h, e, l, o]
21892227

21902228
query ???
2191-
select
2229+
select
21922230
array_remove(make_array(1, null, 2, 3), 2),
21932231
array_remove(make_array(1.1, null, 2.2, 3.3), 1.1),
21942232
array_remove(make_array('a', null, 'bc'), 'a');
@@ -2887,7 +2925,7 @@ from array_intersect_table_3D;
28872925
query ??????
28882926
SELECT array_intersect(make_array(1,2,3), make_array(2,3,4)),
28892927
array_intersect(make_array(1,3,5), make_array(2,4,6)),
2890-
array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
2928+
array_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
28912929
array_intersect(make_array(true, false), make_array(true)),
28922930
array_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)),
28932931
array_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4]))
@@ -2918,7 +2956,7 @@ NULL
29182956
query ??????
29192957
SELECT list_intersect(make_array(1,2,3), make_array(2,3,4)),
29202958
list_intersect(make_array(1,3,5), make_array(2,4,6)),
2921-
list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
2959+
list_intersect(make_array('aa','bb','cc'), make_array('cc','aa','dd')),
29222960
list_intersect(make_array(true, false), make_array(true)),
29232961
list_intersect(make_array(1.1, 2.2, 3.3), make_array(2.2, 3.3, 4.4)),
29242962
list_intersect(make_array([1, 1], [2, 2], [3, 3]), make_array([2, 2], [3, 3], [4, 4]))

0 commit comments

Comments
 (0)