Skip to content

Commit 9f30214

Browse files
authored
kurtosis udaf (#4)
1 parent 94f2a11 commit 9f30214

File tree

4 files changed

+270
-0
lines changed

4 files changed

+270
-0
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,3 +86,4 @@ SELECT min_by(x, y) FROM VALUES (1, 10), (2, 5), (3, 15), (4, 8) as tab(x, y);
8686
- [x] `min_by(expression1, expression2) -> scalar` - Returns the value of `expression1` associated with the minimum value of `expression2`.
8787
- [x] `skewness(expression) -> scalar` - Computes the skewness value for `expression`.
8888
- [x] `kurtois_pop(expression) -> scalar` - Computes the excess kurtosis (Fisher’s definition) without bias correction.
89+
- [x] `kurtosis(expression) -> scalar` - Computes the excess kurtosis (Fisher’s definition) with bias correction according to the sample size.

src/kurtosis.rs

Lines changed: 195 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,195 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{ArrayRef, Float64Array, UInt64Array};
19+
use arrow::datatypes::{DataType, Field};
20+
use datafusion::arrow;
21+
use std::any::Any;
22+
use std::fmt::Debug;
23+
24+
use datafusion::common::cast::as_float64_array;
25+
use datafusion::common::downcast_value;
26+
use datafusion::common::DataFusionError;
27+
use datafusion::error::Result;
28+
use datafusion::logical_expr::function::{AccumulatorArgs, StateFieldsArgs};
29+
use datafusion::logical_expr::{Accumulator, AggregateUDFImpl, Signature, Volatility};
30+
use datafusion::scalar::ScalarValue;
31+
32+
make_udaf_expr_and_func!(
33+
KurtosisFunction,
34+
kurtosis,
35+
x,
36+
"Calculates the excess kurtosis (Fisher’s definition) with bias correction according to the sample size.",
37+
kurtosis_udaf
38+
);
39+
40+
pub struct KurtosisFunction {
41+
signature: Signature,
42+
}
43+
44+
impl Debug for KurtosisFunction {
45+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
46+
f.debug_struct("KurtosisFunction")
47+
.field("signature", &self.signature)
48+
.finish()
49+
}
50+
}
51+
52+
impl Default for KurtosisFunction {
53+
fn default() -> Self {
54+
Self::new()
55+
}
56+
}
57+
58+
impl KurtosisFunction {
59+
pub fn new() -> Self {
60+
Self {
61+
signature: Signature::coercible(vec![DataType::Float64], Volatility::Immutable),
62+
}
63+
}
64+
}
65+
66+
impl AggregateUDFImpl for KurtosisFunction {
67+
fn as_any(&self) -> &dyn Any {
68+
self
69+
}
70+
71+
fn name(&self) -> &str {
72+
"kurtosis"
73+
}
74+
75+
fn signature(&self) -> &Signature {
76+
&self.signature
77+
}
78+
79+
fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
80+
Ok(DataType::Float64)
81+
}
82+
83+
fn accumulator(&self, _acc_args: AccumulatorArgs) -> Result<Box<dyn Accumulator>> {
84+
Ok(Box::new(KurtosisAccumulator::new()))
85+
}
86+
87+
fn state_fields(&self, _args: StateFieldsArgs) -> Result<Vec<Field>> {
88+
Ok(vec![
89+
Field::new("count", DataType::UInt64, true),
90+
Field::new("sum", DataType::Float64, true),
91+
Field::new("sum_sqr", DataType::Float64, true),
92+
Field::new("sum_cub", DataType::Float64, true),
93+
Field::new("sum_four", DataType::Float64, true),
94+
])
95+
}
96+
}
97+
98+
/// Accumulator for calculating the excess kurtosis (Fisher’s definition) with bias correction according to the sample size.
99+
/// This implementation follows the [DuckDB implementation]:
100+
/// <https://github.com/duckdb/duckdb/blob/main/src/core_functions/aggregate/distributive/kurtosis.cpp>
101+
#[derive(Debug, Default)]
102+
pub struct KurtosisAccumulator {
103+
count: u64,
104+
sum: f64,
105+
sum_sqr: f64,
106+
sum_cub: f64,
107+
sum_four: f64,
108+
}
109+
110+
impl KurtosisAccumulator {
111+
pub fn new() -> Self {
112+
Self {
113+
count: 0,
114+
sum: 0.0,
115+
sum_sqr: 0.0,
116+
sum_cub: 0.0,
117+
sum_four: 0.0,
118+
}
119+
}
120+
}
121+
122+
impl Accumulator for KurtosisAccumulator {
123+
fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
124+
let array = as_float64_array(&values[0])?;
125+
for value in array.iter().flatten() {
126+
self.count += 1;
127+
self.sum += value;
128+
self.sum_sqr += value.powi(2);
129+
self.sum_cub += value.powi(3);
130+
self.sum_four += value.powi(4);
131+
}
132+
Ok(())
133+
}
134+
135+
fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
136+
let counts = downcast_value!(states[0], UInt64Array);
137+
let sums = downcast_value!(states[1], Float64Array);
138+
let sum_sqrs = downcast_value!(states[2], Float64Array);
139+
let sum_cubs = downcast_value!(states[3], Float64Array);
140+
let sum_fours = downcast_value!(states[4], Float64Array);
141+
142+
for i in 0..counts.len() {
143+
let c = counts.value(i);
144+
if c == 0 {
145+
continue;
146+
}
147+
self.count += c;
148+
self.sum += sums.value(i);
149+
self.sum_sqr += sum_sqrs.value(i);
150+
self.sum_cub += sum_cubs.value(i);
151+
self.sum_four += sum_fours.value(i);
152+
}
153+
154+
Ok(())
155+
}
156+
157+
fn evaluate(&mut self) -> Result<ScalarValue> {
158+
if self.count <= 3 {
159+
return Ok(ScalarValue::Float64(None));
160+
}
161+
162+
let count_64 = 1_f64 / self.count as f64;
163+
let m4 = count_64
164+
* (self.sum_four - 4.0 * self.sum_cub * self.sum * count_64
165+
+ 6.0 * self.sum_sqr * self.sum.powi(2) * count_64.powi(2)
166+
- 3.0 * self.sum.powi(4) * count_64.powi(3));
167+
168+
let m2 = (self.sum_sqr - self.sum.powi(2) * count_64) * count_64;
169+
if m2 <= 0.0 {
170+
return Ok(ScalarValue::Float64(None));
171+
}
172+
173+
let count = self.count as f64;
174+
let numerator = (count - 1.0) * ((count + 1.0) * m4 / m2.powi(2) - 3.0 * (count - 1.0));
175+
let denominator = (count - 2.0) * (count - 3.0);
176+
177+
let target = numerator / denominator;
178+
179+
Ok(ScalarValue::Float64(Some(target)))
180+
}
181+
182+
fn size(&self) -> usize {
183+
std::mem::size_of_val(self)
184+
}
185+
186+
fn state(&mut self) -> Result<Vec<ScalarValue>> {
187+
Ok(vec![
188+
ScalarValue::from(self.count),
189+
ScalarValue::from(self.sum),
190+
ScalarValue::from(self.sum_sqr),
191+
ScalarValue::from(self.sum_cub),
192+
ScalarValue::from(self.sum_four),
193+
])
194+
}
195+
}

