Skip to content

Commit c9444da

Browse files
author
jatin
committed
kurtosis udaf
1 parent f89f200 commit c9444da

File tree

2 files changed

+206
-1
lines changed

2 files changed

+206
-1
lines changed

src/kurtosis.rs

Lines changed: 201 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,201 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::any::Any;
19+
use std::fmt::Debug;
20+
use arrow::array::{ArrayRef, Float64Array, UInt64Array};
21+
use arrow::datatypes::{DataType, Field};
22+
use datafusion::arrow;
23+
24+
use datafusion::common::cast::as_float64_array;
25+
use datafusion::common::downcast_value;
26+
use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
27+
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
28+
use datafusion::scalar::ScalarValue;
29+
use datafusion::error::Result;
30+
use datafusion::common::DataFusionError;
31+
32+
33+
make_udaf_expr_and_func!(
34+
KurtosisFunction,
35+
kurtosis,
36+
x,
37+
"Calculates the excess kurtosis (Fisher’s definition) with bias correction according to the sample size.",
38+
kurtosis_udaf
39+
);
40+
41+
pub struct KurtosisFunction {
42+
signature: Signature,
43+
}
44+
45+
impl Debug for KurtosisFunction {
46+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
47+
f.debug_struct("KurtosisFunction")
48+
.field("signature", &self.signature)
49+
.finish()
50+
}
51+
}
52+
53+
impl Default for KurtosisFunction {
54+
fn default() -> Self {
55+
Self::new()
56+
}
57+
}
58+
59+
impl KurtosisFunction {
60+
pub fn new() -> Self {
61+
Self {
62+
signature: Signature::coercible(
63+
vec![DataType::Float64],
64+
Volatility::Immutable,
65+
),
66+
}
67+
}
68+
}
69+
70+
impl AggregateUDFImpl for KurtosisFunction {
71+
fn as_any(&self) -> &dyn Any {
72+
self
73+
}
74+
75+
fn name(&self) -> &str {
76+
"kurtosis"
77+
}
78+
79+
fn signature(&self) -> &Signature {
80+
&self.signature
81+
}
82+
83+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
84+
Ok(DataType::Float64)
85+
}
86+
87+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
88+
Ok(Box::new(KurtosisAccumulator::new()))
89+
}
90+
91+
// TODO
92+
// check from here
93+
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
94+
Ok(vec![
95+
Field::new("count", DataType::UInt64, true),
96+
Field::new("sum", DataType::Float64, true),
97+
Field::new("sum_sqr", DataType::Float64, true),
98+
Field::new("sum_cub", DataType::Float64, true),
99+
Field::new("sum_four", DataType::Float64, true),
100+
])
101+
}
102+
}
103+
104+
/// Accumulator for calculating the excess kurtosis (Fisher’s definition) with bias correction according to the sample size.
105+
/// This implementation follows the [DuckDB implementation]:
106+
/// <https://github.com/duckdb/duckdb/blob/main/src/core_functions/aggregate/distributive/kurtosis.cpp>
107+
#[derive(Debug, Default)]
108+
pub struct KurtosisAccumulator {
109+
count: u64,
110+
sum: f64,
111+
sum_sqr: f64,
112+
sum_cub: f64,
113+
sum_four: f64,
114+
}
115+
116+
impl KurtosisAccumulator {
117+
pub fn new() -> Self {
118+
Self {
119+
count: 0,
120+
sum: 0.0,
121+
sum_sqr: 0.0,
122+
sum_cub: 0.0,
123+
sum_four: 0.0,
124+
}
125+
}
126+
}
127+
128+
impl Accumulator for KurtosisAccumulator {
129+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
130+
let array = as_float64_array(&values[0])?;
131+
for value in array.iter().flatten() {
132+
self.count += 1;
133+
self.sum += value;
134+
self.sum_sqr += value.powi(2);
135+
self.sum_cub += value.powi(3);
136+
self.sum_four += value.powi(4);
137+
}
138+
Ok(())
139+
}
140+
141+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
142+
let counts = downcast_value!(states[0], UInt64Array);
143+
let sums = downcast_value!(states[1], Float64Array);
144+
let sum_sqrs = downcast_value!(states[2], Float64Array);
145+
let sum_cubs = downcast_value!(states[3], Float64Array);
146+
let sum_fours = downcast_value!(states[4], Float64Array);
147+
148+
for i in 0..counts.len() {
149+
let c = counts.value(i);
150+
if c == 0 {
151+
continue;
152+
}
153+
self.count += c;
154+
self.sum += sums.value(i);
155+
self.sum_sqr += sum_sqrs.value(i);
156+
self.sum_cub += sum_cubs.value(i);
157+
self.sum_four += sum_fours.value(i);
158+
}
159+
160+
Ok(())
161+
}
162+
163+
fn evaluate(&mut self) -> Result<ScalarValue> {
164+
if self.count <= 3 {
165+
return Ok(ScalarValue::Float64(None));
166+
}
167+
168+
let count_64 = 1_f64 / self.count as f64;
169+
let m4 = count_64
170+
* (self.sum_four - 4.0 * self.sum_cub * self.sum * count_64
171+
+ 6.0 * self.sum_sqr * self.sum.powi(2) * count_64.powi(2)
172+
- 3.0 * self.sum.powi(4) * count_64.powi(3));
173+
174+
let m2 = (self.sum_sqr - self.sum.powi(2) * count_64) * count_64;
175+
if m2 <= 0.0 {
176+
return Ok(ScalarValue::Float64(None));
177+
}
178+
179+
let count = self.count as f64;
180+
let numerator = (count - 1.0) * ((count + 1.0) * m4 / m2.powi(2) - 3.0 * (count - 1.0));
181+
let denominator = (count - 2.0 ) * (count - 3.0);
182+
183+
let target = numerator/denominator;
184+
185+
Ok(ScalarValue::Float64(Some(target)))
186+
}
187+
188+
fn size(&self) -> usize {
189+
std::mem::size_of_val(self)
190+
}
191+
192+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
193+
Ok(vec![
194+
ScalarValue::from(self.count),
195+
ScalarValue::from(self.sum),
196+
ScalarValue::from(self.sum_sqr),
197+
ScalarValue::from(self.sum_cub),
198+
ScalarValue::from(self.sum_four),
199+
])
200+
}
201+
}

src/lib.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,17 @@ use datafusion::logical_expr::AggregateUDF;
2727
pub mod macros;
2828
pub mod common;
2929
pub mod mode;
30+
pub mod kurtosis;
3031

3132
pub mod expr_extra_fn {
3233
pub use super::mode::mode;
3334
}
3435

3536
pub fn all_extra_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
36-
vec![mode_udaf()]
37+
vec![
38+
mode_udaf(),
39+
kurtosis::kurtosis_udaf(),
40+
]
3741
}
3842

3943
/// Registers all enabled packages with a [`FunctionRegistry`]

0 commit comments

Comments
 (0)