Skip to content

Commit 7ead2ad

Browse files
Kev1n8alamb
andauthored
Improve StringView support for SUBSTR (#12044)
* operate stringview instead of generating string in SUBSTR * treat Utf8View as Text in sqllogictests output * add bench to see enhancement of utf8view against utf8 and large_utf8 * fix a tiny bug * make clippy happy * add tests to cover stringview larger than 12B and correct the code * better comments * fix lint * correct feature setting * avoid expensive utf8 and some other checks * fix lint * remove unnecessary indirection * add optimized_utf8_to_str_type * Simplify type check * Use ByteView * update datafusion-cli.lock * Remove duration override * format toml * refactor the code, using append_view_u128 from arrow * manually collect the views and nulls * remove bench file and fix some comments * fix tiny mistake * Update Cargo.lock --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 5756f39 commit 7ead2ad

File tree

1 file changed

+216
-26
lines changed

1 file changed

+216
-26
lines changed

datafusion/functions/src/unicode/substr.rs

Lines changed: 216 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -19,18 +19,18 @@ use std::any::Any;
1919
use std::cmp::max;
2020
use std::sync::Arc;
2121

22+
use crate::utils::{make_scalar_function, utf8_to_str_type};
2223
use arrow::array::{
23-
ArrayAccessor, ArrayIter, ArrayRef, AsArray, GenericStringArray, OffsetSizeTrait,
24+
make_view, Array, ArrayAccessor, ArrayIter, ArrayRef, AsArray, ByteView,
25+
GenericStringArray, OffsetSizeTrait, StringViewArray,
2426
};
2527
use arrow::datatypes::DataType;
26-
28+
use arrow_buffer::{NullBufferBuilder, ScalarBuffer};
2729
use datafusion_common::cast::as_int64_array;
2830
use datafusion_common::{exec_datafusion_err, exec_err, Result};
2931
use datafusion_expr::TypeSignature::Exact;
3032
use datafusion_expr::{ColumnarValue, ScalarUDFImpl, Signature, Volatility};
3133

32-
use crate::utils::{make_scalar_function, utf8_to_str_type};
33-
3434
#[derive(Debug)]
3535
pub struct SubstrFunc {
3636
signature: Signature,
@@ -77,7 +77,11 @@ impl ScalarUDFImpl for SubstrFunc {
7777
}
7878

7979
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
80-
utf8_to_str_type(&arg_types[0], "substr")
80+
if arg_types[0] == DataType::Utf8View {
81+
Ok(DataType::Utf8View)
82+
} else {
83+
utf8_to_str_type(&arg_types[0], "substr")
84+
}
8185
}
8286

8387
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
@@ -89,29 +93,188 @@ impl ScalarUDFImpl for SubstrFunc {
8993
}
9094
}
9195

96+
/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
97+
/// substr('alphabet', 3) = 'phabet'
98+
/// substr('alphabet', 3, 2) = 'ph'
99+
/// The implementation uses UTF-8 code points as characters
92100
pub fn substr(args: &[ArrayRef]) -> Result<ArrayRef> {
93101
match args[0].data_type() {
94102
DataType::Utf8 => {
95103
let string_array = args[0].as_string::<i32>();
96-
calculate_substr::<_, i32>(string_array, &args[1..])
104+
string_substr::<_, i32>(string_array, &args[1..])
97105
}
98106
DataType::LargeUtf8 => {
99107
let string_array = args[0].as_string::<i64>();
100-
calculate_substr::<_, i64>(string_array, &args[1..])
108+
string_substr::<_, i64>(string_array, &args[1..])
101109
}
102110
DataType::Utf8View => {
103111
let string_array = args[0].as_string_view();
104-
calculate_substr::<_, i32>(string_array, &args[1..])
112+
string_view_substr(string_array, &args[1..])
105113
}
106-
other => exec_err!("Unsupported data type {other:?} for function substr"),
114+
other => exec_err!(
115+
"Unsupported data type {other:?} for function substr,\
116+
expected Utf8View, Utf8 or LargeUtf8."
117+
),
107118
}
108119
}
109120

