Skip to content

Commit 10fe3c1

Browse files
authored
Implement native support StringView for find in set (#11970)
* Implement native support StringView for find in set Signed-off-by: Chojan Shang <[email protected]> * Add more tests Signed-off-by: Chojan Shang <[email protected]> * Minor update --------- Signed-off-by: Chojan Shang <[email protected]>
1 parent 3c477bf commit 10fe3c1

File tree

3 files changed

+69
-31
lines changed

3 files changed

+69
-31
lines changed

datafusion/functions/src/unicode/find_in_set.rs

Lines changed: 41 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,11 +19,11 @@ use std::any::Any;
1919
use std::sync::Arc;
2020

2121
use arrow::array::{
22-
ArrayRef, ArrowPrimitiveType, GenericStringArray, OffsetSizeTrait, PrimitiveArray,
22+
ArrayAccessor, ArrayIter, ArrayRef, ArrowPrimitiveType, AsArray, OffsetSizeTrait,
23+
PrimitiveArray,
2324
};
2425
use arrow::datatypes::{ArrowNativeType, DataType, Int32Type, Int64Type};
2526

26-
use datafusion_common::cast::as_generic_string_array;
2727
use datafusion_common::{exec_err, Result};
2828
use datafusion_expr::TypeSignature::Exact;
2929
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
@@ -46,7 +46,11 @@ impl FindInSetFunc {
4646
use DataType::*;
4747
Self {
4848
signature: Signature::one_of(
49-
vec![Exact(vec![Utf8, Utf8]), Exact(vec![LargeUtf8, LargeUtf8])],
49+
vec![
50+
Exact(vec![Utf8View, Utf8View]),
51+
Exact(vec![Utf8, Utf8]),
52+
Exact(vec![LargeUtf8, LargeUtf8]),
53+
],
5054
Volatility::Immutable,
5155
),
5256
}
@@ -71,41 +75,52 @@ impl ScalarUDFImpl for FindInSetFunc {
7175
}
7276

7377
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
74-
match args[0].data_type() {
75-
DataType::Utf8 => {
76-
make_scalar_function(find_in_set::<Int32Type>, vec![])(args)
77-
}
78-
DataType::LargeUtf8 => {
79-
make_scalar_function(find_in_set::<Int64Type>, vec![])(args)
80-
}
81-
other => {
82-
exec_err!("Unsupported data type {other:?} for function find_in_set")
83-
}
84-
}
78+
make_scalar_function(find_in_set, vec![])(args)
8579
}
8680
}
8781

8882
///Returns a value in the range of 1 to N if the string str is in the string list strlist consisting of N substrings
8983
///A string list is a string composed of substrings separated by , characters.
90-
pub fn find_in_set<T: ArrowPrimitiveType>(args: &[ArrayRef]) -> Result<ArrayRef>
91-
where
92-
T::Native: OffsetSizeTrait,
93-
{
84+
fn find_in_set(args: &[ArrayRef]) -> Result<ArrayRef> {
9485
if args.len() != 2 {
9586
return exec_err!(
9687
"find_in_set was called with {} arguments. It requires 2.",
9788
args.len()
9889
);
9990
}
91+
match args[0].data_type() {
92+
DataType::Utf8 => {
93+
let string_array = args[0].as_string::<i32>();
94+
let str_list_array = args[1].as_string::<i32>();
95+
find_in_set_general::<Int32Type, _>(string_array, str_list_array)
96+
}
97+
DataType::LargeUtf8 => {
98+
let string_array = args[0].as_string::<i64>();
99+
let str_list_array = args[1].as_string::<i64>();
100+
find_in_set_general::<Int64Type, _>(string_array, str_list_array)
101+
}
102+
DataType::Utf8View => {
103+
let string_array = args[0].as_string_view();
104+
let str_list_array = args[1].as_string_view();
105+
find_in_set_general::<Int32Type, _>(string_array, str_list_array)
106+
}
107+
other => {
108+
exec_err!("Unsupported data type {other:?} for function find_in_set")
109+
}
110+
}
111+
}
100112

101-
let str_array: &GenericStringArray<T::Native> =
102-
as_generic_string_array::<T::Native>(&args[0])?;
103-
let str_list_array: &GenericStringArray<T::Native> =
104-
as_generic_string_array::<T::Native>(&args[1])?;
105-
106-
let result = str_array
107-
.iter()
108-
.zip(str_list_array.iter())
113+
pub fn find_in_set_general<'a, T: ArrowPrimitiveType, V: ArrayAccessor<Item = &'a str>>(
114+
string_array: V,
115+
str_list_array: V,
116+
) -> Result<ArrayRef>
117+
where
118+
T::Native: OffsetSizeTrait,
119+
{
120+
let string_iter = ArrayIter::new(string_array);
121+
let str_list_iter = ArrayIter::new(str_list_array);
122+
let result = string_iter
123+
.zip(str_list_iter)
109124
.map(|(string, str_list)| match (string, str_list) {
110125
(Some(string), Some(str_list)) => {
111126
let mut res = 0;

datafusion/sqllogictest/test_files/functions.slt

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1092,7 +1092,7 @@ docs.apache.com docs com
10921092
community.influxdata.com community com
10931093
arrow.apache.org arrow org
10941094

1095-
1095+
# find_in_set tests
10961096
query I
10971097
SELECT find_in_set('b', 'a,b,c,d')
10981098
----
@@ -1136,6 +1136,23 @@ SELECT find_in_set(NULL, NULL)
11361136
----
11371137
NULL
11381138

1139+
# find_in_set tests with utf8view
1140+
query I
1141+
SELECT find_in_set(arrow_cast('b', 'Utf8View'), 'a,b,c,d')
1142+
----
1143+
2
1144+
1145+
1146+
query I
1147+
SELECT find_in_set('a', arrow_cast('a,b,c,d,a', 'Utf8View'))
1148+
----
1149+
1
1150+
1151+
query I
1152+
SELECT find_in_set(arrow_cast('', 'Utf8View'), arrow_cast('a,b,c,d,a', 'Utf8View'))
1153+
----
1154+
0
1155+
11391156
# Verify that multiple calls to volatile functions like `random()` are not combined / optimized away
11401157
query B
11411158
SELECT r FROM (SELECT r1 == r2 r, r1, r2 FROM (SELECT random()+1 r1, random()+1 r2) WHERE r1 > 0 AND r2 > 0)

datafusion/sqllogictest/test_files/string_view.slt

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -992,18 +992,24 @@ logical_plan
992992
02)--TableScan: test projection=[column1_utf8view]
993993

994994
## Ensure no casts for FIND_IN_SET
995-
## TODO file ticket
996995
query TT
997996
EXPLAIN SELECT
998997
FIND_IN_SET(column1_utf8view, 'a,b,c,d') as c
999998
FROM test;
1000999
----
10011000
logical_plan
1002-
01)Projection: find_in_set(CAST(test.column1_utf8view AS Utf8), Utf8("a,b,c,d")) AS c
1001+
01)Projection: find_in_set(test.column1_utf8view, Utf8View("a,b,c,d")) AS c
10031002
02)--TableScan: test projection=[column1_utf8view]
10041003

1005-
1006-
1004+
query I
1005+
SELECT
1006+
FIND_IN_SET(column1_utf8view, 'a,b,c,d') as c
1007+
FROM test;
1008+
----
1009+
0
1010+
0
1011+
0
1012+
NULL
10071013

10081014
statement ok
10091015
drop table test;

0 commit comments

Comments
 (0)