@@ -63,6 +63,7 @@ use std::hash::Hash;
63
63
use std:: task:: { Context , Poll } ;
64
64
use std:: { any:: Any , collections:: BTreeMap , fmt, sync:: Arc } ;
65
65
66
+ use arrow:: array:: { Array , ArrayRef , StringViewArray } ;
66
67
use arrow:: {
67
68
array:: { Int64Array , StringArray } ,
68
69
datatypes:: SchemaRef ,
@@ -100,6 +101,7 @@ use datafusion_optimizer::AnalyzerRule;
100
101
use datafusion_physical_plan:: execution_plan:: { Boundedness , EmissionType } ;
101
102
102
103
use async_trait:: async_trait;
104
+ use datafusion_common:: cast:: as_string_view_array;
103
105
use futures:: { Stream , StreamExt } ;
104
106
105
107
/// Execute the specified sql and return the resulting record batches
@@ -796,22 +798,30 @@ fn accumulate_batch(
796
798
k : & usize ,
797
799
) -> BTreeMap < i64 , String > {
798
800
let num_rows = input_batch. num_rows ( ) ;
801
+
799
802
// Assuming the input columns are
800
- // column[0]: customer_id / UTF8
803
+ // column[0]: customer_id / UTF8 or UTF8View
801
804
// column[1]: revenue: Int64
802
- let customer_id =
803
- as_string_array ( input_batch. column ( 0 ) ) . expect ( "Column 0 is not customer_id" ) ;
804
805
806
+ let customer_id_column = input_batch. column ( 0 ) ;
805
807
let revenue = as_int64_array ( input_batch. column ( 1 ) ) . unwrap ( ) ;
806
808
807
809
for row in 0 ..num_rows {
808
- add_row (
809
- & mut top_values,
810
- customer_id. value ( row) ,
811
- revenue. value ( row) ,
812
- k,
813
- ) ;
810
+ let customer_id = match customer_id_column. data_type ( ) {
811
+ arrow:: datatypes:: DataType :: Utf8 => {
812
+ let array = as_string_array ( customer_id_column) . unwrap ( ) ;
813
+ array. value ( row)
814
+ }
815
+ arrow:: datatypes:: DataType :: Utf8View => {
816
+ let array = as_string_view_array ( customer_id_column) . unwrap ( ) ;
817
+ array. value ( row)
818
+ }
819
+ _ => panic ! ( "Unsupported customer_id type" ) ,
820
+ } ;
821
+
822
+ add_row ( & mut top_values, customer_id, revenue. value ( row) , k) ;
814
823
}
824
+
815
825
top_values
816
826
}
817
827
@@ -843,11 +853,22 @@ impl Stream for TopKReader {
843
853
self . state . iter ( ) . rev ( ) . unzip ( ) ;
844
854
845
855
let customer: Vec < & str > = customer. iter ( ) . map ( |& s| & * * s) . collect ( ) ;
856
+
857
+ let customer_array: ArrayRef = match schema. field ( 0 ) . data_type ( ) {
858
+ arrow:: datatypes:: DataType :: Utf8 => {
859
+ Arc :: new ( StringArray :: from ( customer) )
860
+ }
861
+ arrow:: datatypes:: DataType :: Utf8View => {
862
+ Arc :: new ( StringViewArray :: from ( customer) )
863
+ }
864
+ other => panic ! ( "Unsupported customer_id output type: {:?}" , other) ,
865
+ } ;
866
+
846
867
Poll :: Ready ( Some (
847
868
RecordBatch :: try_new (
848
869
schema,
849
870
vec ! [
850
- Arc :: new( StringArray :: from ( customer ) ) ,
871
+ Arc :: new( customer_array ) ,
851
872
Arc :: new( Int64Array :: from( revenue) ) ,
852
873
] ,
853
874
)
0 commit comments