Skip to content

Commit fa0e40f

Browse files
committed
Support User Defined Window Functions
1 parent 98669b0 commit fa0e40f

File tree

32 files changed

+1232
-31
lines changed

32 files changed

+1232
-31
lines changed

datafusion-examples/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ prost = { version = "0.11", default-features = false }
5656
prost-derive = { version = "0.11", default-features = false }
5757
serde = { version = "1.0.136", features = ["derive"] }
5858
serde_json = "1.0.82"
59+
tempfile = "3"
5960
tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] }
6061
tonic = "0.9"
6162
url = "2.2"

datafusion-examples/README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ cargo run --example csv_sql
5757
- [`rewrite_expr.rs`](examples/rewrite_expr.rs): Define and invoke a custom Query Optimizer pass
5858
- [`simple_udaf.rs`](examples/simple_udaf.rs): Define and invoke a User Defined Aggregate Function (UDAF)
5959
- [`simple_udf.rs`](examples/simple_udf.rs): Define and invoke a User Defined (scalar) Function (UDF)
60+
- [`simple_udfw.rs`](examples/simple_udwf.rs): Define and invoke a User Defined Window Function (UDWF)
6061

6162
## Distributed
6263

datafusion-examples/examples/rewrite_expr.rs

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ use datafusion_common::config::ConfigOptions;
2020
use datafusion_common::tree_node::{Transformed, TreeNode};
2121
use datafusion_common::{DataFusionError, Result, ScalarValue};
2222
use datafusion_expr::{
23-
AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource,
23+
AggregateUDF, Between, Expr, Filter, LogicalPlan, ScalarUDF, TableSource, WindowUDF,
2424
};
2525
use datafusion_optimizer::analyzer::{Analyzer, AnalyzerRule};
2626
use datafusion_optimizer::optimizer::Optimizer;
@@ -216,6 +216,10 @@ impl ContextProvider for MyContextProvider {
216216
None
217217
}
218218