src/lib.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ use datafusion::logical_expr::AggregateUDF;
2626
#[macro_use]
2727
pub mod macros;
2828
pub mod common;
29+
pub mod kurtosis;
2930
pub mod kurtosis_pop;
3031
pub mod max_min_by;
3132
pub mod mode;
3233
pub mod skewness;
3334
pub mod expr_extra_fn {
35+
pub use super::kurtosis::kurtosis;
3436
pub use super::kurtosis_pop::kurtosis_pop;
3537
pub use super::max_min_by::max_by;
3638
pub use super::max_min_by::min_by;
@@ -43,6 +45,7 @@ pub fn all_extra_aggregate_functions() -> Vec<Arc<AggregateUDF>> {
4345
mode_udaf(),
4446
max_min_by::max_by_udaf(),
4547
max_min_by::min_by_udaf(),
48+
kurtosis::kurtosis_udaf(),
4649
skewness::skewness_udaf(),
4750
kurtosis_pop::kurtosis_pop_udaf(),
4851
]

tests/main.rs

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,3 +364,74 @@ async fn test_skewness() {
364364
- +-------------------+
365365
"###);
366366
}
367+
368+
#[tokio::test]
369+
async fn test_kurtosis() {
370+
let mut execution = TestExecution::new().await.unwrap();
371+
372+
let actual = execution
373+
.run_and_format("SELECT kurtosis(col) FROM VALUES (1.0), (10.0), (100.0), (10.0), (1.0) as tab(col);")
374+
.await;
375+
376+
insta::assert_yaml_snapshot!(actual, @r###"
377+
- +-------------------+
378+
- "| kurtosis(tab.col) |"
379+
- +-------------------+
380+
- "| 4.777292927667962 |"
381+
- +-------------------+
382+
"###);
383+
384+
let actual = execution
385+
.run_and_format("SELECT kurtosis(col) FROM VALUES ('1'), ('10'), ('100'), ('10'), ('1') as tab(col);")
386+
.await;
387+
388+
insta::assert_yaml_snapshot!(actual, @r###"
389+
- +-------------------+
390+
- "| kurtosis(tab.col) |"
391+
- +-------------------+
392+
- "| 4.777292927667962 |"
393+
- +-------------------+
394+
"###);
395+
396+
let actual = execution
397+
.run_and_format("SELECT kurtosis(col) FROM VALUES (1.0), (2.0), (3.0) as tab(col);")
398+
.await;
399+
400+
insta::assert_yaml_snapshot!(actual, @r###"
401+
- +-------------------+
402+
- "| kurtosis(tab.col) |"
403+
- +-------------------+
404+
- "| |"
405+
- +-------------------+
406+
"###);
407+
408+
let actual = execution.run_and_format("SELECT kurtosis(1);").await;
409+
410+
insta::assert_yaml_snapshot!(actual, @r###"
411+
- +--------------------+
412+
- "| kurtosis(Int64(1)) |"
413+
- +--------------------+
414+
- "| |"
415+
- +--------------------+
416+
"###);
417+
418+
let actual = execution.run_and_format("SELECT kurtosis(1.0);").await;
419+
420+
insta::assert_yaml_snapshot!(actual, @r###"
421+
- +----------------------+
422+
- "| kurtosis(Float64(1)) |"
423+
- +----------------------+
424+
- "| |"
425+
- +----------------------+
426+
"###);
427+
428+
let actual = execution.run_and_format("SELECT kurtosis(null);").await;
429+
430+
insta::assert_yaml_snapshot!(actual, @r###"
431+
- +----------------+
432+
- "| kurtosis(NULL) |"
433+
- +----------------+
434+
- "| |"
435+
- +----------------+
436+
"###);
437+
}

0 commit comments

Comments
 (0)