Skip to content

Commit fe2da2b

Browse files
authored
feat(function): add greatest function (#12474)
* feat(function): add greatest function This match the Spark implementation for greatest: https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.functions.greatest.html * remove unused * fix finding common supertype in greatest * allow single argument for greatest * assert that both array have the same length * use logical null count * remove unused import * add docs * add greatest slt tests * add greatest slt tests * fix merge conflicts * add docs * revert manual docs changes * Update based on cr * fix lint * run fmt * run clippy * Uppdated docs using `./dev/update_function_docs.sh`
1 parent eaf51ba commit fe2da2b

File tree

4 files changed

+502
-0
lines changed

4 files changed

+502
-0
lines changed
+272
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,272 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{make_comparator, Array, ArrayRef, BooleanArray};
19+
use arrow::compute::kernels::cmp;
20+
use arrow::compute::kernels::zip::zip;
21+
use arrow::compute::SortOptions;
22+
use arrow::datatypes::DataType;
23+
use arrow_buffer::BooleanBuffer;
24+
use datafusion_common::{exec_err, plan_err, Result, ScalarValue};
25+
use datafusion_expr::binary::type_union_resolution;
26+
use datafusion_expr::scalar_doc_sections::DOC_SECTION_CONDITIONAL;
27+
use datafusion_expr::{ColumnarValue, Documentation};
28+
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
29+
use std::any::Any;
30+
use std::sync::{Arc, OnceLock};
31+
32+
const SORT_OPTIONS: SortOptions = SortOptions {
33+
// We want greatest first
34+
descending: false,
35+
36+
// NULL will be less than any other value
37+
nulls_first: true,
38+
};
39+
40+
#[derive(Debug)]
41+
pub struct GreatestFunc {
42+
signature: Signature,
43+
}
44+
45+
impl Default for GreatestFunc {
46+
fn default() -> Self {
47+
GreatestFunc::new()
48+
}
49+
}
50+
51+
impl GreatestFunc {
52+
pub fn new() -> Self {
53+
Self {
54+
signature: Signature::user_defined(Volatility::Immutable),
55+
}
56+
}
57+
}
58+
59+
fn get_logical_null_count(arr: &dyn Array) -> usize {
60+
arr.logical_nulls()
61+
.map(|n| n.null_count())
62+
.unwrap_or_default()
63+
}
64+
65+
/// Return boolean array where `arr[i] = lhs[i] >= rhs[i]` for all i, where `arr` is the result array
66+
/// Nulls are always considered smaller than any other value
67+
fn get_larger(lhs: &dyn Array, rhs: &dyn Array) -> Result<BooleanArray> {
68+
// Fast path:
69+
// If both arrays are not nested, have the same length and no nulls, we can use the faster vectorised kernel
70+
// - If both arrays are not nested: Nested types, such as lists, are not supported as the null semantics are not well-defined.
71+
// - both array does not have any nulls: cmp::gt_eq will return null if any of the input is null while we want to return false in that case
72+
if !lhs.data_type().is_nested()
73+
&& get_logical_null_count(lhs) == 0
74+
&& get_logical_null_count(rhs) == 0
75+
{
76+
return cmp::gt_eq(&lhs, &rhs).map_err(|e| e.into());
77+
}
78+
79+
let cmp = make_comparator(lhs, rhs, SORT_OPTIONS)?;
80+
81+
if lhs.len() != rhs.len() {
82+
return exec_err!(
83+
"All arrays should have the same length for greatest comparison"
84+
);
85+
}
86+
87+
let values = BooleanBuffer::collect_bool(lhs.len(), |i| cmp(i, i).is_ge());
88+
89+
// No nulls as we only want to keep the values that are larger, its either true or false
90+
Ok(BooleanArray::new(values, None))
91+
}
92+
93+
/// Return array where the largest value at each index is kept
94+
fn keep_larger(lhs: ArrayRef, rhs: ArrayRef) -> Result<ArrayRef> {
95+
// True for values that we should keep from the left array
96+
let keep_lhs = get_larger(lhs.as_ref(), rhs.as_ref())?;
97+
98+
let larger = zip(&keep_lhs, &lhs, &rhs)?;
99+
100+
Ok(larger)
101+
}
102+
103+
fn keep_larger_scalar<'a>(
104+
lhs: &'a ScalarValue,
105+
rhs: &'a ScalarValue,
106+
) -> Result<&'a ScalarValue> {
107+
if !lhs.data_type().is_nested() {
108+
return if lhs >= rhs { Ok(lhs) } else { Ok(rhs) };
109+
}
110+
111+
// If complex type we can't compare directly as we want null values to be smaller
112+
let cmp = make_comparator(
113+
lhs.to_array()?.as_ref(),
114+
rhs.to_array()?.as_ref(),
115+
SORT_OPTIONS,
116+
)?;
117+
118+
if cmp(0, 0).is_ge() {
119+
Ok(lhs)
120+
} else {
121+
Ok(rhs)
122+
}
123+
}
124+
125+
fn find_coerced_type(data_types: &[DataType]) -> Result<DataType> {
126+
if data_types.is_empty() {
127+
plan_err!("greatest was called without any arguments. It requires at least 1.")
128+
} else if let Some(coerced_type) = type_union_resolution(data_types) {
129+
Ok(coerced_type)
130+
} else {
131+
plan_err!("Cannot find a common type for arguments")
132+
}
133+
}
134+
135+
impl ScalarUDFImpl for GreatestFunc {
136+
fn as_any(&self) -> &dyn Any {
137+
self
138+
}
139+
140+
fn name(&self) -> &str {
141+
"greatest"
142+
}
143+
144+
fn signature(&self) -> &Signature {
145+
&self.signature
146+
}
147+
148+
fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
149+
Ok(arg_types[0].clone())
150+
}
151+
152+
fn invoke(&self, args: &[ColumnarValue]) -> Result<ColumnarValue> {
153+
if args.is_empty() {
154+
return exec_err!(
155+
"greatest was called with no arguments. It requires at least 1."
156+
);
157+
}
158+
159+
// Some engines (e.g. SQL Server) allow greatest with single arg, it's a noop
160+
if args.len() == 1 {
161+
return Ok(args[0].clone());
162+
}
163+
164+
// Split to scalars and arrays for later optimization
165+
let (scalars, arrays): (Vec<_>, Vec<_>) = args.iter().partition(|x| match x {
166+
ColumnarValue::Scalar(_) => true,
167+
ColumnarValue::Array(_) => false,
168+
});
169+
170+
let mut arrays_iter = arrays.iter().map(|x| match x {
171+
ColumnarValue::Array(a) => a,
172+
_ => unreachable!(),
173+
});
174+
175+
let first_array = arrays_iter.next();
176+
177+
let mut largest: ArrayRef;
178+
179+
// Optimization: merge all scalars into one to avoid recomputing
180+
if !scalars.is_empty() {
181+
let mut scalars_iter = scalars.iter().map(|x| match x {
182+
ColumnarValue::Scalar(s) => s,
183+
_ => unreachable!(),
184+
});
185+
186+
// We have at least one scalar
187+
let mut largest_scalar = scalars_iter.next().unwrap();
188+
189+
for scalar in scalars_iter {
190+
largest_scalar = keep_larger_scalar(largest_scalar, scalar)?;
191+
}
192+
193+
// If we only have scalars, return the largest one
194+
if arrays.is_empty() {
195+
return Ok(ColumnarValue::Scalar(largest_scalar.clone()));
196+
}
197+
198+
// We have at least one array
199+
let first_array = first_array.unwrap();
200+
201+
// Start with the largest value
202+
largest = keep_larger(
203+
Arc::clone(first_array),
204+
largest_scalar.to_array_of_size(first_array.len())?,
205+
)?;
206+
} else {
207+
// If we only have arrays, start with the first array
208+
// (We must have at least one array)
209+
largest = Arc::clone(first_array.unwrap());
210+
}
211+
212+
for array in arrays_iter {
213+
largest = keep_larger(Arc::clone(array), largest)?;
214+
}
215+
216+
Ok(ColumnarValue::Array(largest))
217+
}
218+
219+
fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
220+
let coerced_type = find_coerced_type(arg_types)?;
221+
222+
Ok(vec![coerced_type; arg_types.len()])
223+
}
224+
225+
fn documentation(&self) -> Option<&Documentation> {
226+
Some(get_greatest_doc())
227+
}
228+
}
229+
static DOCUMENTATION: OnceLock<Documentation> = OnceLock::new();
230+
231+
fn get_greatest_doc() -> &'static Documentation {
232+
DOCUMENTATION.get_or_init(|| {
233+
Documentation::builder()
234+
.with_doc_section(DOC_SECTION_CONDITIONAL)
235+
.with_description("Returns the greatest value in a list of expressions. Returns _null_ if all expressions are _null_.")
236+
.with_syntax_example("greatest(expression1[, ..., expression_n])")
237+
.with_sql_example(r#"```sql
238+
> select greatest(4, 7, 5);
239+
+---------------------------+
240+
| greatest(4,7,5) |
241+
+---------------------------+
242+
| 7 |
243+
+---------------------------+
244+
```"#,
245+
)
246+
.with_argument(
247+
"expression1, expression_n",
248+
"Expressions to compare and return the greatest value.. Can be a constant, column, or function, and any combination of arithmetic operators. Pass as many expression arguments as necessary."
249+
)
250+
.build()
251+
.unwrap()
252+
})
253+
}
254+
255+
#[cfg(test)]
256+
mod test {
257+
use crate::core;
258+
use arrow::datatypes::DataType;
259+
use datafusion_expr::ScalarUDFImpl;
260+
261+
#[test]
262+
fn test_greatest_return_types_without_common_supertype_in_arg_type() {
263+
let greatest = core::greatest::GreatestFunc::new();
264+
let return_type = greatest
265+
.coerce_types(&[DataType::Decimal128(10, 3), DataType::Decimal128(10, 4)])
266+
.unwrap();
267+
assert_eq!(
268+
return_type,
269+
vec![DataType::Decimal128(11, 4), DataType::Decimal128(11, 4)]
270+
);
271+
}
272+
}

datafusion/functions/src/core/mod.rs

+7
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ pub mod arrowtypeof;
2525
pub mod coalesce;
2626
pub mod expr_ext;
2727
pub mod getfield;
28+
pub mod greatest;
2829
pub mod named_struct;
2930
pub mod nullif;
3031
pub mod nvl;
@@ -43,6 +44,7 @@ make_udf_function!(r#struct::StructFunc, STRUCT, r#struct);
4344
make_udf_function!(named_struct::NamedStructFunc, NAMED_STRUCT, named_struct);
4445
make_udf_function!(getfield::GetFieldFunc, GET_FIELD, get_field);
4546
make_udf_function!(coalesce::CoalesceFunc, COALESCE, coalesce);
47+
make_udf_function!(greatest::GreatestFunc, GREATEST, greatest);
4648
make_udf_function!(version::VersionFunc, VERSION, version);
4749

4850
pub mod expr_fn {
@@ -80,6 +82,10 @@ pub mod expr_fn {
8082
coalesce,
8183
"Returns `coalesce(args...)`, which evaluates to the value of the first expr which is not NULL",
8284
args,
85+
),(
86+
greatest,
87+
"Returns `greatest(args...)`, which evaluates to the greatest value in the list of expressions or NULL if all the expressions are NULL",
88+
args,
8389
));
8490

8591
#[doc = "Returns the value of the field with the given name from the struct"]
@@ -106,6 +112,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
106112
// calls to `get_field`
107113
get_field(),
108114
coalesce(),
115+
greatest(),
109116
version(),
110117
r#struct(),
111118
]

0 commit comments

Comments
 (0)