219+
fn get_window_meta(&self, _name: &str) -> Option<Arc<WindowUDF>> {
220+
None
221+
}
222+
219223
fn options(&self) -> &ConfigOptions {
220224
&self.options
221225
}
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use std::sync::Arc;
19+
20+
use arrow::{
21+
array::{ArrayRef, AsArray, Float64Array},
22+
datatypes::Float64Type,
23+
};
24+
use arrow_schema::DataType;
25+
use datafusion::datasource::file_format::options::CsvReadOptions;
26+
27+
use datafusion::error::Result;
28+
use datafusion::prelude::*;
29+
use datafusion_common::{DataFusionError, ScalarValue};
30+
use datafusion_expr::{
31+
PartitionEvaluator, Signature, Volatility, WindowFrame, WindowUDF,
32+
};
33+
34+
// create local execution context with `cars.csv` registered as a table named `cars`
35+
async fn create_context() -> Result<SessionContext> {
36+
// declare a new context. In spark API, this corresponds to a new spark SQLsession
37+
let ctx = SessionContext::new();
38+
39+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
40+
println!("pwd: {}", std::env::current_dir().unwrap().display());
41+
let csv_path = "datafusion/core/tests/data/cars.csv".to_string();
42+
let read_options = CsvReadOptions::default().has_header(true);
43+
44+
ctx.register_csv("cars", &csv_path, read_options).await?;
45+
Ok(ctx)
46+
}
47+
48+
/// In this example we will declare a user defined window function that computes a moving average and then run it using SQL
49+
#[tokio::main]
50+
async fn main() -> Result<()> {
51+
let ctx = create_context().await?;
52+
53+
// register the window function with DataFusion so wecan call it
54+
ctx.register_udwf(smooth_it());
55+
56+
// Use SQL to run the new window function
57+
let df = ctx.sql("SELECT * from cars").await?;
58+
// print the results
59+
df.show().await?;
60+
61+
// Use SQL to run the new window function:
62+
//
63+
// `PARTITION BY car`:each distinct value of car (red, and green)
64+
// should be treated as a seprate partition (and will result in
65+
// creating a new `PartitionEvaluator`)
66+
//
67+
// `ORDER BY time`: within each partition ('green' or 'red') the
68+
// rows will be be orderd by the value in the `time` column
69+
//
70+
// `evaluate_inside_range` is invoked with a window defined by the
71+
// SQL. In this case:
72+
//
73+
// The first invocation will be passed row 0, the first row in the
74+
// partition.
75+
//
76+
// The second invocation will be passed rows 0 and 1, the first
77+
// two rows in the partition.
78+
//
79+
// etc.
80+
let df = ctx
81+
.sql(
82+
"SELECT \
83+
car, \
84+
speed, \
85+
smooth_it(speed) OVER (PARTITION BY car ORDER BY time),\
86+
time \
87+
from cars \
88+
ORDER BY \
89+
car",
90+
)
91+
.await?;
92+
// print the results
93+
df.show().await?;
94+
95+
// this time, call the new widow function with an explicit
96+
// window so evaluate will be invoked with each window.
97+
//
98+
// `ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING`: each invocation
99+
// sees at most 3 rows: the row before, the current row, and the 1
100+
// row afterward.
101+
let df = ctx.sql(
102+
"SELECT \
103+
car, \
104+
speed, \
105+
smooth_it(speed) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 1 PRECEDING AND 1 FOLLOWING),\
106+
time \
107+
from cars \
108+
ORDER BY \
109+
car",
110+
).await?;
111+
// print the results
112+
df.show().await?;
113+
114+
// Now, run the function using the DataFrame API:
115+
let window_expr = smooth_it().call(
116+
vec![col("speed")], // smooth_it(speed)
117+
vec![col("car")], // PARTITION BY car
118+
vec![col("time").sort(true, true)], // ORDER BY time ASC
119+
WindowFrame::new(false),
120+
);
121+
let df = ctx.table("cars").await?.window(vec![window_expr])?;
122+
123+
// print the results
124+
df.show().await?;
125+
126+
Ok(())
127+
}
128+
129+
fn smooth_it() -> WindowUDF {
130+
WindowUDF {
131+
name: String::from("smooth_it"),
132+
// it will take 1 arguments -- the column to smooth
133+
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
134+
return_type: Arc::new(return_type),
135+
partition_evaluator_factory: Arc::new(make_partition_evaluator),
136+
}
137+
}
138+
139+
/// Compute the return type of the smooth_it window function given
140+
/// arguments of `arg_types`.
141+
fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
142+
if arg_types.len() != 1 {
143+
return Err(DataFusionError::Plan(format!(
144+
"my_udwf expects 1 argument, got {}: {:?}",
145+
arg_types.len(),
146+
arg_types
147+
)));
148+
}
149+
Ok(Arc::new(arg_types[0].clone()))
150+
}
151+
152+
/// Create a `PartitionEvalutor` to evaluate this function on a new
153+
/// partition.
154+
fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
155+
Ok(Box::new(MyPartitionEvaluator::new()))
156+
}
157+
158+
/// This implements the lowest level evaluation for a window function
159+
///
160+
/// It handles calculating the value of the window function for each
161+
/// distinct values of `PARTITION BY` (each car type in our example)
162+
#[derive(Clone, Debug)]
163+
struct MyPartitionEvaluator {}
164+
165+
impl MyPartitionEvaluator {
166+
fn new() -> Self {
167+
Self {}
168+
}
169+
}
170+
171+
/// Different evaluation methods are called depending on the various
172+
/// settings of WindowUDF. This example uses the simplest and most
173+
/// general, `evaluate`. See `PartitionEvaluator` for the other more
174+
/// advanced uses.
175+
impl PartitionEvaluator for MyPartitionEvaluator {
176+
/// Tell DataFusion the window function varies based on the value
177+
/// of the window frame.
178+
fn uses_window_frame(&self) -> bool {
179+
true
180+
}
181+
182+
/// This function is called once per input row.
183+
///
184+
/// `range`specifies which indexes of `values` should be
185+
/// considered for the calculation.
186+
///
187+
/// Note this is the SLOWEST, but simplest, way to evaluate a
188+
/// window function. It is much faster to implement
189+
/// evaluate_all or evaluate_all_with_rank, if possible
190+
fn evaluate(
191+
&mut self,
192+
values: &[ArrayRef],
193+
range: &std::ops::Range<usize>,
194+
) -> Result<ScalarValue> {
195+
//println!("evaluate_inside_range(). range: {range:#?}, values: {values:#?}");
196+
197+
// Again, the input argument is an array of floating
198+
// point numbers to calculate a moving average
199+
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>();
200+
201+
let range_len = range.end - range.start;
202+
203+
// our smoothing function will average all the values in the
204+
let output = if range_len > 0 {
205+
let sum: f64 = arr.values().iter().skip(range.start).take(range_len).sum();
206+
Some(sum / range_len as f64)
207+
} else {
208+
None
209+
};
210+
211+
Ok(ScalarValue::Float64(output))
212+
}
213+
}

datafusion/core/src/dataframe.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -218,6 +218,14 @@ impl DataFrame {
218218
Ok(DataFrame::new(self.session_state, plan))
219219
}
220220

