@@ -37,7 +37,8 @@ use arrow::array::Array;
37
37
use arrow:: array:: ArrowNativeTypeOp ;
38
38
use arrow:: datatypes:: { ArrowNativeType , ArrowPrimitiveType } ;
39
39
40
- use datafusion_common:: { DataFusionError , HashSet , Result , ScalarValue } ;
40
+ use datafusion_common:: { internal_err, DataFusionError , HashSet , Result , ScalarValue } ;
41
+ use datafusion_doc:: DocSection ;
41
42
use datafusion_expr:: function:: StateFieldsArgs ;
42
43
use datafusion_expr:: {
43
44
function:: AccumulatorArgs , utils:: format_state_name, Accumulator , AggregateUDFImpl ,
@@ -172,6 +173,45 @@ impl AggregateUDFImpl for Median {
172
173
}
173
174
}
174
175
176
+ fn groups_accumulator_supported ( & self , args : AccumulatorArgs ) -> bool {
177
+ !args. is_distinct
178
+ }
179
+
180
+ fn create_groups_accumulator (
181
+ & self ,
182
+ args : AccumulatorArgs ,
183
+ ) -> Result < Box < dyn GroupsAccumulator > > {
184
+ let num_args = args. exprs . len ( ) ;
185
+ if num_args != 1 {
186
+ return internal_err ! (
187
+ "median should only have 1 arg, but found num args:{}" ,
188
+ args. exprs. len( )
189
+ ) ;
190
+ }
191
+
192
+ let dt = args. exprs [ 0 ] . data_type ( args. schema ) ?;
193
+
194
+ macro_rules! helper {
195
+ ( $t: ty, $dt: expr) => {
196
+ Ok ( Box :: new( MedianGroupsAccumulator :: <$t>:: new( $dt) ) )
197
+ } ;
198
+ }
199
+
200
+ downcast_integer ! {
201
+ dt => ( helper, dt) ,
202
+ DataType :: Float16 => helper!( Float16Type , dt) ,
203
+ DataType :: Float32 => helper!( Float32Type , dt) ,
204
+ DataType :: Float64 => helper!( Float64Type , dt) ,
205
+ DataType :: Decimal128 ( _, _) => helper!( Decimal128Type , dt) ,
206
+ DataType :: Decimal256 ( _, _) => helper!( Decimal256Type , dt) ,
207
+ _ => Err ( DataFusionError :: NotImplemented ( format!(
208
+ "MedianGroupsAccumulator not supported for {} with {}" ,
209
+ args. name,
210
+ dt,
211
+ ) ) ) ,
212
+ }
213
+ }
214
+
175
215
fn aliases ( & self ) -> & [ String ] {
176
216
& [ ]
177
217
}
0 commit comments