110-
/// Extracts the substring of string starting at the start'th character, and extending for count characters if that is specified. (Same as substring(string from start for count).)
111-
/// substr('alphabet', 3) = 'phabet'
112-
/// substr('alphabet', 3, 2) = 'ph'
113-
/// The implementation uses UTF-8 code points as characters
114-
fn calculate_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
121+
// Return the exact byte index for [start, end), set count to -1 to ignore count
122+
fn get_true_start_end(input: &str, start: usize, count: i64) -> (usize, usize) {
123+
let (mut st, mut ed) = (input.len(), input.len());
124+
let mut start_counting = false;
125+
let mut cnt = 0;
126+
for (char_cnt, (byte_cnt, _)) in input.char_indices().enumerate() {
127+
if char_cnt == start {
128+
st = byte_cnt;
129+
if count != -1 {
130+
start_counting = true;
131+
} else {
132+
break;
133+
}
134+
}
135+
if start_counting {
136+
if cnt == count {
137+
ed = byte_cnt;
138+
break;
139+
}
140+
cnt += 1;
141+
}
142+
}
143+
(st, ed)
144+
}
145+
146+
/// Make a `u128` based on the given substr, start(offset to view.offset), and
147+
/// push into to the given buffers
148+
fn make_and_append_view(
149+
views_buffer: &mut Vec<u128>,
150+
null_builder: &mut NullBufferBuilder,
151+
raw: &u128,
152+
substr: &str,
153+
start: u32,
154+
) {
155+
let substr_len = substr.len();
156+
if substr_len == 0 {
157+
null_builder.append_null();
158+
views_buffer.push(0);
159+
} else {
160+
let sub_view = if substr_len > 12 {
161+
let view = ByteView::from(*raw);
162+
make_view(substr.as_bytes(), view.buffer_index, view.offset + start)
163+
} else {
164+
// inline value does not need block id or offset
165+
make_view(substr.as_bytes(), 0, 0)
166+
};
167+
views_buffer.push(sub_view);
168+
null_builder.append_non_null();
169+
}
170+
}
171+
172+
// The decoding process refs the trait at: arrow/arrow-data/src/byte_view.rs:44
173+
// From<u128> for ByteView
174+
fn string_view_substr(
175+
string_view_array: &StringViewArray,
176+
args: &[ArrayRef],
177+
) -> Result<ArrayRef> {
178+
let mut views_buf = Vec::with_capacity(string_view_array.len());
179+
let mut null_builder = NullBufferBuilder::new(string_view_array.len());
180+
181+
let start_array = as_int64_array(&args[0])?;
182+
183+
match args.len() {
184+
1 => {
185+
for (idx, (raw, start)) in string_view_array
186+
.views()
187+
.iter()
188+
.zip(start_array.iter())
189+
.enumerate()
190+
{
191+
if let Some(start) = start {
192+
let start = (start - 1).max(0) as usize;
193+
194+
// Safety:
195+
// idx is always smaller or equal to string_view_array.views.len()
196+
unsafe {
197+
let str = string_view_array.value_unchecked(idx);
198+
let (start, end) = get_true_start_end(str, start, -1);
199+
let substr = &str[start..end];
200+
201+
make_and_append_view(
202+
&mut views_buf,
203+
&mut null_builder,
204+
raw,
205+
substr,
206+
start as u32,
207+
);
208+
}
209+
} else {
210+
null_builder.append_null();
211+
views_buf.push(0);
212+
}
213+
}
214+
}
215+
2 => {
216+
let count_array = as_int64_array(&args[1])?;
217+
for (idx, ((raw, start), count)) in string_view_array
218+
.views()
219+
.iter()
220+
.zip(start_array.iter())
221+
.zip(count_array.iter())
222+
.enumerate()
223+
{
224+
if let (Some(start), Some(count)) = (start, count) {
225+
let start = (start - 1).max(0) as usize;
226+
if count < 0 {
227+
return exec_err!(
228+
"negative substring length not allowed: substr(<str>, {start}, {count})"
229+
);
230+
} else {
231+
// Safety:
232+
// idx is always smaller or equal to string_view_array.views.len()
233+
unsafe {
234+
let str = string_view_array.value_unchecked(idx);
235+
let (start, end) = get_true_start_end(str, start, count);
236+
let substr = &str[start..end];
237+
238+
make_and_append_view(
239+
&mut views_buf,
240+
&mut null_builder,
241+
raw,
242+
substr,
243+
start as u32,
244+
);
245+
}
246+
}
247+
} else {
248+
null_builder.append_null();
249+
views_buf.push(0);
250+
}
251+
}
252+
}
253+
other => {
254+
return exec_err!(
255+
"substr was called with {other} arguments. It requires 2 or 3."
256+
)
257+
}
258+
}
259+
260+
let views_buf = ScalarBuffer::from(views_buf);
261+
let nulls_buf = null_builder.finish();
262+
263+
// Safety:
264+
// (1) The blocks of the given views are all provided
265+
// (2) Each of the range `view.offset+start..end` of view in views_buf is within
266+
// the bounds of each of the blocks
267+
unsafe {
268+
let array = StringViewArray::new_unchecked(
269+
views_buf,
270+
string_view_array.data_buffers().to_vec(),
271+
nulls_buf,
272+
);
273+
Ok(Arc::new(array) as ArrayRef)
274+
}
275+
}
276+
277+
fn string_substr<'a, V, T>(string_array: V, args: &[ArrayRef]) -> Result<ArrayRef>
115278
where
116279
V: ArrayAccessor<Item = &'a str>,
117280
T: OffsetSizeTrait,
@@ -174,8 +337,8 @@ where
174337