221+
/// Apply one or more window functions ([`Expr::WindowFunction`]) to extend the schema
222+
pub fn window(self, window_exprs: Vec<Expr>) -> Result<DataFrame> {
223+
let plan = LogicalPlanBuilder::from(self.plan)
224+
.window(window_exprs)?
225+
.build()?;
226+
Ok(DataFrame::new(self.session_state, plan))
227+
}
228+
221229
/// Limit the number of rows returned from this DataFrame.
222230
///
223231
/// `skip` - Number of rows to skip before fetch any row

datafusion/core/src/execution/context.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ use datafusion_common::alias::AliasGenerator;
3232
use datafusion_execution::registry::SerializerRegistry;
3333
use datafusion_expr::{
3434
logical_plan::{DdlStatement, Statement},
35-
DescribeTable, StringifiedPlan, UserDefinedLogicalNode,
35+
DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
3636
};
3737
pub use datafusion_physical_expr::execution_props::ExecutionProps;
3838
use datafusion_physical_expr::var_provider::is_system_variables;
@@ -786,6 +786,20 @@ impl SessionContext {
786786
.insert(f.name.clone(), Arc::new(f));
787787
}
788788

789+
/// Registers an window UDF within this context.
790+
///
791+
/// Note in SQL queries, window function names are looked up using
792+
/// lowercase unless the query uses quotes. For example,
793+
///
794+
/// - `SELECT MY_UDWF(x)...` will look for a window function named `"my_uwaf"`
795+
/// - `SELECT "my_UDWF"(x)` will look for a window function named `"my_UDWF"`
796+
pub fn register_udwf(&self, f: WindowUDF) {
797+
self.state
798+
.write()
799+
.window_functions
800+
.insert(f.name.clone(), Arc::new(f));
801+
}
802+
789803
/// Creates a [`DataFrame`] for reading a data source.
790804
///
791805
/// For more control such as reading multiple files, you can use
@@ -1279,6 +1293,10 @@ impl FunctionRegistry for SessionContext {
12791293
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
12801294
self.state.read().udaf(name)
12811295
}
1296+
1297+
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
1298+
self.state.read().udwf(name)
1299+
}
12821300
}
12831301

12841302
/// A planner used to add extensions to DataFusion logical and physical plans.
@@ -1329,6 +1347,8 @@ pub struct SessionState {
13291347
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
13301348
/// Aggregate functions registered in the context
13311349
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
1350+
/// Window functions registered in the context
1351+
window_functions: HashMap<String, Arc<WindowUDF>>,
13321352
/// Deserializer registry for extensions.
13331353
serializer_registry: Arc<dyn SerializerRegistry>,
13341354
/// Session configuration
@@ -1423,6 +1443,7 @@ impl SessionState {
14231443
catalog_list,
14241444
scalar_functions: HashMap::new(),
14251445
aggregate_functions: HashMap::new(),
1446+
window_functions: HashMap::new(),
14261447
serializer_registry: Arc::new(EmptySerializerRegistry),
14271448
config,
14281449
execution_props: ExecutionProps::new(),
@@ -1899,6 +1920,11 @@ impl SessionState {
18991920
&self.aggregate_functions
19001921
}
19011922

1923+
/// Return reference to window functions
1924+
pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
1925+
&self.window_functions
1926+
}
1927+
19021928
/// Return [SerializerRegistry] for extensions
19031929
pub fn serializer_registry(&self) -> Arc<dyn SerializerRegistry> {
19041930
self.serializer_registry.clone()
@@ -1932,6 +1958,10 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
19321958
self.state.aggregate_functions().get(name).cloned()
19331959
}
19341960

1961+
fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
1962+
self.state.window_functions().get(name).cloned()
1963+
}
1964+
19351965
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
19361966
if variable_names.is_empty() {
19371967
return None;
@@ -1979,6 +2009,16 @@ impl FunctionRegistry for SessionState {
19792009
))
19802010
})
19812011
}
2012+
2013+
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
2014+
let result = self.window_functions.get(name);
2015+
2016+
result.cloned().ok_or_else(|| {
2017+
DataFusionError::Plan(format!(
2018+
"There is no UDWF named \"{name}\" in the registry"
2019+
))
2020+
})
2021+
}
19822022
}
19832023

19842024
impl OptimizerConfig for SessionState {
@@ -2012,6 +2052,7 @@ impl From<&SessionState> for TaskContext {
20122052
state.config.clone(),
20132053
state.scalar_functions.clone(),
20142054
state.aggregate_functions.clone(),
2055+
state.window_functions.clone(),
20152056
state.runtime_env.clone(),
20162057
)
20172058
}

0 commit comments

Comments
 (0)