Skip to content

Commit edf0afc

Browse files
committed
Move window_frame_state and partition_evaluator to datafusion_expr
1 parent 53064e1 commit edf0afc

File tree

6 files changed

+110
-92
lines changed

6 files changed

+110
-92
lines changed

datafusion/expr/src/function.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
2020
use crate::function_err::generate_signature_error_msg;
2121
use crate::nullif::SUPPORTED_NULLIF_TYPES;
22+
use crate::partition_evaluator::PartitionEvaluator;
2223
use crate::type_coercion::functions::data_types;
2324
use crate::ColumnarValue;
2425
use crate::{
@@ -54,6 +55,12 @@ pub type AccumulatorFunctionImplementation =
5455
pub type StateTypeFunction =
5556
Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
5657

58+
/// Factory that creates a PartitionEvaluator for the given aggregate, given
59+
/// its return datatype.
60+
pub type PartitionEvaluatorFunctionFactory =
61+
Arc<dyn Fn(&DataType) -> Result<Box<dyn PartitionEvaluator>> + Send + Sync>;
62+
63+
5764
macro_rules! make_utf8_to_return_type {
5865
($FUNC:ident, $largeUtf8Type:expr, $utf8Type:expr) => {
5966
fn $FUNC(arg_type: &DataType, name: &str) -> Result<DataType> {

datafusion/expr/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ mod udwf;
5353
pub mod utils;
5454
pub mod window_frame;
5555
pub mod window_function;
56+
pub mod partition_evaluator;
57+
pub mod window_frame_state;
5658

5759
pub use accumulator::Accumulator;
5860
pub use aggregate_function::AggregateFunction;

datafusion/physical-expr/src/window/partition_evaluator.rs renamed to datafusion/expr/src/partition_evaluator.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,25 @@
1717

1818
//! Partition evaluation module
1919
20-
use crate::window::window_expr::BuiltinWindowState;
21-
use crate::window::WindowAggState;
20+
use crate::window_frame_state::WindowAggState;
2221
use arrow::array::ArrayRef;
2322
use datafusion_common::Result;
2423
use datafusion_common::{DataFusionError, ScalarValue};
24+
use std::any::Any;
2525
use std::fmt::Debug;
2626
use std::ops::Range;
2727

28+
29+
/// Trait for the state managed by this partition evaluator
30+
///
31+
/// This follows the existing pattern, but maybe we can improve it :thinking:
32+
33+
pub trait PartitionState {
34+
/// Returns the aggregate expression as [`Any`](std::any::Any) so that it can be
35+
/// downcast to a specific implementation.
36+
fn as_any(&self) -> &dyn Any;
37+
}
38+
2839
/// Partition evaluator for Window Functions
2940
///
3041
/// # Background
@@ -100,12 +111,9 @@ pub trait PartitionEvaluator: Debug + Send {
100111
false
101112
}
102113

103-
/// Returns the internal state of the window function
104-
///
105-
/// Only used for stateful evaluation
106-
fn state(&self) -> Result<BuiltinWindowState> {
107-
// If we do not use state we just return Default
108-
Ok(BuiltinWindowState::Default)
114+
/// Returns the internal state of the window function, if any
115+
fn state(&self) -> Result<Option<Box<dyn PartitionState>>> {
116+
Ok(None)
109117
}
110118

111119
/// Updates the internal state for window function
@@ -130,7 +138,7 @@ pub trait PartitionEvaluator: Debug + Send {
130138
/// Sets the internal state for window function
131139
///
132140
/// Only used for stateful evaluation
133-
fn set_state(&mut self, _state: &BuiltinWindowState) -> Result<()> {
141+
fn set_state(&mut self, state: Box<dyn PartitionState>) -> Result<()> {
134142
Err(DataFusionError::NotImplemented(
135143
"set_state is not implemented for this window function".to_string(),
136144
))

datafusion/physical-expr/src/window/window_frame_state.rs renamed to datafusion/expr/src/window_frame_state.rs

Lines changed: 84 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,97 @@
1919
//! depending on the window frame mode: RANGE, ROWS, GROUPS.
2020
2121
use arrow::array::ArrayRef;
22+
use arrow::compute::{concat};
2223
use arrow::compute::kernels::sort::SortOptions;
24+
use arrow::record_batch::RecordBatch;
2325
use datafusion_common::utils::{compare_rows, get_row_at_idx, search_in_slice};
2426
use datafusion_common::{DataFusionError, Result, ScalarValue};
25-
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
27+
use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
2628
use std::cmp::min;
2729
use std::collections::VecDeque;
2830
use std::fmt::Debug;
2931
use std::ops::Range;
3032
use std::sync::Arc;
3133

34+
35+
/// State for each unique partition determined according to PARTITION BY column(s)
36+
#[derive(Debug)]
37+
pub struct PartitionBatchState {
38+
/// The record_batch belonging to current partition
39+
pub record_batch: RecordBatch,
40+
/// Flag indicating whether we have received all data for this partition
41+
pub is_end: bool,
42+
/// Number of rows emitted for each partition
43+
pub n_out_row: usize,
44+
}
45+
46+
47+
#[derive(Debug)]
48+
pub struct WindowAggState {
49+
/// The range that we calculate the window function
50+
pub window_frame_range: Range<usize>,
51+
pub window_frame_ctx: Option<WindowFrameContext>,
52+
/// The index of the last row that its result is calculated inside the partition record batch buffer.
53+
pub last_calculated_index: usize,
54+
/// The offset of the deleted row number
55+
pub offset_pruned_rows: usize,
56+
/// Stores the results calculated by window frame
57+
pub out_col: ArrayRef,
58+
/// Keeps track of how many rows should be generated to be in sync with input record_batch.
59+
// (For each row in the input record batch we need to generate a window result).
60+
pub n_row_result_missing: usize,
61+
/// flag indicating whether we have received all data for this partition
62+
pub is_end: bool,
63+
}
64+
65+
impl WindowAggState {
66+
pub fn prune_state(&mut self, n_prune: usize) {
67+
self.window_frame_range = Range {
68+
start: self.window_frame_range.start - n_prune,
69+
end: self.window_frame_range.end - n_prune,
70+
};
71+
self.last_calculated_index -= n_prune;
72+
self.offset_pruned_rows += n_prune;
73+
74+
match self.window_frame_ctx.as_mut() {
75+
// Rows have no state do nothing
76+
Some(WindowFrameContext::Rows(_)) => {}
77+
Some(WindowFrameContext::Range { .. }) => {}
78+
Some(WindowFrameContext::Groups { state, .. }) => {
79+
let mut n_group_to_del = 0;
80+
for (_, end_idx) in &state.group_end_indices {
81+
if n_prune < *end_idx {
82+
break;
83+
}
84+
n_group_to_del += 1;
85+
}
86+
state.group_end_indices.drain(0..n_group_to_del);
87+
state
88+
.group_end_indices
89+
.iter_mut()
90+
.for_each(|(_, start_idx)| *start_idx -= n_prune);
91+
state.current_group_idx -= n_group_to_del;
92+
}
93+
None => {}
94+
};
95+
}
96+
}
97+
98+
impl WindowAggState {
99+
pub fn update(
100+
&mut self,
101+
out_col: &ArrayRef,
102+
partition_batch_state: &PartitionBatchState,
103+
) -> Result<()> {
104+
self.last_calculated_index += out_col.len();
105+
self.out_col = concat(&[&self.out_col, &out_col])?;
106+
self.n_row_result_missing =
107+
partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
108+
self.is_end = partition_batch_state.is_end;
109+
Ok(())
110+
}
111+
}
112+
32113
/// This object stores the window frame state for use in incremental calculations.
33114
#[derive(Debug)]
34115
pub enum WindowFrameContext {
@@ -547,11 +628,10 @@ fn check_equality(current: &[ScalarValue], target: &[ScalarValue]) -> Result<boo
547628

548629
#[cfg(test)]
549630
mod tests {
550-
use crate::window::window_frame_state::WindowFrameStateGroups;
631+
use super::*;
551632
use arrow::array::{ArrayRef, Float64Array};
552-
use arrow_schema::SortOptions;
553633
use datafusion_common::{Result, ScalarValue};
554-
use datafusion_expr::{WindowFrame, WindowFrameBound, WindowFrameUnits};
634+
use crate::{WindowFrame, WindowFrameBound, WindowFrameUnits};
555635
use std::ops::Range;
556636
use std::sync::Arc;
557637

datafusion/physical-expr/src/window/mod.rs

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@ pub(crate) mod cume_dist;
2222
pub(crate) mod lead_lag;
2323
pub(crate) mod nth_value;
2424
pub(crate) mod ntile;
25-
pub(crate) mod partition_evaluator;
2625
pub(crate) mod rank;
2726
pub(crate) mod row_number;
2827
mod sliding_aggregate;
2928
mod window_expr;
30-
mod window_frame_state;
3129

3230
pub use aggregate::PlainAggregateWindowExpr;
3331
pub use built_in::BuiltInWindowExpr;

datafusion/physical-expr/src/window/window_expr.rs

Lines changed: 0 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -337,83 +337,6 @@ pub enum BuiltinWindowState {
337337
Default,
338338
}
339339

340-
#[derive(Debug)]
341-
pub struct WindowAggState {
342-
/// The range that we calculate the window function
343-
pub window_frame_range: Range<usize>,
344-
pub window_frame_ctx: Option<WindowFrameContext>,
345-
/// The index of the last row that its result is calculated inside the partition record batch buffer.
346-
pub last_calculated_index: usize,
347-
/// The offset of the deleted row number
348-
pub offset_pruned_rows: usize,
349-
/// Stores the results calculated by window frame
350-
pub out_col: ArrayRef,
351-
/// Keeps track of how many rows should be generated to be in sync with input record_batch.
352-
// (For each row in the input record batch we need to generate a window result).
353-
pub n_row_result_missing: usize,
354-
/// flag indicating whether we have received all data for this partition
355-
pub is_end: bool,
356-
}
357-
358-
impl WindowAggState {
359-
pub fn prune_state(&mut self, n_prune: usize) {
360-
self.window_frame_range = Range {
361-
start: self.window_frame_range.start - n_prune,
362-
end: self.window_frame_range.end - n_prune,
363-
};
364-
self.last_calculated_index -= n_prune;
365-
self.offset_pruned_rows += n_prune;
366-
367-
match self.window_frame_ctx.as_mut() {
368-
// Rows have no state do nothing
369-
Some(WindowFrameContext::Rows(_)) => {}
370-
Some(WindowFrameContext::Range { .. }) => {}
371-
Some(WindowFrameContext::Groups { state, .. }) => {
372-
let mut n_group_to_del = 0;
373-
for (_, end_idx) in &state.group_end_indices {
374-
if n_prune < *end_idx {
375-
break;
376-
}
377-
n_group_to_del += 1;
378-
}
379-
state.group_end_indices.drain(0..n_group_to_del);
380-
state
381-
.group_end_indices
382-
.iter_mut()
383-
.for_each(|(_, start_idx)| *start_idx -= n_prune);
384-
state.current_group_idx -= n_group_to_del;
385-
}
386-
None => {}
387-
};
388-
}
389-
}
390-
391-
impl WindowAggState {
392-
pub fn update(
393-
&mut self,
394-
out_col: &ArrayRef,
395-
partition_batch_state: &PartitionBatchState,
396-
) -> Result<()> {
397-
self.last_calculated_index += out_col.len();
398-
self.out_col = concat(&[&self.out_col, &out_col])?;
399-
self.n_row_result_missing =
400-
partition_batch_state.record_batch.num_rows() - self.last_calculated_index;
401-
self.is_end = partition_batch_state.is_end;
402-
Ok(())
403-
}
404-
}
405-
406-
/// State for each unique partition determined according to PARTITION BY column(s)
407-
#[derive(Debug)]
408-
pub struct PartitionBatchState {
409-
/// The record_batch belonging to current partition
410-
pub record_batch: RecordBatch,
411-
/// Flag indicating whether we have received all data for this partition
412-
pub is_end: bool,
413-
/// Number of rows emitted for each partition
414-
pub n_out_row: usize,
415-
}
416-
417340
/// Key for IndexMap for each unique partition
418341
///
419342
/// For instance, if window frame is `OVER(PARTITION BY a,b)`,

0 commit comments

Comments
 (0)