175338
#[cfg(test)]
176339
mod tests {
177-
use arrow::array::{Array, StringArray};
178-
use arrow::datatypes::DataType::Utf8;
340+
use arrow::array::{Array, StringArray, StringViewArray};
341+
use arrow::datatypes::DataType::{Utf8, Utf8View};
179342

180343
use datafusion_common::{exec_err, Result, ScalarValue};
181344
use datafusion_expr::{ColumnarValue, ScalarUDFImpl};
@@ -193,8 +356,8 @@ mod tests {
193356
],
194357
Ok(None),
195358
&str,
196-
Utf8,
197-
StringArray
359+
Utf8View,
360+
StringViewArray
198361
);
199362
test_function!(
200363
SubstrFunc::new(),
@@ -206,8 +369,35 @@ mod tests {
206369
],
207370
Ok(Some("alphabet")),
208371
&str,
209-
Utf8,
210-
StringArray
372+
Utf8View,
373+
StringViewArray
374+
);
375+
test_function!(
376+
SubstrFunc::new(),
377+
&[
378+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
379+
"this és longer than 12B"
380+
)))),
381+
ColumnarValue::Scalar(ScalarValue::from(5i64)),
382+
ColumnarValue::Scalar(ScalarValue::from(2i64)),
383+
],
384+
Ok(Some(" é")),
385+
&str,
386+
Utf8View,
387+
StringViewArray
388+
);
389+
test_function!(
390+
SubstrFunc::new(),
391+
&[
392+
ColumnarValue::Scalar(ScalarValue::Utf8View(Some(String::from(
393+
"this is longer than 12B"
394+
)))),
395+
ColumnarValue::Scalar(ScalarValue::from(5i64)),
396+
],
397+
Ok(Some(" is longer than 12B")),
398+
&str,
399+
Utf8View,
400+
StringViewArray
211401
);
212402
test_function!(
213403
SubstrFunc::new(),
@@ -219,8 +409,8 @@ mod tests {
219409
],
220410
Ok(Some("ésoj")),
221411
&str,
222-
Utf8,
223-
StringArray
412+
Utf8View,
413+
StringViewArray
224414
);
225415
test_function!(
226416
SubstrFunc::new(),
@@ -233,8 +423,8 @@ mod tests {
233423
],
234424
Ok(Some("ph")),
235425
&str,
236-
Utf8,
237-
StringArray
426+
Utf8View,
427+
StringViewArray
238428
);
239429
test_function!(
240430
SubstrFunc::new(),
@@ -247,8 +437,8 @@ mod tests {
247437
],
248438
Ok(Some("phabet")),
249439
&str,
250-
Utf8,
251-
StringArray
440+
Utf8View,
441+
StringViewArray
252442
);
253443
test_function!(
254444
SubstrFunc::new(),

0 commit comments

Comments
 (0)