Skip to content

Commit f48e0b2

Browse files
authored
Impl convert_to_state for GroupsAccumulatorAdapter (faster median for high cardinality aggregates) (#11827)
* make a draft for `convert_to_state` in `GroupsAccumulatorAdapter`. * tmp * use filter nulls to impl quick filter for some arrays. * add unique group by test for `median`, `approx_median`, `approx_distinct`. * add normal cases & nullable cases for `median`, `approx_median`, `approx_distinct`. * add filter cases for `median`, `approx_median`, `approx_distinct`. * fix clippy. * fix fmt. * add todo. * fix comments. * fallback to filter kernal for general. * remove unused imports. * remove unused Array.
1 parent 6519f8e commit f48e0b2

File tree

2 files changed

+269
-6
lines changed

2 files changed

+269
-6
lines changed

datafusion/functions-aggregate-common/src/aggregate/groups_accumulator.rs

Lines changed: 49 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,11 @@ impl GroupsAccumulatorAdapter {
207207
let state = &mut self.states[group_idx];
208208
sizes_pre += state.size();
209209

210-
let values_to_accumulate =
211-
slice_and_maybe_filter(&values, opt_filter.as_ref(), offsets)?;
210+
let values_to_accumulate = slice_and_maybe_filter(
211+
&values,
212+
opt_filter.as_ref().map(|f| f.as_boolean()),
213+
offsets,
214+
)?;
212215
(f)(state.accumulator.as_mut(), &values_to_accumulate)?;
213216

214217
// clear out the state so they are empty for next
@@ -290,6 +293,7 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
290293
result
291294
}
292295

296+
// filtered_null_mask(opt_filter, &values);
293297
fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
294298
let vec_size_pre = self.states.allocated_size();
295299
let states = emit_to.take_needed(&mut self.states);
@@ -348,6 +352,46 @@ impl GroupsAccumulator for GroupsAccumulatorAdapter {
348352
fn size(&self) -> usize {
349353
self.allocation_bytes
350354
}
355+
356+
fn convert_to_state(
357+
&self,
358+
values: &[ArrayRef],
359+
opt_filter: Option<&BooleanArray>,
360+
) -> Result<Vec<ArrayRef>> {
361+
let num_rows = values[0].len();
362+
363+
// Each row has its respective group
364+
let mut results = vec![];
365+
for row_idx in 0..num_rows {
366+
// Create the empty accumulator for converting
367+
let mut converted_accumulator = (self.factory)()?;
368+
369+
// Convert row to states
370+
let values_to_accumulate =
371+
slice_and_maybe_filter(values, opt_filter, &[row_idx, row_idx + 1])?;
372+
converted_accumulator.update_batch(&values_to_accumulate)?;
373+
let states = converted_accumulator.state()?;
374+
375+
// Resize results to have enough columns according to the converted states
376+
results.resize_with(states.len(), || Vec::with_capacity(num_rows));
377+
378+
// Add the states to results
379+
for (idx, state_val) in states.into_iter().enumerate() {
380+
results[idx].push(state_val);
381+
}
382+
}
383+
384+
let arrays = results
385+
.into_iter()
386+
.map(ScalarValue::iter_to_array)
387+
.collect::<Result<Vec<_>>>()?;
388+
389+
Ok(arrays)
390+
}
391+
392+
fn supports_convert_to_state(&self) -> bool {
393+
true
394+
}
351395
}
352396

353397
/// Extension trait for [`Vec`] to account for allocations.
@@ -384,7 +428,7 @@ fn get_filter_at_indices(
384428
// Copied from physical-plan
385429
pub(crate) fn slice_and_maybe_filter(
386430
aggr_array: &[ArrayRef],
387-
filter_opt: Option<&ArrayRef>,
431+
filter_opt: Option<&BooleanArray>,
388432
offsets: &[usize],
389433
) -> Result<Vec<ArrayRef>> {
390434
let (offset, length) = (offsets[0], offsets[1] - offsets[0]);
@@ -394,13 +438,12 @@ pub(crate) fn slice_and_maybe_filter(
394438
.collect();
395439

396440
if let Some(f) = filter_opt {
397-
let filter_array = f.slice(offset, length);
398-
let filter_array = filter_array.as_boolean();
441+
let filter = f.slice(offset, length);
399442

400443
sliced_arrays
401444
.iter()
402445
.map(|array| {
403-
compute::filter(array, filter_array).map_err(|e| arrow_datafusion_err!(e))
446+
compute::filter(&array, &filter).map_err(|e| arrow_datafusion_err!(e))
404447
})
405448
.collect()
406449
} else {

datafusion/sqllogictest/test_files/aggregate_skip_partial.slt

Lines changed: 220 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,51 @@ GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
133133
-2117946883 d -2117946883 NULL NULL NULL
134134
-2098805236 c -2098805236 NULL NULL NULL
135135

136+
query ITIIII
137+
SELECT c5, c1,
138+
MEDIAN(c5),
139+
MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END),
140+
MEDIAN(c5) FILTER (WHERE c1 = 'b'),
141+
MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b')
142+
FROM aggregate_test_100
143+
GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
144+
----
145+
-2141999138 c -2141999138 NULL NULL NULL
146+
-2141451704 a -2141451704 -2141451704 NULL NULL
147+
-2138770630 b -2138770630 NULL -2138770630 NULL
148+
-2117946883 d -2117946883 NULL NULL NULL
149+
-2098805236 c -2098805236 NULL NULL NULL
150+
151+
query ITIIII
152+
SELECT c5, c1,
153+
APPROX_MEDIAN(c5),
154+
APPROX_MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END),
155+
APPROX_MEDIAN(c5) FILTER (WHERE c1 = 'b'),
156+
APPROX_MEDIAN(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b')
157+
FROM aggregate_test_100
158+
GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
159+
----
160+
-2141999138 c -2141999138 NULL NULL NULL
161+
-2141451704 a -2141451704 -2141451704 NULL NULL
162+
-2138770630 b -2138770630 NULL -2138770630 NULL
163+
-2117946883 d -2117946883 NULL NULL NULL
164+
-2098805236 c -2098805236 NULL NULL NULL
165+
166+
query ITIIII
167+
SELECT c5, c1,
168+
APPROX_DISTINCT(c5),
169+
APPROX_DISTINCT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END),
170+
APPROX_DISTINCT(c5) FILTER (WHERE c1 = 'b'),
171+
APPROX_DISTINCT(CASE WHEN c1 = 'a' THEN c5 ELSE NULL END) FILTER (WHERE c1 = 'b')
172+
FROM aggregate_test_100
173+
GROUP BY 1, 2 ORDER BY 1 LIMIT 5;
174+
----
175+
-2141999138 c 1 0 0 0
176+
-2141451704 a 1 1 0 0
177+
-2138770630 b 1 0 1 0
178+
-2117946883 d 1 0 0 0
179+
-2098805236 c 1 0 0 0
180+
136181
# FIXME: add bool_and(v3) column when issue fixed
137182
# ISSUE https://github.com/apache/datafusion/issues/11846
138183
query TBBB rowsort
@@ -222,6 +267,36 @@ SELECT c2, sum(c5), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
222267
4 16155718643 9.531112968922
223268
5 6449337880 7.074412226677
224269

270+
# Test median for int / float
271+
query IIR
272+
SELECT c2, median(c5), median(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
273+
----
274+
1 23971150 0.5922606
275+
2 -562486880 0.43422085
276+
3 240273900 0.40199697
277+
4 762932956 0.48515016
278+
5 604973998 0.49842384
279+
280+
# Test approx_median for int / float
281+
query IIR
282+
SELECT c2, approx_median(c5), approx_median(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
283+
----
284+
1 191655437 0.59926736
285+
2 -587831330 0.43230486
286+
3 240273900 0.40199697
287+
4 762932956 0.48515016
288+
5 593204320 0.5156586
289+
290+
# Test approx_distinct for varchar / int
291+
query III
292+
SELECT c2, approx_distinct(c1), approx_distinct(c5) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
293+
----
294+
1 5 22
295+
2 5 22
296+
3 5 19
297+
4 5 23
298+
5 5 14
299+
225300
# Test count with nullable fields
226301
query III
227302
SELECT c2, count(c3), count(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
@@ -252,6 +327,36 @@ SELECT c2, sum(c3), sum(c11) FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
252327
4 29 9.531112968922
253328
5 -194 7.074412226677
254329

330+
# Test median with nullable fields
331+
query IIR
332+
SELECT c2, median(c3), median(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
333+
----
334+
1 12 0.6067944
335+
2 1 0.46076488
336+
3 14 0.40154034
337+
4 -17 0.48515016
338+
5 -35 0.5536642
339+
340+
# Test approx_median with nullable fields
341+
query IIR
342+
SELECT c2, approx_median(c3), approx_median(c11) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
343+
----
344+
1 12 0.6067944
345+
2 1 0.46076488
346+
3 14 0.40154034
347+
4 -7 0.48515016
348+
5 -39 0.5536642
349+
350+
# Test approx_distinct with nullable fields
351+
query II
352+
SELECT c2, approx_distinct(c3) FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
353+
----
354+
1 19
355+
2 16
356+
3 13
357+
4 16
358+
5 12
359+
255360
# Test avg for tinyint / float
256361
query TRR
257362
SELECT
@@ -338,6 +443,48 @@ FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
338443
4 417
339444
5 284
340445

446+
# Test approx_distinct with filter
447+
query III
448+
SELECT
449+
c2,
450+
approx_distinct(c3) FILTER (WHERE c3 > 0),
451+
approx_distinct(c3) FILTER (WHERE c11 > 10)
452+
FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
453+
----
454+
1 13 0
455+
2 12 0
456+
3 11 0
457+
4 13 0
458+
5 5 0
459+
460+
# Test median with filter
461+
query III
462+
SELECT
463+
c2,
464+
median(c3) FILTER (WHERE c3 > 0),
465+
median(c3) FILTER (WHERE c3 < 0)
466+
FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
467+
----
468+
1 57 -56
469+
2 52 -60
470+
3 71 -74
471+
4 65 -69
472+
5 64 -59
473+
474+
# Test approx_median with filter
475+
query III
476+
SELECT
477+
c2,
478+
approx_median(c3) FILTER (WHERE c3 > 0),
479+
approx_median(c3) FILTER (WHERE c3 < 0)
480+
FROM aggregate_test_100 GROUP BY c2 ORDER BY c2;
481+
----
482+
1 57 -56
483+
2 52 -60
484+
3 71 -76
485+
4 65 -64
486+
5 64 -59
487+
341488
# Test count with nullable fields and filter
342489
query III
343490
SELECT c2,
@@ -421,6 +568,79 @@ FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
421568
4 -171 56 2.10740506649 1.939846396446
422569
5 -86 -76 1.8741710186 1.600569307804
423570

571+
# Test approx_distinct with nullable fields and filter
572+
query II
573+
SELECT c2,
574+
approx_distinct(c3) FILTER (WHERE c5 > 0)
575+
FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
576+
----
577+
1 11
578+
2 6
579+
3 6
580+
4 11
581+
5 8
582+
583+
# Test approx_distinct with nullable fields and nullable filter
584+
query II
585+
SELECT c2,
586+
approx_distinct(c3) FILTER (WHERE c11 > 0.5)
587+
FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
588+
----
589+
1 10
590+
2 6
591+
3 3
592+
4 3
593+
5 6
594+
595+
# Test median with nullable fields and filter
596+
query IIR
597+
SELECT c2,
598+
median(c3) FILTER (WHERE c5 > 0),
599+
median(c11) FILTER (WHERE c5 < 0)
600+
FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
601+
----
602+
1 -5 0.6623719
603+
2 15 0.52930677
604+
3 13 0.32792538
605+
4 -38 0.49774808
606+
5 -18 0.49842384
607+
608+
# Test min / max with nullable fields and nullable filter
609+
query II
610+
SELECT c2,
611+
median(c3) FILTER (WHERE c11 > 0.5)
612+
FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
613+
----
614+
1 33
615+
2 -29
616+
3 22
617+
4 -90
618+
5 -22
619+
620+
# Test approx_median with nullable fields and filter
621+
query IIR
622+
SELECT c2,
623+
approx_median(c3) FILTER (WHERE c5 > 0),
624+
approx_median(c11) FILTER (WHERE c5 < 0)
625+
FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
626+
----
627+
1 -5 0.6623719
628+
2 12 0.52930677
629+
3 13 0.32792538
630+
4 -38 0.49774808
631+
5 -21 0.47652745
632+
633+
# Test approx_median with nullable fields and nullable filter
634+
query II
635+
SELECT c2,
636+
approx_median(c3) FILTER (WHERE c11 > 0.5)
637+
FROM aggregate_test_100_null GROUP BY c2 ORDER BY c2;
638+
----
639+
1 35
640+
2 -29
641+
3 22
642+
4 -90
643+
5 -32
424644

425645
statement ok
426646
DROP TABLE aggregate_test_100_null;

0 commit comments

Comments
 (0)