Skip to content

Commit ca1daa2

Browse files
committed
feat: implement invoke_with_args for struct and named_struct
By implementing `invoke_with_args` the fields derived in return_type(_from_args) can be reused (performance) and the duplicate derivation logic can be removed.
1 parent 774d3cb commit ca1daa2

File tree

4 files changed

+110
-118
lines changed

4 files changed

+110
-118
lines changed

datafusion/core/Cargo.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,10 @@ name = "math_query_sql"
182182
harness = false
183183
name = "filter_query_sql"
184184

185+
[[bench]]
186+
harness = false
187+
name = "struct_query_sql"
188+
185189
[[bench]]
186190
harness = false
187191
name = "window_query_sql"
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::{
19+
array::{Float32Array, Float64Array},
20+
datatypes::{DataType, Field, Schema},
21+
record_batch::RecordBatch,
22+
};
23+
use criterion::{criterion_group, criterion_main, Criterion};
24+
use datafusion::prelude::SessionContext;
25+
use datafusion::{datasource::MemTable, error::Result};
26+
use futures::executor::block_on;
27+
use std::sync::Arc;
28+
use tokio::runtime::Runtime;
29+
30+
async fn query(ctx: &SessionContext, sql: &str) {
31+
let rt = Runtime::new().unwrap();
32+
33+
// execute the query
34+
let df = rt.block_on(ctx.sql(sql)).unwrap();
35+
criterion::black_box(rt.block_on(df.collect()).unwrap());
36+
}
37+
38+
fn create_context(array_len: usize, batch_size: usize) -> Result<SessionContext> {
39+
// define a schema.
40+
let schema = Arc::new(Schema::new(vec![
41+
Field::new("f32", DataType::Float32, false),
42+
Field::new("f64", DataType::Float64, false),
43+
]));
44+
45+
// define data.
46+
let batches = (0..array_len / batch_size)
47+
.map(|i| {
48+
RecordBatch::try_new(
49+
schema.clone(),
50+
vec![
51+
Arc::new(Float32Array::from(vec![i as f32; batch_size])),
52+
Arc::new(Float64Array::from(vec![i as f64; batch_size])),
53+
],
54+
)
55+
.unwrap()
56+
})
57+
.collect::<Vec<_>>();
58+
59+
let ctx = SessionContext::new();
60+
61+
// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
62+
let provider = MemTable::try_new(schema, vec![batches])?;
63+
ctx.register_table("t", Arc::new(provider))?;
64+
65+
Ok(ctx)
66+
}
67+
68+
fn criterion_benchmark(c: &mut Criterion) {
69+
let array_len = 524_288; // 2^19
70+
let batch_size = 4096; // 2^12
71+
72+
c.bench_function("struct", |b| {
73+
let ctx = create_context(array_len, batch_size).unwrap();
74+
b.iter(|| block_on(query(&ctx, "select struct(f32, f64) from t")))
75+
});
76+
}
77+
78+
criterion_group!(benches, criterion_benchmark);
79+
criterion_main!(benches);

datafusion/functions/src/core/named_struct.rs

Lines changed: 8 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -17,79 +17,13 @@
1717

1818
use arrow::array::StructArray;
1919
use arrow::datatypes::{DataType, Field, Fields};
20-
use datafusion_common::{exec_err, internal_err, HashSet, Result, ScalarValue};
21-
use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs};
20+
use datafusion_common::{exec_err, internal_err, Result};
21+
use datafusion_expr::{ColumnarValue, Documentation, ReturnInfo, ReturnTypeArgs, ScalarFunctionArgs};
2222
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2323
use datafusion_macros::user_doc;
2424
use std::any::Any;
2525
use std::sync::Arc;
2626

27-
/// Put values in a struct array.
28-
fn named_struct_expr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
29-
// Do not accept 0 arguments.
30-
if args.is_empty() {
31-
return exec_err!(
32-
"named_struct requires at least one pair of arguments, got 0 instead"
33-
);
34-
}
35-
36-
if args.len() % 2 != 0 {
37-
return exec_err!(
38-
"named_struct requires an even number of arguments, got {} instead",
39-
args.len()
40-
);
41-
}
42-
43-
let (names, values): (Vec<_>, Vec<_>) = args
44-
.chunks_exact(2)
45-
.enumerate()
46-
.map(|(i, chunk)| {
47-
let name_column = &chunk[0];
48-
let name = match name_column {
49-
ColumnarValue::Scalar(ScalarValue::Utf8(Some(name_scalar))) => {
50-
name_scalar
51-
}
52-
// TODO: Implement Display for ColumnarValue
53-
_ => {
54-
return exec_err!(
55-
"named_struct even arguments must be string literals at position {}",
56-
i * 2
57-
)
58-
}
59-
};
60-
61-
Ok((name, chunk[1].clone()))
62-
})
63-
.collect::<Result<Vec<_>>>()?
64-
.into_iter()
65-
.unzip();
66-
67-
{
68-
// Check to enforce the uniqueness of struct field name
69-
let mut unique_field_names = HashSet::new();
70-
for name in names.iter() {
71-
if unique_field_names.contains(name) {
72-
return exec_err!(
73-
"named_struct requires unique field names. Field {name} is used more than once."
74-
);
75-
}
76-
unique_field_names.insert(name);
77-
}
78-
}
79-
80-
let fields: Fields = names
81-
.into_iter()
82-
.zip(&values)
83-
.map(|(name, value)| Arc::new(Field::new(name, value.data_type().clone(), true)))
84-
.collect::<Vec<_>>()
85-
.into();
86-
87-
let arrays = ColumnarValue::values_to_arrays(&values)?;
88-
89-
let struct_array = StructArray::new(fields, arrays, None);
90-
Ok(ColumnarValue::Array(Arc::new(struct_array)))
91-
}
92-
9327
#[user_doc(
9428
doc_section(label = "Struct Functions"),
9529
description = "Returns an Arrow struct using the specified name and input expressions pairs.",
@@ -203,12 +137,12 @@ impl ScalarUDFImpl for NamedStructFunc {
203137
))))
204138
}
205139

