Skip to content

Commit 0be56a3

Browse files
committed
Add utf8view support for user defined
1 parent 83fe677 commit 0be56a3

File tree

3 files changed

+35
-14
lines changed

3 files changed

+35
-14
lines changed

datafusion/core/tests/user_defined/user_defined_plan.rs

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ use std::hash::Hash;
6363
use std::task::{Context, Poll};
6464
use std::{any::Any, collections::BTreeMap, fmt, sync::Arc};
6565

66+
use arrow::array::{Array, ArrayRef, StringViewArray};
6667
use arrow::{
6768
array::{Int64Array, StringArray},
6869
datatypes::SchemaRef,
@@ -100,6 +101,7 @@ use datafusion_optimizer::AnalyzerRule;
100101
use datafusion_physical_plan::execution_plan::{Boundedness, EmissionType};
101102

102103
use async_trait::async_trait;
104+
use datafusion_common::cast::as_string_view_array;
103105
use futures::{Stream, StreamExt};
104106

105107
/// Execute the specified sql and return the resulting record batches
@@ -796,22 +798,30 @@ fn accumulate_batch(
796798
k: &usize,
797799
) -> BTreeMap<i64, String> {
798800
let num_rows = input_batch.num_rows();
801+
799802
// Assuming the input columns are
800-
// column[0]: customer_id / UTF8
803+
// column[0]: customer_id / UTF8 or UTF8View
801804
// column[1]: revenue: Int64
802-
let customer_id =
803-
as_string_array(input_batch.column(0)).expect("Column 0 is not customer_id");
804805

806+
let customer_id_column = input_batch.column(0);
805807
let revenue = as_int64_array(input_batch.column(1)).unwrap();
806808

807809
for row in 0..num_rows {
808-
add_row(
809-
&mut top_values,
810-
customer_id.value(row),
811-
revenue.value(row),
812-
k,
813-
);
810+
let customer_id = match customer_id_column.data_type() {
811+
arrow::datatypes::DataType::Utf8 => {
812+
let array = as_string_array(customer_id_column).unwrap();
813+
array.value(row)
814+
}
815+
arrow::datatypes::DataType::Utf8View => {
816+
let array = as_string_view_array(customer_id_column).unwrap();
817+
array.value(row)
818+
}
819+
_ => panic!("Unsupported customer_id type"),
820+
};
821+
822+
add_row(&mut top_values, customer_id, revenue.value(row), k);
814823
}
824+
815825
top_values
816826
}
817827

@@ -843,11 +853,22 @@ impl Stream for TopKReader {
843853
self.state.iter().rev().unzip();
844854

845855
let customer: Vec<&str> = customer.iter().map(|&s| &**s).collect();
856+
857+
let customer_array: ArrayRef = match schema.field(0).data_type() {
858+
arrow::datatypes::DataType::Utf8 => {
859+
Arc::new(StringArray::from(customer))
860+
}
861+
arrow::datatypes::DataType::Utf8View => {
862+
Arc::new(StringViewArray::from(customer))
863+
}
864+
other => panic!("Unsupported customer_id output type: {:?}", other),
865+
};
866+
846867
Poll::Ready(Some(
847868
RecordBatch::try_new(
848869
schema,
849870
vec![
850-
Arc::new(StringArray::from(customer)),
871+
Arc::new(customer_array),
851872
Arc::new(Int64Array::from(revenue)),
852873
],
853874
)

datafusion/sql/src/planner.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ impl ParserOptions {
7272
parse_float_as_decimal: false,
7373
enable_ident_normalization: true,
7474
support_varchar_with_length: true,
75-
map_varchar_to_utf8view: false,
75+
map_varchar_to_utf8view: true,
7676
enable_options_value_normalization: false,
7777
collect_spans: false,
7878
}

datafusion/sql/tests/sql_integration.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3360,7 +3360,7 @@ fn parse_decimals_parser_options() -> ParserOptions {
33603360
parse_float_as_decimal: true,
33613361
enable_ident_normalization: false,
33623362
support_varchar_with_length: false,
3363-
map_varchar_to_utf8view: false,
3363+
map_varchar_to_utf8view: true,
33643364
enable_options_value_normalization: false,
33653365
collect_spans: false,
33663366
}
@@ -3371,7 +3371,7 @@ fn ident_normalization_parser_options_no_ident_normalization() -> ParserOptions
33713371
parse_float_as_decimal: true,
33723372
enable_ident_normalization: false,
33733373
support_varchar_with_length: false,
3374-
map_varchar_to_utf8view: false,
3374+
map_varchar_to_utf8view: true,
33753375
enable_options_value_normalization: false,
33763376
collect_spans: false,
33773377
}
@@ -3382,7 +3382,7 @@ fn ident_normalization_parser_options_ident_normalization() -> ParserOptions {
33823382
parse_float_as_decimal: true,
33833383
enable_ident_normalization: true,
33843384
support_varchar_with_length: false,
3385-
map_varchar_to_utf8view: false,
3385+
map_varchar_to_utf8view: true,
33863386
enable_options_value_normalization: false,
33873387
collect_spans: false,
33883388
}

0 commit comments

Comments
 (0)