Skip to content

Commit 62e23a2

Browse files
authored
bug: Fix edge cases in array_slice (#14489)
This commit fixes the following edge cases in the array_slice function so that it's semantics match DuckDB: - When begin < 0 and -begin > length, begin is clamped to the beginning of the list. - When step < 0 and begin = end, then the result should be a list with the single element found at index begin/end. Fixes #10548
1 parent 5239d1a commit 62e23a2

File tree

2 files changed

+37
-5
lines changed

2 files changed

+37
-5
lines changed

datafusion/functions-nested/src/extract.rs

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -482,7 +482,17 @@ where
482482
// 0 ~ len - 1
483483
let adjusted_zero_index = if index < 0 {
484484
if let Ok(index) = index.try_into() {
485-
index + len
485+
// When index < 0 and -index > length, index is clamped to the beginning of the list.
486+
// Otherwise, when index < 0, the index is counted from the end of the list.
487+
//
488+
// Note, we actually test the contrapositive, index < -length, because negating a
489+
// negative will panic if the negative is equal to the smallest representable value
490+
// while negating a positive is always safe.
491+
if index < (O::zero() - O::one()) * len {
492+
O::zero()
493+
} else {
494+
index + len
495+
}
486496
} else {
487497
return exec_err!("array_slice got invalid index: {}", index);
488498
}
@@ -570,7 +580,7 @@ where
570580
"array_slice got invalid stride: {:?}, it cannot be 0",
571581
stride
572582
);
573-
} else if (from <= to && stride.is_negative())
583+
} else if (from < to && stride.is_negative())
574584
|| (from > to && stride.is_positive())
575585
{
576586
// return empty array
@@ -582,7 +592,7 @@ where
582592
internal_datafusion_err!("array_slice got invalid stride: {}", stride)
583593
})?;
584594

585-
if from <= to {
595+
if from <= to && stride > O::zero() {
586596
assert!(start + to <= end);
587597
if stride.eq(&O::one()) {
588598
// stride is default to 1

datafusion/sqllogictest/test_files/array.slt

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1941,12 +1941,12 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -4
19411941
query ??
19421942
select array_slice(make_array(1, 2, 3, 4, 5), -7, -2), array_slice(make_array('h', 'e', 'l', 'l', 'o'), -7, -3);
19431943
----
1944-
[] []
1944+
[1, 2, 3, 4] [h, e, l]
19451945

19461946
query ??
19471947
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -7, -2), array_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -7, -3);
19481948
----
1949-
[] []
1949+
[1, 2, 3, 4] [h, e, l]
19501950

19511951
# array_slice scalar function #20 (with negative indexes; nested array)
19521952
query ??
@@ -1993,6 +1993,28 @@ select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2,
19931993
----
19941994
[2, 3, 4] [h, e]
19951995

1996+
# array_slice scalar function #24 (with first negative index larger than len)
1997+
query ??
1998+
select array_slice(make_array(1, 2, 3, 4, 5), -2147483648, 1), list_slice(make_array('h', 'e', 'l', 'l', 'o'), -2147483648, 1);
1999+
----
2000+
[1] [h]
2001+
2002+
query ??
2003+
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), -9223372036854775808, 1), list_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), -9223372036854775808, 1);
2004+
----
2005+
[1] [h]
2006+
2007+
# array_slice scalar function #25 (with negative step and equal indexes)
2008+
query ??
2009+
select array_slice(make_array(1, 2, 3, 4, 5), 2, 2, -1), list_slice(make_array('h', 'e', 'l', 'l', 'o'), 2, 2, -1);
2010+
----
2011+
[2] [e]
2012+
2013+
query ??
2014+
select array_slice(arrow_cast(make_array(1, 2, 3, 4, 5), 'LargeList(Int64)'), 2, 2, -1), list_slice(arrow_cast(make_array('h', 'e', 'l', 'l', 'o'), 'LargeList(Utf8)'), 2, 2, -1);
2015+
----
2016+
[2] [e]
2017+
19962018
# array_slice with columns
19972019
query ?
19982020
select array_slice(column1, column2, column3) from slices;

0 commit comments

Comments
 (0)