Skip to content

Commit f98f8a9

Browse files
tlm365alamb
andauthored
Implement native support StringView for REPEAT (#11962)
* Implement native support StringView for REPEAT Signed-off-by: Tai Le Manh <[email protected]> * cargo fmt --------- Signed-off-by: Tai Le Manh <[email protected]> Co-authored-by: Andrew Lamb <[email protected]>
1 parent e4be013 commit f98f8a9

File tree

2 files changed

+73
-14
lines changed

2 files changed

+73
-14
lines changed

datafusion/functions/src/string/repeat.rs

Lines changed: 72 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
use std::any::Any;
1919
use std::sync::Arc;
2020

21-
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait};
21+
use arrow::array::{ArrayRef, GenericStringArray, OffsetSizeTrait, StringArray};
2222
use arrow::datatypes::DataType;
2323

24-
use datafusion_common::cast::{as_generic_string_array, as_int64_array};
24+
use datafusion_common::cast::{
25+
as_generic_string_array, as_int64_array, as_string_view_array,
26+
};
2527
use datafusion_common::{exec_err, Result};
2628
use datafusion_expr::TypeSignature::*;
2729
use datafusion_expr::{ColumnarValue, Volatility};
@@ -45,7 +47,14 @@ impl RepeatFunc {
4547
use DataType::*;
4648
Self {
4749
signature: Signature::one_of(
48-
vec![Exact(vec![Utf8, Int64]), Exact(vec![LargeUtf8, Int64])],
50+
vec![
51+
// Planner attempts coercion to the target type starting with the most preferred candidate.
52+
// For example, given input `(Utf8View, Int64)`, it first tries coercing to `(Utf8View, Int64)`.
53+
// If that fails, it proceeds to `(Utf8, Int64)`.
54+
Exact(vec![Utf8View, Int64]),
55+
Exact(vec![Utf8, Int64]),
56+
Exact(vec![LargeUtf8, Int64]),
57+
],
4958
Volatility::Immutable,
5059
),
5160
}
@@ -71,9 +80,10 @@ impl ScalarUDFImpl for RepeatFunc {
7180

7281
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
7382
match args[0].data_type() {
83+
DataType::Utf8View => make_scalar_function(repeat_utf8view, vec![])(args),
7484
DataType::Utf8 => make_scalar_function(repeat::<i32>, vec![])(args),
7585
DataType::LargeUtf8 => make_scalar_function(repeat::<i64>, vec![])(args),
76-
other => exec_err!("Unsupported data type {other:?} for function repeat"),
86+
other => exec_err!("Unsupported data type {other:?} for function repeat. Expected Utf8, Utf8View or LargeUtf8"),
7787
}
7888
}
7989
}
@@ -87,18 +97,35 @@ fn repeat<T: OffsetSizeTrait>(args: &[ArrayRef]) -> Result<ArrayRef> {
8797
let result = string_array
8898
.iter()
8999
.zip(number_array.iter())
90-
.map(|(string, number)| match (string, number) {
91-
(Some(string), Some(number)) if number >= 0 => {
92-
Some(string.repeat(number as usize))
93-
}
94-
(Some(_), Some(_)) => Some("".to_string()),
95-
_ => None,
96-
})
100+
.map(|(string, number)| repeat_common(string, number))
97101
.collect::<GenericStringArray<T>>();
98102

99103
Ok(Arc::new(result) as ArrayRef)
100104
}
101105

106+
fn repeat_utf8view(args: &[ArrayRef]) -> Result<ArrayRef> {
107+
let string_view_array = as_string_view_array(&args[0])?;
108+
let number_array = as_int64_array(&args[1])?;
109+
110+
let result = string_view_array
111+
.iter()
112+
.zip(number_array.iter())
113+
.map(|(string, number)| repeat_common(string, number))
114+
.collect::<StringArray>();
115+
116+
Ok(Arc::new(result) as ArrayRef)
117+
}
118+
119+
fn repeat_common(string: Option<&str>, number: Option<i64>) -> Option<String> {
120+
match (string, number) {
121+
(Some(string), Some(number)) if number >= 0 => {
122+
Some(string.repeat(number as usize))
123+
}
124+
(Some(_), Some(_)) => Some("".to_string()),
125+
_ => None,
126+
}
127+
}
128+
102129
#[cfg(test)]
103130
mod tests {
104131
use arrow::array::{Array, StringArray};
@@ -124,7 +151,6 @@ mod tests {
124151
Utf8,
125152
StringArray
126153
);
127-
128154
test_function!(
129155
RepeatFunc::new(),
130156
&[
@@ -148,6 +174,40 @@ mod tests {
148174
StringArray
149175
);
150176

177+
test_function!(
178+
RepeatFunc::new(),
179+
&[
180+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
181+
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
182+
],
183+
Ok(Some("PgPgPgPg")),
184+
&str,
185+
Utf8,
186+
StringArray
187+
);
188+
test_function!(
189+
RepeatFunc::new(),
190+
&[
191+
ColumnarValue::Scalar(ScalarValue::Utf8View(None)),
192+
ColumnarValue::Scalar(ScalarValue::Int64(Some(4))),
193+
],
194+
Ok(None),
195+
&str,
196+
Utf8,
197+
StringArray
198+
);
199+
test_function!(
200+
RepeatFunc::new(),
201+
&[
202+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from("Pg")))),
203+
ColumnarValue::Scalar(ScalarValue::Int64(None)),
204+
],
205+
Ok(None),
206+
&str,
207+
Utf8,
208+
StringArray
209+
);
210+
151211
Ok(())
152212
}
153213
}

datafusion/sqllogictest/test_files/string_view.slt

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -860,14 +860,13 @@ logical_plan
860860

861861

862862
## Ensure no casts for REPEAT
863-
## TODO file ticket
864863
query TT
865864
EXPLAIN SELECT
866865
REPEAT(column1_utf8view, 2) as c1
867866
FROM test;
868867
----
869868
logical_plan
870-
01)Projection: repeat(CAST(test.column1_utf8view AS Utf8), Int64(2)) AS c1
869+
01)Projection: repeat(test.column1_utf8view, Int64(2)) AS c1
871870
02)--TableScan: test projection=[column1_utf8view]
872871

873872
## Ensure no casts for REPLACE

0 commit comments

Comments
 (0)