Skip to content

Commit 2eb38bd

Browse files
authored
Minor: Move group accumulator for aggregate function to physical-expr-common, and add ahash physical-expr-common (#10574)
* ahash workspace Signed-off-by: jayzhan211 <[email protected]> * move other utils Signed-off-by: jayzhan211 <[email protected]> * move NullState Signed-off-by: jayzhan211 <[email protected]> * move PrimitiveGroupsAccumulator Signed-off-by: jayzhan211 <[email protected]> * move boolop Signed-off-by: jayzhan211 <[email protected]> * move deciamlavg Signed-off-by: jayzhan211 <[email protected]> * add comment Signed-off-by: jayzhan211 <[email protected]> * fix doc Signed-off-by: jayzhan211 <[email protected]> --------- Signed-off-by: jayzhan211 <[email protected]>
1 parent 65e281a commit 2eb38bd

File tree

17 files changed

+222
-220
lines changed

17 files changed

+222
-220
lines changed

Cargo.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,9 @@ version = "38.0.0"
5959
# for the inherited dependency but cannot do the reverse (override from true to false).
6060
#
6161
# See for more detaiils: https://github.com/rust-lang/cargo/issues/11329
62+
ahash = { version = "0.8", default-features = false, features = [
63+
"runtime-rng",
64+
] }
6265
arrow = { version = "51.0.0", features = ["prettyprint"] }
6366
arrow-array = { version = "51.0.0", default-features = false, features = ["chrono-tz"] }
6467
arrow-buffer = { version = "51.0.0", default-features = false }

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/common/Cargo.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,7 @@ backtrace = []
4141
pyarrow = ["pyo3", "arrow/pyarrow", "parquet"]
4242

4343
[dependencies]
44-
ahash = { version = "0.8", default-features = false, features = [
45-
"runtime-rng",
46-
] }
44+
ahash = { workspace = true }
4745
apache-avro = { version = "0.16", default-features = false, features = [
4846
"bzip",
4947
"snappy",

datafusion/core/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ unicode_expressions = [
7777
]
7878

7979
[dependencies]
80-
ahash = { version = "0.8", default-features = false, features = ["runtime-rng"] }
80+
ahash = { workspace = true }
8181
apache-avro = { version = "0.16", optional = true }
8282
arrow = { workspace = true }
8383
arrow-array = { workspace = true }

datafusion/expr/Cargo.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ path = "src/lib.rs"
3838
[features]
3939

4040
[dependencies]
41-
ahash = { version = "0.8", default-features = false, features = [
42-
"runtime-rng",
43-
] }
41+
ahash = { workspace = true }
4442
arrow = { workspace = true }
4543
arrow-array = { workspace = true }
4644
chrono = { workspace = true }

datafusion/physical-expr-common/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,3 +39,4 @@ path = "src/lib.rs"
3939
arrow = { workspace = true }
4040
datafusion-common = { workspace = true, default-features = true }
4141
datafusion-expr = { workspace = true }
42+
rand = { workspace = true }

datafusion/physical-expr/src/aggregate/groups_accumulator/accumulate.rs renamed to datafusion/physical-expr-common/src/aggregate/groups_accumulator/accumulate.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,9 @@
1919
//!
2020
//! [`GroupsAccumulator`]: datafusion_expr::GroupsAccumulator
2121
22+
use arrow::array::{Array, BooleanArray, BooleanBufferBuilder, PrimitiveArray};
23+
use arrow::buffer::{BooleanBuffer, NullBuffer};
2224
use arrow::datatypes::ArrowPrimitiveType;
23-
use arrow_array::{Array, BooleanArray, PrimitiveArray};
24-
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder, NullBuffer};
2525

2626
use datafusion_expr::EmitTo;
2727
/// Track the accumulator null state per row: if any values for that
@@ -462,9 +462,9 @@ fn initialize_builder(
462462
mod test {
463463
use super::*;
464464

465-
use arrow_array::UInt32Array;
466-
use hashbrown::HashSet;
465+
use arrow::array::UInt32Array;
467466
use rand::{rngs::ThreadRng, Rng};
467+
use std::collections::HashSet;
468468

469469
#[test]
470470
fn accumulate() {

datafusion/physical-expr/src/aggregate/groups_accumulator/bool_op.rs renamed to datafusion/physical-expr-common/src/aggregate/groups_accumulator/bool_op.rs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,8 @@
1717

1818
use std::sync::Arc;
1919

20-
use arrow::array::AsArray;
21-
use arrow_array::{ArrayRef, BooleanArray};
22-
use arrow_buffer::{BooleanBuffer, BooleanBufferBuilder};
20+
use arrow::array::{ArrayRef, AsArray, BooleanArray, BooleanBufferBuilder};
21+
use arrow::buffer::BooleanBuffer;
2322
use datafusion_common::Result;
2423
use datafusion_expr::{EmitTo, GroupsAccumulator};
2524

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
//! Utilities for implementing GroupsAccumulator
19+
20+
pub mod accumulate;
21+
pub mod bool_op;
22+
pub mod prim_op;

datafusion/physical-expr/src/aggregate/groups_accumulator/prim_op.rs renamed to datafusion/physical-expr-common/src/aggregate/groups_accumulator/prim_op.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717

1818
use std::sync::Arc;
1919

20-
use arrow::{array::AsArray, datatypes::ArrowPrimitiveType};
21-
use arrow_array::{ArrayRef, BooleanArray, PrimitiveArray};
22-
use arrow_schema::DataType;
20+
use arrow::array::{ArrayRef, AsArray, BooleanArray, PrimitiveArray};
21+
use arrow::datatypes::ArrowPrimitiveType;
22+
use arrow::datatypes::DataType;
2323
use datafusion_common::Result;
2424
use datafusion_expr::{EmitTo, GroupsAccumulator};
2525

datafusion/physical-expr-common/src/aggregate/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
pub mod groups_accumulator;
1819
pub mod stats;
1920
pub mod utils;
2021

datafusion/physical-expr-common/src/aggregate/utils.rs

Lines changed: 161 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,18 @@
1717

1818
use std::{any::Any, sync::Arc};
1919

20+
use arrow::datatypes::ArrowNativeType;
2021
use arrow::{
22+
array::{ArrayRef, ArrowNativeTypeOp, AsArray},
2123
compute::SortOptions,
22-
datatypes::{DataType, Field},
24+
datatypes::{
25+
DataType, Decimal128Type, DecimalType, Field, TimeUnit, TimestampMicrosecondType,
26+
TimestampMillisecondType, TimestampNanosecondType, TimestampSecondType,
27+
ToByteSlice,
28+
},
2329
};
30+
use datafusion_common::{exec_err, DataFusionError, Result};
31+
use datafusion_expr::Accumulator;
2432

2533
use crate::sort_expr::PhysicalSortExpr;
2634

@@ -43,6 +51,60 @@ pub fn down_cast_any_ref(any: &dyn Any) -> &dyn Any {
4351
}
4452
}
4553

54+
/// Convert scalar values from an accumulator into arrays.
55+
pub fn get_accum_scalar_values_as_arrays(
56+
accum: &mut dyn Accumulator,
57+
) -> Result<Vec<ArrayRef>> {
58+
accum
59+
.state()?
60+
.iter()
61+
.map(|s| s.to_array_of_size(1))
62+
.collect()
63+
}
64+
65+
/// Adjust array type metadata if needed
66+
///
67+
/// Since `Decimal128Arrays` created from `Vec<NativeType>` have
68+
/// default precision and scale, this function adjusts the output to
69+
/// match `data_type`, if necessary
70+
pub fn adjust_output_array(data_type: &DataType, array: ArrayRef) -> Result<ArrayRef> {
71+
let array = match data_type {
72+
DataType::Decimal128(p, s) => Arc::new(
73+
array
74+
.as_primitive::<Decimal128Type>()
75+
.clone()
76+
.with_precision_and_scale(*p, *s)?,
77+
) as ArrayRef,
78+
DataType::Timestamp(TimeUnit::Nanosecond, tz) => Arc::new(
79+
array
80+
.as_primitive::<TimestampNanosecondType>()
81+
.clone()
82+
.with_timezone_opt(tz.clone()),
83+
),
84+
DataType::Timestamp(TimeUnit::Microsecond, tz) => Arc::new(
85+
array
86+
.as_primitive::<TimestampMicrosecondType>()
87+
.clone()
88+
.with_timezone_opt(tz.clone()),
89+
),
90+
DataType::Timestamp(TimeUnit::Millisecond, tz) => Arc::new(
91+
array
92+
.as_primitive::<TimestampMillisecondType>()
93+
.clone()
94+
.with_timezone_opt(tz.clone()),
95+
),
96+
DataType::Timestamp(TimeUnit::Second, tz) => Arc::new(
97+
array
98+
.as_primitive::<TimestampSecondType>()
99+
.clone()
100+
.with_timezone_opt(tz.clone()),
101+
),
102+
// no adjustment needed for other arrays
103+
_ => array,
104+
};
105+
Ok(array)
106+
}
107+
46108
/// Construct corresponding fields for lexicographical ordering requirement expression
47109
pub fn ordering_fields(
48110
ordering_req: &[PhysicalSortExpr],
@@ -67,3 +129,101 @@ pub fn ordering_fields(
67129
pub fn get_sort_options(ordering_req: &[PhysicalSortExpr]) -> Vec<SortOptions> {
68130
ordering_req.iter().map(|item| item.options).collect()
69131
}
132+
133+
/// A wrapper around a type to provide hash for floats
134+
#[derive(Copy, Clone, Debug)]
135+
pub struct Hashable<T>(pub T);
136+
137+
impl<T: ToByteSlice> std::hash::Hash for Hashable<T> {
138+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
139+
self.0.to_byte_slice().hash(state)
140+
}
141+
}
142+
143+
impl<T: ArrowNativeTypeOp> PartialEq for Hashable<T> {
144+
fn eq(&self, other: &Self) -> bool {
145+
self.0.is_eq(other.0)
146+
}
147+
}
148+
149+
impl<T: ArrowNativeTypeOp> Eq for Hashable<T> {}
150+
151+
/// Computes averages for `Decimal128`/`Decimal256` values, checking for overflow
152+
///
153+
/// This is needed because different precisions for Decimal128/Decimal256 can
154+
/// store different ranges of values and thus sum/count may not fit in
155+
/// the target type.
156+
///
157+
/// For example, the precision is 3, the max of value is `999` and the min
158+
/// value is `-999`
159+
pub struct DecimalAverager<T: DecimalType> {
160+
/// scale factor for sum values (10^sum_scale)
161+
sum_mul: T::Native,
162+
/// scale factor for target (10^target_scale)
163+
target_mul: T::Native,
164+
/// the output precision
165+
target_precision: u8,
166+
}
167+
168+
impl<T: DecimalType> DecimalAverager<T> {
169+
/// Create a new `DecimalAverager`:
170+
///
171+
/// * sum_scale: the scale of `sum` values passed to [`Self::avg`]
172+
/// * target_precision: the output precision
173+
/// * target_scale: the output scale
174+
///
175+
/// Errors if the resulting data can not be stored
176+
pub fn try_new(
177+
sum_scale: i8,
178+
target_precision: u8,
179+
target_scale: i8,
180+
) -> Result<Self> {
181+
let sum_mul = T::Native::from_usize(10_usize)
182+
.map(|b| b.pow_wrapping(sum_scale as u32))
183+
.ok_or(DataFusionError::Internal(
184+
"Failed to compute sum_mul in DecimalAverager".to_string(),
185+
))?;
186+
187+
let target_mul = T::Native::from_usize(10_usize)
188+
.map(|b| b.pow_wrapping(target_scale as u32))
189+
.ok_or(DataFusionError::Internal(
190+
"Failed to compute target_mul in DecimalAverager".to_string(),
191+
))?;
192+
193+
if target_mul >= sum_mul {
194+
Ok(Self {
195+
sum_mul,
196+
target_mul,
197+
target_precision,
198+
})
199+
} else {
200+
// can't convert the lit decimal to the returned data type
201+
exec_err!("Arithmetic Overflow in AvgAccumulator")
202+
}
203+
}
204+
205+
/// Returns the `sum`/`count` as a i128/i256 Decimal128/Decimal256 with
206+
/// target_scale and target_precision and reporting overflow.
207+
///
208+
/// * sum: The total sum value stored as Decimal128 with sum_scale
209+
/// (passed to `Self::try_new`)
210+
/// * count: total count, stored as a i128/i256 (*NOT* a Decimal128/Decimal256 value)
211+
#[inline(always)]
212+
pub fn avg(&self, sum: T::Native, count: T::Native) -> Result<T::Native> {
213+
if let Ok(value) = sum.mul_checked(self.target_mul.div_wrapping(self.sum_mul)) {
214+
let new_value = value.div_wrapping(count);
215+
216+
let validate =
217+
T::validate_decimal_precision(new_value, self.target_precision);
218+
219+
if validate.is_ok() {
220+
Ok(new_value)
221+
} else {
222+
exec_err!("Arithmetic Overflow in AvgAccumulator")
223+
}
224+
} else {
225+
// can't convert the lit decimal to the returned data type
226+
exec_err!("Arithmetic Overflow in AvgAccumulator")
227+
}
228+
}
229+
}

datafusion/physical-expr/Cargo.toml

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ encoding_expressions = ["base64", "hex"]
4444
regex_expressions = ["regex"]
4545

4646
[dependencies]
47-
ahash = { version = "0.8", default-features = false, features = [
48-
"runtime-rng",
49-
] }
47+
ahash = { workspace = true }
5048
arrow = { workspace = true }
5149
arrow-array = { workspace = true }
5250
arrow-buffer = { workspace = true }

datafusion/physical-expr/src/aggregate/groups_accumulator/mod.rs

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,19 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
pub(crate) mod accumulate;
1918
mod adapter;
20-
pub use accumulate::NullState;
2119
pub use adapter::GroupsAccumulatorAdapter;
2220

23-
pub(crate) mod bool_op;
24-
pub(crate) mod prim_op;
21+
// Backward compatibility
22+
pub(crate) mod accumulate {
23+
pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::{accumulate_indices, NullState};
24+
}
25+
26+
pub use datafusion_physical_expr_common::aggregate::groups_accumulator::accumulate::NullState;
27+
28+
pub(crate) mod bool_op {
29+
pub use datafusion_physical_expr_common::aggregate::groups_accumulator::bool_op::BooleanGroupsAccumulator;
30+
}
31+
pub(crate) mod prim_op {
32+
pub use datafusion_physical_expr_common::aggregate::groups_accumulator::prim_op::PrimitiveGroupsAccumulator;
33+
}

datafusion/physical-expr/src/aggregate/mod.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,12 @@ pub(crate) mod variance;
5454

5555
pub mod build_in;
5656
pub mod moving_min_max;
57-
pub mod utils;
57+
pub mod utils {
58+
pub use datafusion_physical_expr_common::aggregate::utils::{
59+
adjust_output_array, down_cast_any_ref, get_accum_scalar_values_as_arrays,
60+
get_sort_options, ordering_fields, DecimalAverager, Hashable,
61+
};
62+
}
5863

5964
/// Checks whether the given aggregate expression is order-sensitive.
6065
/// For instance, a `SUM` aggregation doesn't depend on the order of its inputs.

0 commit comments

Comments
 (0)