206-
fn invoke_batch(
207-
&self,
208-
args: &[ColumnarValue],
209-
_number_rows: usize,
210-
) -> Result<ColumnarValue> {
211-
named_struct_expr(args)
140+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
141+
let DataType::Struct(fields) = args.return_type else {
142+
return internal_err!("incorrect named_struct return type");
143+
};
144+
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
145+
Ok(ColumnarValue::Array(Arc::new(StructArray::new(fields.clone(), arrays, None))))
212146
}
213147

214148
fn documentation(&self) -> Option<&Documentation> {

datafusion/functions/src/core/struct.rs

Lines changed: 19 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -15,46 +15,15 @@
1515
// specific language governing permissions and limitations
1616
// under the License.
1717

18-
use arrow::array::{ArrayRef, StructArray};
19-
use arrow::datatypes::{DataType, Field, Fields};
20-
use datafusion_common::{exec_err, Result};
21-
use datafusion_expr::{ColumnarValue, Documentation};
18+
use arrow::array::StructArray;
19+
use arrow::datatypes::{DataType, Field};
20+
use datafusion_common::{exec_err, internal_err, Result};
21+
use datafusion_expr::{ColumnarValue, Documentation, ScalarFunctionArgs};
2222
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
2323
use datafusion_macros::user_doc;
2424
use std::any::Any;
2525
use std::sync::Arc;
2626

27-
fn array_struct(args: &[ArrayRef]) -> Result<ArrayRef> {
28-
// do not accept 0 arguments.
29-
if args.is_empty() {
30-
return exec_err!("struct requires at least one argument");
31-
}
32-
33-
let fields = args
34-
.iter()
35-
.enumerate()
36-
.map(|(i, arg)| {
37-
let field_name = format!("c{i}");
38-
Ok(Arc::new(Field::new(
39-
field_name.as_str(),
40-
arg.data_type().clone(),
41-
true,
42-
)))
43-
})
44-
.collect::<Result<Vec<_>>>()?
45-
.into();
46-
47-
let arrays = args.to_vec();
48-
49-
Ok(Arc::new(StructArray::new(fields, arrays, None)))
50-
}
51-
52-
/// put values in a struct array.
53-
fn struct_expr(args: &[ColumnarValue]) -> Result<ColumnarValue> {
54-
let arrays = ColumnarValue::values_to_arrays(args)?;
55-
Ok(ColumnarValue::Array(array_struct(arrays.as_slice())?))
56-
}
57-
5827
#[user_doc(
5928
doc_section(label = "Struct Functions"),
6029
description = "Returns an Arrow struct using the specified input expressions optionally named.
@@ -133,20 +102,26 @@ impl ScalarUDFImpl for StructFunc {
133102
}
134103

135104
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
136-
let return_fields = arg_types
105+
if arg_types.is_empty() {
106+
return exec_err!("struct requires at least one argument, got 0 instead");
107+
}
108+
109+
let fields = arg_types
137110
.iter()
138111
.enumerate()
139112
.map(|(pos, dt)| Field::new(format!("c{pos}"), dt.clone(), true))
140-
.collect::<Vec<Field>>();
141-
Ok(DataType::Struct(Fields::from(return_fields)))
113+
.collect::<Vec<Field>>()
114+
.into();
115+
116+
Ok(DataType::Struct(fields))
142117
}
143118

144-
fn invoke_batch(
145-
&self,
146-
args: &[ColumnarValue],
147-
_number_rows: usize,
148-
) -> Result<ColumnarValue> {
149-
struct_expr(args)
119+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
120+
let DataType::Struct(fields) = args.return_type else {
121+
return internal_err!("incorrect struct return type");
122+
};
123+
let arrays = ColumnarValue::values_to_arrays(&args.args)?;
124+
Ok(ColumnarValue::Array(Arc::new(StructArray::new(fields.clone(), arrays, None))))
150125
}
151126

152127
fn documentation(&self) -> Option<&Documentation> {

0 commit comments

Comments
 (0)