@@ -37,7 +37,7 @@ use arrow::array::Array;
37
37
use arrow:: array:: ArrowNativeTypeOp ;
38
38
use arrow:: datatypes:: { ArrowNativeType , ArrowPrimitiveType } ;
39
39
40
- use datafusion_common:: { DataFusionError , HashSet , Result , ScalarValue } ;
40
+ use datafusion_common:: { internal_err , DataFusionError , HashSet , Result , ScalarValue } ;
41
41
use datafusion_doc:: DocSection ;
42
42
use datafusion_expr:: function:: StateFieldsArgs ;
43
43
use datafusion_expr:: {
@@ -173,6 +173,45 @@ impl AggregateUDFImpl for Median {
173
173
}
174
174
}
175
175
176
+ fn groups_accumulator_supported ( & self , args : AccumulatorArgs ) -> bool {
177
+ !args. is_distinct
178
+ }
179
+
180
+ fn create_groups_accumulator (
181
+ & self ,
182
+ args : AccumulatorArgs ,
183
+ ) -> Result < Box < dyn GroupsAccumulator > > {
184
+ let num_args = args. exprs . len ( ) ;
185
+ if num_args != 1 {
186
+ return internal_err ! (
187
+ "median should only have 1 arg, but found num args:{}" ,
188
+ args. exprs. len( )
189
+ ) ;
190
+ }
191
+
192
+ let dt = args. exprs [ 0 ] . data_type ( args. schema ) ?;
193
+
194
+ macro_rules! helper {
195
+ ( $t: ty, $dt: expr) => {
196
+ Ok ( Box :: new( MedianGroupsAccumulator :: <$t>:: new( $dt) ) )
197
+ } ;
198
+ }
199
+
200
+ downcast_integer ! {
201
+ dt => ( helper, dt) ,
202
+ DataType :: Float16 => helper!( Float16Type , dt) ,
203
+ DataType :: Float32 => helper!( Float32Type , dt) ,
204
+ DataType :: Float64 => helper!( Float64Type , dt) ,
205
+ DataType :: Decimal128 ( _, _) => helper!( Decimal128Type , dt) ,
206
+ DataType :: Decimal256 ( _, _) => helper!( Decimal256Type , dt) ,
207
+ _ => Err ( DataFusionError :: NotImplemented ( format!(
208
+ "MedianGroupsAccumulator not supported for {} with {}" ,
209
+ args. name,
210
+ dt,
211
+ ) ) ) ,
212
+ }
213
+ }
214
+
176
215
fn aliases ( & self ) -> & [ String ] {
177
216
& [ ]
178
217
}
0 commit comments