Skip to content

Commit 06bcf33

Browse files
authored
Update REVERSE scalar function to support Utf8View (#11973)
1 parent 4baa901 commit 06bcf33

File tree

3 files changed

+88
-51
lines changed

3 files changed

+88
-51
lines changed

datafusion/functions/src/unicode/reverse.rs

Lines changed: 57 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -18,12 +18,14 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
21+
use arrow::array::{
22+
Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray,
23+
OffsetSizeTrait,
24+
};
2225
use arrow::datatypes::DataType;
23-
24-
use datafusion_common::cast::as_generic_string_array;
2526
use datafusion_common::{exec_err, Result};
2627
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
28+
use DataType::{LargeUtf8, Utf8, Utf8View};
2729

2830
use crate::utils::{make_scalar_function, utf8_to_str_type};
2931

@@ -44,7 +46,7 @@ impl ReverseFunc {
4446
Self {
4547
signature: Signature::uniform(
4648
1,
47-
vec![Utf8, LargeUtf8],
49+
vec![Utf8View, Utf8, LargeUtf8],
4850
Volatility::Immutable,
4951
),
5052
}
@@ -70,8 +72,8 @@ impl ScalarUDFImpl for ReverseFunc {
7072

7173
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
7274
match args[0].data_type() {
73-
DataType::Utf8 => make_scalar_function(reverse::<i32>, vec![])(args),
74-
DataType::LargeUtf8 => make_scalar_function(reverse::<i64>, vec![])(args),
75+
Utf8 | Utf8View => make_scalar_function(reverse::<i32>, vec![])(args),
76+
LargeUtf8 => make_scalar_function(reverse::<i64>, vec![])(args),
7577
other => {
7678
exec_err!("Unsupported data type {other:?} for function reverse")
7779
}
@@ -83,10 +85,17 @@ impl ScalarUDFImpl for ReverseFunc {
8385
/// reverse('abcde') = 'edcba'
8486
/// The implementation uses UTF-8 code points as characters
8587
pub fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
86-
let string_array = as_generic_string_array::<T>(&args[0])?;
88+
if args[0].data_type() == &Utf8View {
89+
reverse_impl::<T, _>(args[0].as_string_view())
90+
} else {
91+
reverse_impl::<T, _>(args[0].as_string::<T>())
92+
}
93+
}
8794

88-
let result = string_array
89-
.iter()
95+
fn reverse_impl<'a, T: OffsetSizeTrait, V: ArrayAccessor<Item = &'a str>>(
96+
string_array: V,
97+
) -> Result<ArrayRef> {
98+
let result = ArrayIter::new(string_array)
9099
.map(|string| string.map(|string: &str| string.chars().rev().collect::<String>()))
91100
.collect::<GenericStringArray<T>>();
92101

@@ -95,59 +104,58 @@ pub fn reverse<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
95104

96105
#[cfg(test)]
97106
mod tests {
98-
use arrow::array::{Array, StringArray};
99-
use arrow::datatypes::DataType::Utf8;
107+
use arrow::array::{Array, LargeStringArray, StringArray};
108+
use arrow::datatypes::DataType::{LargeUtf8, Utf8};
100109

101110
use datafusion_common::{Result, ScalarValue};
102111
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
103112

104113
use crate::unicode::reverse::ReverseFunc;
105114
use crate::utils::test::test_function;
106115

116+
macro_rules! test_reverse {
117+
($INPUT:expr, $EXPECTED:expr) => {
118+
test_function!(
119+
ReverseFunc::new(),
120+
&[ColumnarValue::Scalar(ScalarValue::Utf8($INPUT))],
121+
$EXPECTED,
122+
&str,
123+
Utf8,
124+
StringArray
125+
);
126+
127+
test_function!(
128+
ReverseFunc::new(),
129+
&[ColumnarValue::Scalar(ScalarValue::LargeUtf8($INPUT))],
130+
$EXPECTED,
131+
&str,
132+
LargeUtf8,
133+
LargeStringArray
134+
);
135+
136+
test_function!(
137+
ReverseFunc::new(),
138+
&[ColumnarValue::Scalar(ScalarValue::Utf8View($INPUT))],
139+
$EXPECTED,
140+
&str,
141+
Utf8,
142+
StringArray
143+
);
144+
};
145+
}
146+
107147
#[test]
108148
fn test_functions() -> Result<()> {
109-
test_function!(
110-
ReverseFunc::new(),
111-
&[ColumnarValue::Scalar(ScalarValue::from("abcde"))],
112-
Ok(Some("edcba")),
113-
&str,
114-
Utf8,
115-
StringArray
116-
);
117-
test_function!(
118-
ReverseFunc::new(),
119-
&[ColumnarValue::Scalar(ScalarValue::from("loẅks"))],
120-
Ok(Some("sk̈wol")),
121-
&str,
122-
Utf8,
123-
StringArray
124-
);
125-
test_function!(
126-
ReverseFunc::new(),
127-
&[ColumnarValue::Scalar(ScalarValue::from("loẅks"))],
128-
Ok(Some("sk̈wol")),
129-
&str,
130-
Utf8,
131-
StringArray
132-
);
133-
test_function!(
134-
ReverseFunc::new(),
135-
&[ColumnarValue::Scalar(ScalarValue::Utf8(None))],
136-
Ok(None),
137-
&str,
138-
Utf8,
139-
StringArray
140-
);
149+
test_reverse!(Some("abcde".into()), Ok(Some("edcba")));
150+
test_reverse!(Some("loẅks".into()), Ok(Some("sk̈wol")));
151+
test_reverse!(Some("loẅks".into()), Ok(Some("sk̈wol")));
152+
test_reverse!(None, Ok(None));
141153
#[cfg(not(feature = "unicode_expressions"))]
142-
test_function!(
143-
ReverseFunc::new(),
144-
&[ColumnarValue::Scalar(ScalarValue::from("abcde"))],
154+
test_reverse!(
155+
Some("abcde".into()),
145156
internal_err!(
146157
"function reverse requires compilation with feature flag: unicode_expressions."
147158
),
148-
&str,
149-
Utf8,
150-
StringArray
151159
);
152160

153161
Ok(())

datafusion/sqllogictest/test_files/functions.slt

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,6 +234,16 @@ SELECT reverse('abcde')
234234
----
235235
edcba
236236

237+
query T
238+
SELECT reverse(arrow_cast('abcde', 'LargeUtf8'))
239+
----
240+
edcba
241+
242+
query T
243+
SELECT reverse(arrow_cast('abcde', 'Utf8View'))
244+
----
245+
edcba
246+
237247
query T
238248
SELECT reverse(arrow_cast('abcde', 'Dictionary(Int32, Utf8)'))
239249
----
@@ -244,11 +254,31 @@ SELECT reverse('loẅks')
244254
----
245255
sk̈wol
246256

257+
query T
258+
SELECT reverse(arrow_cast('loẅks', 'LargeUtf8'))
259+
----
260+
sk̈wol
261+
262+
query T
263+
SELECT reverse(arrow_cast('loẅks', 'Utf8View'))
264+
----
265+
sk̈wol
266+
247267
query T
248268
SELECT reverse(NULL)
249269
----
250270
NULL
251271

272+
query T
273+
SELECT reverse(arrow_cast(NULL, 'LargeUtf8'))
274+
----
275+
NULL
276+
277+
query T
278+
SELECT reverse(arrow_cast(NULL, 'Utf8View'))
279+
----
280+
NULL
281+
252282
query T
253283
SELECT right('abcde', -2)
254284
----

datafusion/sqllogictest/test_files/string_view.slt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -890,14 +890,13 @@ logical_plan
890890
03)----TableScan: test projection=[column1_utf8view, column2_utf8view]
891891

892892
## Ensure no casts for REVERSE
893-
## TODO file ticket
894893
query TT
895894
EXPLAIN SELECT
896895
REVERSE(column1_utf8view) as c1
897896
FROM test;
898897
----
899898
logical_plan
900-
01)Projection: reverse(CAST(test.column1_utf8view AS Utf8)) AS c1
899+
01)Projection: reverse(test.column1_utf8view) AS c1
901900
02)--TableScan: test projection=[column1_utf8view]
902901

903902

0 commit comments

Comments
 (0)