Skip to content

Commit efe5708

Browse files
jcsherinalamb
andauthored
Convert BuiltInWindowFunction::{Lead, Lag} to a user defined window function (#12857)
* Move `lead-lag` to `functions-window` package * Builds with warnings * Adds `PartitionEvaluatorArgs` * Extracts `shift_offset` from input expressions * Computes shift offset * Get default value from input expression * Implements `partition_evaluator` * Fixes compiler warnings * Comments out failing tests * Fixes `cargo test` errors and warnings * Minor: taplo formatting * Delete code * Define `lead`, `lag` user-defined window functions * Fixes `cargo build` errors * Export udwf and expression public APIs * Mark result field as nullable * Delete `return_type` tests for `lead` and `lag` * Disables test: window function case insensitive * Fixes: lowercase name in logical plan * Reverts to old methods for computing `shift_offset`, `default_value` * Implements expression reversal * Fixes: lowercase name in logical plans * Fixes: doc test compilation errors Fixes: doc test build errors * Temporarily quite clippy errors * Fixes proto defintion * Minor: fixes formatting * Fixes: doc tests * Uses macro for defining `lag_udwf()` and `leag_udwf()` * Fixes: window fuzz test cases * Copies doc comments verbatim from `BuiltInWindowFunction` enum * Deletes from window function case insensitive test * Deletes `BuiltInWindowFunction` expression APIs * Delete from `create_built_in_window_expr` * Deletes proto serialization * Delete from `BuiltInWindowFunction` enum * Deletes test for finding built-in window function * Fixes build errors + deletes redundant code * Deletes more code * Delete unnecessary structs * Refactors shift offset computation * Passes range unit test * Fixes: clippy::get-first error * Rewrite unit tests for WindowUDF * Fixes: unit test for lag with default value * Consistent input expressions and data types in unit tests * Minor: fixes formatting * Restore original helper method for unit tests * Revert "Refactors shift offset computation" This reverts commit 000ceb7. * Moves helper functions into `functions-window-common` package * Uses common helper functions in `{lead, lag}` * Minor: formatting * Revert "Moves helper functions into `functions-window-common` package" This reverts commit ab8a83c. * Moves common functions to utils * Minor: formatting fixes * Update lowercase names in explain output * Adds doc for `lead()` and `lag()` expression functions * Add doc for `WindowShiftKind::shift_offset` * Remove `arrow` dev dependency * Minor: formatting * Update inner doc comment * Serialize 1 or more window function arguments * Adds logical plan roundtrip test cases * Refactor: readability of unit tests * Minor: rename variable bindings * Minor: copy edit * Revert "Remove `arrow` dev dependency" This reverts commit 3eb0985. * Move null argument handling helper to utils * Disable failing sqllogic tests for handling NULL input * Revert "Disable failing sqllogic tests for handling NULL input" This reverts commit 270a203. * Fixes: incorrect NULL handling in `lead`/`lag` window function * Adds more tests cases --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 700b07f commit efe5708

File tree

24 files changed

+520
-407
lines changed

24 files changed

+520
-407
lines changed

datafusion-cli/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

datafusion/core/tests/fuzz_cases/window_fuzz.rs

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ use datafusion_physical_expr::{PhysicalExpr, PhysicalSortExpr};
4545
use test_utils::add_empty_batches;
4646

4747
use datafusion::functions_window::row_number::row_number_udwf;
48+
use datafusion_functions_window::lead_lag::{lag_udwf, lead_udwf};
4849
use datafusion_functions_window::rank::{dense_rank_udwf, rank_udwf};
4950
use hashbrown::HashMap;
5051
use rand::distributions::Alphanumeric;
@@ -197,7 +198,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
197198
// )
198199
(
199200
// Window function
200-
WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lag),
201+
WindowFunctionDefinition::WindowUDF(lag_udwf()),
201202
// its name
202203
"LAG",
203204
// no argument
@@ -211,7 +212,7 @@ async fn bounded_window_causal_non_causal() -> Result<()> {
211212
// )
212213
(
213214
// Window function
214-
WindowFunctionDefinition::BuiltInWindowFunction(BuiltInWindowFunction::Lead),
215+
WindowFunctionDefinition::WindowUDF(lead_udwf()),
215216
// its name
216217
"LEAD",
217218
// no argument
@@ -393,9 +394,7 @@ fn get_random_function(
393394
window_fn_map.insert(
394395
"lead",
395396
(
396-
WindowFunctionDefinition::BuiltInWindowFunction(
397-
BuiltInWindowFunction::Lead,
398-
),
397+
WindowFunctionDefinition::WindowUDF(lead_udwf()),
399398
vec![
400399
arg.clone(),
401400
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),
@@ -406,9 +405,7 @@ fn get_random_function(
406405
window_fn_map.insert(
407406
"lag",
408407
(
409-
WindowFunctionDefinition::BuiltInWindowFunction(
410-
BuiltInWindowFunction::Lag,
411-
),
408+
WindowFunctionDefinition::WindowUDF(lag_udwf()),
412409
vec![
413410
arg.clone(),
414411
lit(ScalarValue::Int64(Some(rng.gen_range(1..10)))),

datafusion/expr/src/built_in_window_function.rs

Lines changed: 3 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ use std::str::FromStr;
2222

2323
use crate::type_coercion::functions::data_types;
2424
use crate::utils;
25-
use crate::{Signature, TypeSignature, Volatility};
25+
use crate::{Signature, Volatility};
2626
use datafusion_common::{plan_datafusion_err, plan_err, DataFusionError, Result};
2727

2828
use arrow::datatypes::DataType;
@@ -44,17 +44,7 @@ pub enum BuiltInWindowFunction {
4444
CumeDist,
4545
/// Integer ranging from 1 to the argument value, dividing the partition as equally as possible
4646
Ntile,
47-
/// Returns value evaluated at the row that is offset rows before the current row within the partition;
48-
/// If there is no such row, instead return default (which must be of the same type as value).
49-
/// Both offset and default are evaluated with respect to the current row.
50-
/// If omitted, offset defaults to 1 and default to null
51-
Lag,
52-
/// Returns value evaluated at the row that is offset rows after the current row within the partition;
53-
/// If there is no such row, instead return default (which must be of the same type as value).
54-
/// Both offset and default are evaluated with respect to the current row.
55-
/// If omitted, offset defaults to 1 and default to null
56-
Lead,
57-
/// Returns value evaluated at the row that is the first row of the window frame
47+
/// returns value evaluated at the row that is the first row of the window frame
5848
FirstValue,
5949
/// Returns value evaluated at the row that is the last row of the window frame
6050
LastValue,
@@ -68,8 +58,6 @@ impl BuiltInWindowFunction {
6858
match self {
6959
CumeDist => "CUME_DIST",
7060
Ntile => "NTILE",
71-
Lag => "LAG",
72-
Lead => "LEAD",
7361
FirstValue => "first_value",
7462
LastValue => "last_value",
7563
NthValue => "NTH_VALUE",
@@ -83,8 +71,6 @@ impl FromStr for BuiltInWindowFunction {
8371
Ok(match name.to_uppercase().as_str() {
8472
"CUME_DIST" => BuiltInWindowFunction::CumeDist,
8573
"NTILE" => BuiltInWindowFunction::Ntile,
86-
"LAG" => BuiltInWindowFunction::Lag,
87-
"LEAD" => BuiltInWindowFunction::Lead,
8874
"FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
8975
"LAST_VALUE" => BuiltInWindowFunction::LastValue,
9076
"NTH_VALUE" => BuiltInWindowFunction::NthValue,
@@ -117,9 +103,7 @@ impl BuiltInWindowFunction {
117103
match self {
118104
BuiltInWindowFunction::Ntile => Ok(DataType::UInt64),
119105
BuiltInWindowFunction::CumeDist => Ok(DataType::Float64),
120-
BuiltInWindowFunction::Lag
121-
| BuiltInWindowFunction::Lead
122-
| BuiltInWindowFunction::FirstValue
106+
BuiltInWindowFunction::FirstValue
123107
| BuiltInWindowFunction::LastValue
124108
| BuiltInWindowFunction::NthValue => Ok(input_expr_types[0].clone()),
125109
}
@@ -130,16 +114,6 @@ impl BuiltInWindowFunction {
130114
// Note: The physical expression must accept the type returned by this function or the execution panics.
131115
match self {
132116
BuiltInWindowFunction::CumeDist => Signature::any(0, Volatility::Immutable),
133-
BuiltInWindowFunction::Lag | BuiltInWindowFunction::Lead => {
134-
Signature::one_of(
135-
vec![
136-
TypeSignature::Any(1),
137-
TypeSignature::Any(2),
138-
TypeSignature::Any(3),
139-
],
140-
Volatility::Immutable,
141-
)
142-
}
143117
BuiltInWindowFunction::FirstValue | BuiltInWindowFunction::LastValue => {
144118
Signature::any(1, Volatility::Immutable)
145119
}

datafusion/expr/src/expr.rs

Lines changed: 0 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -2560,30 +2560,6 @@ mod test {
25602560
Ok(())
25612561
}
25622562

2563-
#[test]
2564-
fn test_lead_return_type() -> Result<()> {
2565-
let fun = find_df_window_func("lead").unwrap();
2566-
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
2567-
assert_eq!(DataType::Utf8, observed);
2568-
2569-
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
2570-
assert_eq!(DataType::Float64, observed);
2571-
2572-
Ok(())
2573-
}
2574-
2575-
#[test]
2576-
fn test_lag_return_type() -> Result<()> {
2577-
let fun = find_df_window_func("lag").unwrap();
2578-
let observed = fun.return_type(&[DataType::Utf8], &[true], "")?;
2579-
assert_eq!(DataType::Utf8, observed);
2580-
2581-
let observed = fun.return_type(&[DataType::Float64], &[true], "")?;
2582-
assert_eq!(DataType::Float64, observed);
2583-
2584-
Ok(())
2585-
}
2586-
25872563
#[test]
25882564
fn test_nth_value_return_type() -> Result<()> {
25892565
let fun = find_df_window_func("nth_value").unwrap();
@@ -2621,8 +2597,6 @@ mod test {
26212597
let names = vec![
26222598
"cume_dist",
26232599
"ntile",
2624-
"lag",
2625-
"lead",
26262600
"first_value",
26272601
"last_value",
26282602
"nth_value",
@@ -2660,18 +2634,6 @@ mod test {
26602634
built_in_window_function::BuiltInWindowFunction::LastValue
26612635
))
26622636
);
2663-
assert_eq!(
2664-
find_df_window_func("LAG"),
2665-
Some(WindowFunctionDefinition::BuiltInWindowFunction(
2666-
built_in_window_function::BuiltInWindowFunction::Lag
2667-
))
2668-
);
2669-
assert_eq!(
2670-
find_df_window_func("LEAD"),
2671-
Some(WindowFunctionDefinition::BuiltInWindowFunction(
2672-
built_in_window_function::BuiltInWindowFunction::Lead
2673-
))
2674-
);
26752637
assert_eq!(find_df_window_func("not_exist"), None)
26762638
}
26772639

datafusion/expr/src/udwf.rs

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,10 @@ use crate::{
3434
Signature,
3535
};
3636
use datafusion_common::{not_impl_err, Result};
37+
use datafusion_functions_window_common::expr::ExpressionArgs;
3738
use datafusion_functions_window_common::field::WindowUDFFieldArgs;
3839
use datafusion_functions_window_common::partition::PartitionEvaluatorArgs;
40+
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
3941

4042
/// Logical representation of a user-defined window function (UDWF)
4143
/// A UDWF is different from a UDF in that it is stateful across batches.
@@ -149,6 +151,12 @@ impl WindowUDF {
149151
self.inner.simplify()
150152
}
151153

154+
/// Expressions that are passed to the [`PartitionEvaluator`].
155+
///
156+
/// See [`WindowUDFImpl::expressions`] for more details.
157+
pub fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
158+
self.inner.expressions(expr_args)
159+
}
152160
/// Return a `PartitionEvaluator` for evaluating this window function
153161
pub fn partition_evaluator_factory(
154162
&self,
@@ -302,6 +310,14 @@ pub trait WindowUDFImpl: Debug + Send + Sync {
302310
/// types are accepted and the function's Volatility.
303311
fn signature(&self) -> &Signature;
304312

313+
/// Returns the expressions that are passed to the [`PartitionEvaluator`].
314+
fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
315+
expr_args
316+
.input_exprs()
317+
.first()
318+
.map_or(vec![], |expr| vec![Arc::clone(expr)])
319+
}
320+
305321
/// Invoke the function, returning the [`PartitionEvaluator`] instance
306322
fn partition_evaluator(
307323
&self,
@@ -480,6 +496,13 @@ impl WindowUDFImpl for AliasedWindowUDFImpl {
480496
self.inner.signature()
481497
}
482498

499+
fn expressions(&self, expr_args: ExpressionArgs) -> Vec<Arc<dyn PhysicalExpr>> {
500+
expr_args
501+
.input_exprs()
502+
.first()
503+
.map_or(vec![], |expr| vec![Arc::clone(expr)])
504+
}
505+
483506
fn partition_evaluator(
484507
&self,
485508
partition_evaluator_args: PartitionEvaluatorArgs,

datafusion/expr/src/window_function.rs

Lines changed: 0 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,6 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use datafusion_common::ScalarValue;
19-
2018
use crate::{expr::WindowFunction, BuiltInWindowFunction, Expr, Literal};
2119

2220
/// Create an expression to represent the `cume_dist` window function
@@ -29,38 +27,6 @@ pub fn ntile(arg: Expr) -> Expr {
2927
Expr::WindowFunction(WindowFunction::new(BuiltInWindowFunction::Ntile, vec![arg]))
3028
}
3129

32-
/// Create an expression to represent the `lag` window function
33-
pub fn lag(
34-
arg: Expr,
35-
shift_offset: Option<i64>,
36-
default_value: Option<ScalarValue>,
37-
) -> Expr {
38-
let shift_offset_lit = shift_offset
39-
.map(|v| v.lit())
40-
.unwrap_or(ScalarValue::Null.lit());
41-
let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
42-
Expr::WindowFunction(WindowFunction::new(
43-
BuiltInWindowFunction::Lag,
44-
vec![arg, shift_offset_lit, default_lit],
45-
))
46-
}
47-
48-
/// Create an expression to represent the `lead` window function
49-
pub fn lead(
50-
arg: Expr,
51-
shift_offset: Option<i64>,
52-
default_value: Option<ScalarValue>,
53-
) -> Expr {
54-
let shift_offset_lit = shift_offset
55-
.map(|v| v.lit())
56-
.unwrap_or(ScalarValue::Null.lit());
57-
let default_lit = default_value.unwrap_or(ScalarValue::Null).lit();
58-
Expr::WindowFunction(WindowFunction::new(
59-
BuiltInWindowFunction::Lead,
60-
vec![arg, shift_offset_lit, default_lit],
61-
))
62-
}
63-
6430
/// Create an expression to represent the `nth_value` window function
6531
pub fn nth_value(arg: Expr, n: i64) -> Expr {
6632
Expr::WindowFunction(WindowFunction::new(
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use datafusion_common::arrow::datatypes::DataType;
19+
use datafusion_physical_expr_common::physical_expr::PhysicalExpr;
20+
use std::sync::Arc;
21+
22+
/// Arguments passed to user-defined window function
23+
#[derive(Debug, Default)]
24+
pub struct ExpressionArgs<'a> {
25+
/// The expressions passed as arguments to the user-defined window
26+
/// function.
27+
input_exprs: &'a [Arc<dyn PhysicalExpr>],
28+
/// The corresponding data types of expressions passed as arguments
29+
/// to the user-defined window function.
30+
input_types: &'a [DataType],
31+
}
32+
33+
impl<'a> ExpressionArgs<'a> {
34+
/// Create an instance of [`ExpressionArgs`].
35+
///
36+
/// # Arguments
37+
///
38+
/// * `input_exprs` - The expressions passed as arguments
39+
/// to the user-defined window function.
40+
/// * `input_types` - The data types corresponding to the
41+
/// arguments to the user-defined window function.
42+
///
43+
pub fn new(
44+
input_exprs: &'a [Arc<dyn PhysicalExpr>],
45+
input_types: &'a [DataType],
46+
) -> Self {
47+
Self {
48+
input_exprs,
49+
input_types,
50+
}
51+
}
52+
53+
/// Returns the expressions passed as arguments to the user-defined
54+
/// window function.
55+
pub fn input_exprs(&self) -> &'a [Arc<dyn PhysicalExpr>] {
56+
self.input_exprs
57+
}
58+
59+
/// Returns the [`DataType`]s corresponding to the input expressions
60+
/// to the user-defined window function.
61+
pub fn input_types(&self) -> &'a [DataType] {
62+
self.input_types
63+
}
64+
}

datafusion/functions-window-common/src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,5 +18,6 @@
1818
//! Common user-defined window functionality for [DataFusion]
1919
//!
2020
//! [DataFusion]: <https://crates.io/crates/datafusion>
21+
pub mod expr;
2122
pub mod field;
2223
pub mod partition;

datafusion/functions-window/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ path = "src/lib.rs"
4141
datafusion-common = { workspace = true }
4242
datafusion-expr = { workspace = true }
4343
datafusion-functions-window-common = { workspace = true }
44+
datafusion-physical-expr = { workspace = true }
4445
datafusion-physical-expr-common = { workspace = true }
4546
log = { workspace = true }
4647
paste = "1.0.15"

0 commit comments

Comments
 (0)