Skip to content

Commit 1bc2d6e

Browse files
committed
Updates and get example compiling
1 parent edf0afc commit 1bc2d6e

File tree

24 files changed

+330
-83
lines changed

24 files changed

+330
-83
lines changed

datafusion-examples/examples/simple_udwf.rs

Lines changed: 144 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,21 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18+
use std::sync::Arc;
19+
20+
use arrow::{
21+
array::{AsArray, Float64Array},
22+
datatypes::Float64Type,
23+
};
24+
use arrow_schema::DataType;
1825
use datafusion::datasource::file_format::options::CsvReadOptions;
1926

2027
use datafusion::error::Result;
2128
use datafusion::prelude::*;
29+
use datafusion_common::DataFusionError;
30+
use datafusion_expr::{
31+
partition_evaluator::PartitionEvaluator, Signature, Volatility, WindowUDF,
32+
};
2233

2334
// create local execution context with `cars.csv` registered as a table named `cars`
2435
async fn create_context() -> Result<SessionContext> {
@@ -39,6 +50,9 @@ async fn create_context() -> Result<SessionContext> {
3950
async fn main() -> Result<()> {
4051
let ctx = create_context().await?;
4152

53+
// register the window function with DataFusion so wecan call it
54+
ctx.register_udwf(my_average());
55+
4256
// Use SQL to run the new window function
4357
let df = ctx.sql("SELECT * from cars").await?;
4458
// print the results
@@ -52,23 +66,145 @@ async fn main() -> Result<()> {
5266
"SELECT car, \
5367
speed, \
5468
lag(speed, 1) OVER (PARTITION BY car ORDER BY time),\
69+
my_average(speed) OVER (PARTITION BY car ORDER BY time),\
5570
time \
5671
from cars",
5772
)
5873
.await?;
5974
// print the results
6075
df.show().await?;
6176

62-
// ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using
63-
let df = ctx.sql("SELECT car, \
64-
speed, \
65-
lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\
66-
time \
67-
from cars").await?;
68-
// print the results
69-
df.show().await?;
77+
// // ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING: Run the window functon so that each invocation only sees 5 rows: the 2 before and 2 after) using
78+
// let df = ctx.sql("SELECT car, \
79+
// speed, \
80+
// lag(speed, 1) OVER (PARTITION BY car ORDER BY time ROWS BETWEEN 2 PRECEDING AND 2 FOLLOWING),\
81+
// time \
82+
// from cars").await?;
83+
// // print the results
84+
// df.show().await?;
7085

7186
// todo show how to run dataframe API as well
7287

7388
Ok(())
7489
}
90+
91+
// TODO make a helper funciton like `crate_udf` that helps to make these signatures
92+
93+
fn my_average() -> WindowUDF {
94+
WindowUDF {
95+
name: String::from("my_average"),
96+
// it will take 2 arguments -- the column and the window size
97+
signature: Signature::exact(vec![DataType::Int32], Volatility::Immutable),
98+
return_type: Arc::new(return_type),
99+
partition_evaluator: Arc::new(make_partition_evaluator),
100+
}
101+
}
102+
103+
/// Compute the return type of the function given the argument types
104+
fn return_type(arg_types: &[DataType]) -> Result<Arc<DataType>> {
105+
if arg_types.len() != 1 {
106+
return Err(DataFusionError::Plan(format!(
107+
"my_udwf expects 1 argument, got {}: {:?}",
108+
arg_types.len(),
109+
arg_types
110+
)));
111+
}
112+
Ok(Arc::new(arg_types[0].clone()))
113+
}
114+
115+
/// Create a partition evaluator for this argument
116+
fn make_partition_evaluator() -> Result<Box<dyn PartitionEvaluator>> {
117+
Ok(Box::new(MyPartitionEvaluator::new()))
118+
}
119+
120+
/// This implements the lowest level evaluation for a window function
121+
///
122+
/// It handles calculating the value of the window function for each
123+
/// distinct values of `PARTITION BY` (each car type in our example)
124+
#[derive(Clone, Debug)]
125+
struct MyPartitionEvaluator {}
126+
127+
impl MyPartitionEvaluator {
128+
fn new() -> Self {
129+
Self {}
130+
}
131+
}
132+
133+
/// These different evaluation methods are called depending on the various settings of WindowUDF
134+
impl PartitionEvaluator for MyPartitionEvaluator {
135+
fn get_range(&self, _idx: usize, _n_rows: usize) -> Result<std::ops::Range<usize>> {
136+
Err(DataFusionError::NotImplemented(
137+
"get_range is not implemented for this window function".to_string(),
138+
))
139+
}
140+
141+
/// This function is given the values of each partition
142+
fn evaluate(
143+
&self,
144+
values: &[arrow::array::ArrayRef],
145+
_num_rows: usize,
146+
) -> Result<arrow::array::ArrayRef> {
147+
// datafusion has handled ensuring we get the correct input argument
148+
assert_eq!(values.len(), 1);
149+
150+
// For this example, we convert convert the input argument to an
151+
// array of floating point numbers to calculate a moving average
152+
let arr: &Float64Array = values[0].as_ref().as_primitive::<Float64Type>();
153+
154+
// implement a simple moving average by averaging the current
155+
// value with the previous value
156+
//
157+
// value | avg
158+
// ------+------
159+
// 10 | 10
160+
// 20 | 15
161+
// 30 | 25
162+
// 30 | 30
163+
//
164+
let mut previous_value = None;
165+
let new_values: Float64Array = arr
166+
.values()
167+
.iter()
168+
.map(|&value| {
169+
let new_value = previous_value
170+
.map(|previous_value| (value + previous_value) / 2.0)
171+
.unwrap_or(value);
172+
previous_value = Some(value);
173+
new_value
174+
})
175+
.collect();
176+
177+
Ok(Arc::new(new_values))
178+
}
179+
180+
fn evaluate_stateful(
181+
&mut self,
182+
_values: &[arrow::array::ArrayRef],
183+
) -> Result<datafusion_common::ScalarValue> {
184+
Err(DataFusionError::NotImplemented(
185+
"evaluate_stateful is not implemented by default".into(),
186+
))
187+
}
188+
189+
fn evaluate_with_rank(
190+
&self,
191+
_num_rows: usize,
192+
_ranks_in_partition: &[std::ops::Range<usize>],
193+
) -> Result<arrow::array::ArrayRef> {
194+
Err(DataFusionError::NotImplemented(
195+
"evaluate_partition_with_rank is not implemented by default".into(),
196+
))
197+
}
198+
199+
fn evaluate_inside_range(
200+
&self,
201+
_values: &[arrow::array::ArrayRef],
202+
_range: &std::ops::Range<usize>,
203+
) -> Result<datafusion_common::ScalarValue> {
204+
Err(DataFusionError::NotImplemented(
205+
"evaluate_inside_range is not implemented by default".into(),
206+
))
207+
}
208+
}
209+
210+
// TODO show how to use other evaluate methods

datafusion/core/src/execution/context.rs

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ use crate::{
3434
use datafusion_execution::registry::SerializerRegistry;
3535
use datafusion_expr::{
3636
logical_plan::{DdlStatement, Statement},
37-
DescribeTable, StringifiedPlan, UserDefinedLogicalNode,
37+
DescribeTable, StringifiedPlan, UserDefinedLogicalNode, WindowUDF,
3838
};
3939
pub use datafusion_physical_expr::execution_props::ExecutionProps;
4040
use datafusion_physical_expr::var_provider::is_system_variables;
@@ -797,6 +797,20 @@ impl SessionContext {
797797
.insert(f.name.clone(), Arc::new(f));
798798
}
799799

800+
/// Registers an window UDF within this context.
801+
///
802+
/// Note in SQL queries, window function names are looked up using
803+
/// lowercase unless the query uses quotes. For example,
804+
///
805+
/// - `SELECT MY_UDAF(x)...` will look for an aggregate named `"my_udaf"`
806+
/// - `SELECT "my_UDAF"(x)` will look for an aggregate named `"my_UDAF"`
807+
pub fn register_udwf(&self, f: WindowUDF) {
808+
self.state
809+
.write()
810+
.window_functions
811+
.insert(f.name.clone(), Arc::new(f));
812+
}
813+
800814
/// Creates a [`DataFrame`] for reading a data source.
801815
///
802816
/// For more control such as reading multiple files, you can use
@@ -1290,6 +1304,10 @@ impl FunctionRegistry for SessionContext {
12901304
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>> {
12911305
self.state.read().udaf(name)
12921306
}
1307+
1308+
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
1309+
self.state.read().udwf(name)
1310+
}
12931311
}
12941312

12951313
/// A planner used to add extensions to DataFusion logical and physical plans.
@@ -1340,6 +1358,8 @@ pub struct SessionState {
13401358
scalar_functions: HashMap<String, Arc<ScalarUDF>>,
13411359
/// Aggregate functions registered in the context
13421360
aggregate_functions: HashMap<String, Arc<AggregateUDF>>,
1361+
/// Window functions registered in the context
1362+
window_functions: HashMap<String, Arc<WindowUDF>>,
13431363
/// Deserializer registry for extensions.
13441364
serializer_registry: Arc<dyn SerializerRegistry>,
13451365
/// Session configuration
@@ -1483,6 +1503,7 @@ impl SessionState {
14831503
catalog_list,
14841504
scalar_functions: HashMap::new(),
14851505
aggregate_functions: HashMap::new(),
1506+
window_functions: HashMap::new(),
14861507
serializer_registry: Arc::new(EmptySerializerRegistry),
14871508
config,
14881509
execution_props: ExecutionProps::new(),
@@ -1959,6 +1980,11 @@ impl SessionState {
19591980
&self.aggregate_functions
19601981
}
19611982

1983+
/// Return reference to window functions
1984+
pub fn window_functions(&self) -> &HashMap<String, Arc<WindowUDF>> {
1985+
&self.window_functions
1986+
}
1987+
19621988
/// Return [SerializerRegistry] for extensions
19631989
pub fn serializer_registry(&self) -> Arc<dyn SerializerRegistry> {
19641990
self.serializer_registry.clone()
@@ -1992,6 +2018,10 @@ impl<'a> ContextProvider for SessionContextProvider<'a> {
19922018
self.state.aggregate_functions().get(name).cloned()
19932019
}
19942020

2021+
fn get_window_meta(&self, name: &str) -> Option<Arc<WindowUDF>> {
2022+
self.state.window_functions().get(name).cloned()
2023+
}
2024+
19952025
fn get_variable_type(&self, variable_names: &[String]) -> Option<DataType> {
19962026
if variable_names.is_empty() {
19972027
return None;
@@ -2039,6 +2069,16 @@ impl FunctionRegistry for SessionState {
20392069
))
20402070
})
20412071
}
2072+
2073+
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>> {
2074+
let result = self.window_functions.get(name);
2075+
2076+
result.cloned().ok_or_else(|| {
2077+
DataFusionError::Plan(format!(
2078+
"There is no UDWF named \"{name}\" in the registry"
2079+
))
2080+
})
2081+
}
20422082
}
20432083

20442084
impl OptimizerConfig for SessionState {
@@ -2068,6 +2108,7 @@ impl From<&SessionState> for TaskContext {
20682108
state.config.clone(),
20692109
state.scalar_functions.clone(),
20702110
state.aggregate_functions.clone(),
2111+
state.window_functions.clone(),
20712112
state.runtime_env.clone(),
20722113
)
20732114
}

datafusion/core/src/physical_plan/windows/mod.rs

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,11 @@ fn create_udwf_window_expr(
198198
name: String,
199199
) -> Result<Arc<dyn BuiltInWindowFunctionExpr>> {
200200
// need to get the types into an owned vec for some reason
201-
let input_types: Vec<_> = input_schema.fields().iter().map(|f| f.data_type().clone()).collect();
201+
let input_types: Vec<_> = args
202+
.iter()
203+
.map(|arg| arg.data_type(input_schema).map(|dt| dt.clone()))
204+
.collect::<Result<_>>()?;
205+
202206
// figure out the output type
203207
let data_type = (fun.return_type)(&input_types)?;
204208
Ok(Arc::new(WindowUDFExpr {
@@ -227,15 +231,23 @@ impl BuiltInWindowFunctionExpr for WindowUDFExpr {
227231

228232
fn field(&self) -> Result<Field> {
229233
let nullable = false;
230-
Ok(Field::new(&self.name, self.data_type.as_ref().clone(), nullable))
234+
Ok(Field::new(
235+
&self.name,
236+
self.data_type.as_ref().clone(),
237+
nullable,
238+
))
231239
}
232240

233241
fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
234242
self.args.clone()
235243
}
236244

237245
fn create_evaluator(&self) -> Result<Box<dyn PartitionEvaluator>> {
238-
todo!()
246+
(self.fun.partition_evaluator)()
247+
}
248+
249+
fn name(&self) -> &str {
250+
&self.name
239251
}
240252
}
241253

datafusion/core/tests/data/cars.csv

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,3 @@ green,15.1,1996-04-12T12:05:11.000000000
2424
green,15.2,1996-04-12T12:05:12.000000000
2525
green,8.0,1996-04-12T12:05:13.000000000
2626
green,2.0,1996-04-12T12:05:14.000000000
27-
green,0.0,1996-04-12T12:05:15.000000000

datafusion/execution/src/registry.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
//! FunctionRegistry trait
1919
2020
use datafusion_common::Result;
21-
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode};
21+
use datafusion_expr::{AggregateUDF, ScalarUDF, UserDefinedLogicalNode, WindowUDF};
2222
use std::{collections::HashSet, sync::Arc};
2323

2424
/// A registry knows how to build logical expressions out of user-defined function' names
@@ -31,6 +31,9 @@ pub trait FunctionRegistry {
3131

3232
/// Returns a reference to the udaf named `name`.
3333
fn udaf(&self, name: &str) -> Result<Arc<AggregateUDF>>;
34+
35+
/// Returns a reference to the udwf named `name`.
36+
fn udwf(&self, name: &str) -> Result<Arc<WindowUDF>>;
3437
}
3538

3639
/// Serializer and deserializer registry for extensions like [UserDefinedLogicalNode].

0 commit comments

Comments
 (0)