Skip to content

Commit 51978ee

Browse files
gstvgalamb
authored and
Nirnay Roy
committed
Add union_tag scalar function (apache#14687)
* feat: add union_tag scalar function * update for new api * Add test for second field type --------- Co-authored-by: Andrew Lamb <[email protected]>
1 parent 7c16284 commit 51978ee

File tree

5 files changed

+298
-4
lines changed

5 files changed

+298
-4
lines changed

datafusion/functions/src/core/mod.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ pub mod overlay;
3636
pub mod planner;
3737
pub mod r#struct;
3838
pub mod union_extract;
39+
pub mod union_tag;
3940
pub mod version;
4041

4142
// create UDFs
@@ -52,6 +53,7 @@ make_udf_function!(coalesce::CoalesceFunc, coalesce);
5253
make_udf_function!(greatest::GreatestFunc, greatest);
5354
make_udf_function!(least::LeastFunc, least);
5455
make_udf_function!(union_extract::UnionExtractFun, union_extract);
56+
make_udf_function!(union_tag::UnionTagFunc, union_tag);
5557
make_udf_function!(version::VersionFunc, version);
5658

5759
pub mod expr_fn {
@@ -101,6 +103,10 @@ pub mod expr_fn {
101103
least,
102104
"Returns `least(args...)`, which evaluates to the smallest value in the list of expressions or NULL if all the expressions are NULL",
103105
args,
106+
),(
107+
union_tag,
108+
"Returns the name of the currently selected field in the union",
109+
arg1
104110
));
105111

106112
#[doc = "Returns the value of the field with the given name from the struct"]
@@ -136,6 +142,7 @@ pub fn functions() -> Vec<Arc<ScalarUDF>> {
136142
greatest(),
137143
least(),
138144
union_extract(),
145+
union_tag(),
139146
version(),
140147
r#struct(),
141148
]
Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,225 @@
1+
// Licensed to the Apache Software Foundation (ASF) under one
2+
// or more contributor license agreements. See the NOTICE file
3+
// distributed with this work for additional information
4+
// regarding copyright ownership. The ASF licenses this file
5+
// to you under the Apache License, Version 2.0 (the
6+
// "License"); you may not use this file except in compliance
7+
// with the License. You may obtain a copy of the License at
8+
//
9+
// http://www.apache.org/licenses/LICENSE-2.0
10+
//
11+
// Unless required by applicable law or agreed to in writing,
12+
// software distributed under the License is distributed on an
13+
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14+
// KIND, either express or implied. See the License for the
15+
// specific language governing permissions and limitations
16+
// under the License.
17+
18+
use arrow::array::{Array, AsArray, DictionaryArray, Int8Array, StringArray};
19+
use arrow::datatypes::DataType;
20+
use datafusion_common::utils::take_function_args;
21+
use datafusion_common::{exec_datafusion_err, exec_err, Result, ScalarValue};
22+
use datafusion_doc::Documentation;
23+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs};
24+
use datafusion_expr::{ScalarUDFImpl, Signature, Volatility};
25+
use datafusion_macros::user_doc;
26+
use std::sync::Arc;
27+
28+
#[user_doc(
29+
doc_section(label = "Union Functions"),
30+
description = "Returns the name of the currently selected field in the union",
31+
syntax_example = "union_tag(union_expression)",
32+
sql_example = r#"```sql
33+
❯ select union_column, union_tag(union_column) from table_with_union;
34+
+--------------+-------------------------+
35+
| union_column | union_tag(union_column) |
36+
+--------------+-------------------------+
37+
| {a=1} | a |
38+
| {b=3.0} | b |
39+
| {a=4} | a |
40+
| {b=} | b |
41+
| {a=} | a |
42+
+--------------+-------------------------+
43+
```"#,
44+
standard_argument(name = "union", prefix = "Union")
45+
)]
46+
#[derive(Debug)]
47+
pub struct UnionTagFunc {
48+
signature: Signature,
49+
}
50+
51+
impl Default for UnionTagFunc {
52+
fn default() -> Self {
53+
Self::new()
54+
}
55+
}
56+
57+
impl UnionTagFunc {
58+
pub fn new() -> Self {
59+
Self {
60+
signature: Signature::any(1, Volatility::Immutable),
61+
}
62+
}
63+
}
64+
65+
impl ScalarUDFImpl for UnionTagFunc {
66+
fn as_any(&self) -> &dyn std::any::Any {
67+
self
68+
}
69+
70+
fn name(&self) -> &str {
71+
"union_tag"
72+
}
73+
74+
fn signature(&self) -> &Signature {
75+
&self.signature
76+
}
77+
78+
fn return_type(&self, _: &[DataType]) -> Result<DataType> {
79+
Ok(DataType::Dictionary(
80+
Box::new(DataType::Int8),
81+
Box::new(DataType::Utf8),
82+
))
83+
}
84+
85+
fn invoke_with_args(&self, args: ScalarFunctionArgs) -> Result<ColumnarValue> {
86+
let [union_] = take_function_args("union_tag", args.args)?;
87+
88+
match union_ {
89+
ColumnarValue::Array(array)
90+
if matches!(array.data_type(), DataType::Union(_, _)) =>
91+
{
92+
let union_array = array.as_union();
93+
94+
let keys = Int8Array::try_new(union_array.type_ids().clone(), None)?;
95+
96+
let fields = match union_array.data_type() {
97+
DataType::Union(fields, _) => fields,
98+
_ => unreachable!(),
99+
};
100+
101+
// Union fields type IDs only constraints are being unique and in the 0..128 range:
102+
// They may not start at 0, be sequential, or even contiguous.
103+
// Therefore, we allocate a values vector with a length equal to the highest type ID plus one,
104+
// ensuring that each field's name can be placed at the index corresponding to its type ID.
105+
let values_len = fields
106+
.iter()
107+
.map(|(type_id, _)| type_id + 1)
108+
.max()
109+
.unwrap_or_default() as usize;
110+
111+
let mut values = vec![""; values_len];
112+
113+
for (type_id, field) in fields.iter() {
114+
values[type_id as usize] = field.name().as_str()
115+
}
116+
117+
let values = Arc::new(StringArray::from(values));
118+
119+
// SAFETY: union type_ids are validated to not be smaller than zero.
120+
// values len is the union biggest type id plus one.
121+
// keys is built from the union type_ids, which contains only valid type ids
122+
// therefore, `keys[i] >= values.len() || keys[i] < 0` never occurs
123+
let dict = unsafe { DictionaryArray::new_unchecked(keys, values) };
124+
125+
Ok(ColumnarValue::Array(Arc::new(dict)))
126+
}
127+
ColumnarValue::Scalar(ScalarValue::Union(value, fields, _)) => match value {
128+
Some((value_type_id, _)) => fields
129+
.iter()
130+
.find(|(type_id, _)| value_type_id == *type_id)
131+
.map(|(_, field)| {
132+
ColumnarValue::Scalar(ScalarValue::Dictionary(
133+
Box::new(DataType::Int8),
134+
Box::new(field.name().as_str().into()),
135+
))
136+
})
137+
.ok_or_else(|| {
138+
exec_datafusion_err!(
139+
"union_tag: union scalar with unknow type_id {value_type_id}"
140+
)
141+
}),
142+
None => Ok(ColumnarValue::Scalar(ScalarValue::try_new_null(
143+
args.return_field.data_type(),
144+
)?)),
145+
},
146+
v => exec_err!("union_tag only support unions, got {:?}", v.data_type()),
147+
}
148+
}
149+
150+
fn documentation(&self) -> Option<&Documentation> {
151+
self.doc()
152+
}
153+
}
154+
155+
#[cfg(test)]
156+
mod tests {
157+
use super::UnionTagFunc;
158+
use arrow::datatypes::{DataType, Field, UnionFields, UnionMode};
159+
use datafusion_common::ScalarValue;
160+
use datafusion_expr::{ColumnarValue, ScalarFunctionArgs, ScalarUDFImpl};
161+
use std::sync::Arc;
162+
163+
// when it becomes possible to construct union scalars in SQL, this should go to sqllogictests
164+
#[test]
165+
fn union_scalar() {
166+
let fields = [(0, Arc::new(Field::new("a", DataType::UInt32, false)))]
167+
.into_iter()
168+
.collect();
169+
170+
let scalar = ScalarValue::Union(
171+
Some((0, Box::new(ScalarValue::UInt32(Some(0))))),
172+
fields,
173+
UnionMode::Dense,
174+
);
175+
176+
let return_type =
177+
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
178+
179+
let result = UnionTagFunc::new()
180+
.invoke_with_args(ScalarFunctionArgs {
181+
args: vec![ColumnarValue::Scalar(scalar)],
182+
number_rows: 1,
183+
return_field: &Field::new("res", return_type, true),
184+
arg_fields: vec![],
185+
})
186+
.unwrap();
187+
188+
assert_scalar(
189+
result,
190+
ScalarValue::Dictionary(Box::new(DataType::Int8), Box::new("a".into())),
191+
);
192+
}
193+
194+
#[test]
195+
fn union_scalar_empty() {
196+
let scalar = ScalarValue::Union(None, UnionFields::empty(), UnionMode::Dense);
197+
198+
let return_type =
199+
DataType::Dictionary(Box::new(DataType::Int8), Box::new(DataType::Utf8));
200+
201+
let result = UnionTagFunc::new()
202+
.invoke_with_args(ScalarFunctionArgs {
203+
args: vec![ColumnarValue::Scalar(scalar)],
204+
number_rows: 1,
205+
return_field: &Field::new("res", return_type, true),
206+
arg_fields: vec![],
207+
})
208+
.unwrap();
209+
210+
assert_scalar(
211+
result,
212+
ScalarValue::Dictionary(
213+
Box::new(DataType::Int8),
214+
Box::new(ScalarValue::Utf8(None)),
215+
),
216+
);
217+
}
218+
219+
fn assert_scalar(value: ColumnarValue, expected: ScalarValue) {
220+
match value {
221+
ColumnarValue::Array(array) => panic!("expected scalar got {array:?}"),
222+
ColumnarValue::Scalar(scalar) => assert_eq!(scalar, expected),
223+
}
224+
}
225+
}

datafusion/sqllogictest/src/test_context.rs

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -410,10 +410,24 @@ fn create_example_udf() -> ScalarUDF {
410410

411411
fn register_union_table(ctx: &SessionContext) {
412412
let union = UnionArray::try_new(
413-
UnionFields::new(vec![3], vec![Field::new("int", DataType::Int32, false)]),
414-
ScalarBuffer::from(vec![3, 3]),
413+
UnionFields::new(
414+
// typeids: 3 for int, 1 for string
415+
vec![3, 1],
416+
vec![
417+
Field::new("int", DataType::Int32, false),
418+
Field::new("string", DataType::Utf8, false),
419+
],
420+
),
421+
ScalarBuffer::from(vec![3, 1, 3]),
415422
None,
416-
vec![Arc::new(Int32Array::from(vec![1, 2]))],
423+
vec![
424+
Arc::new(Int32Array::from(vec![1, 2, 3])),
425+
Arc::new(StringArray::from(vec![
426+
Some("foo"),
427+
Some("bar"),
428+
Some("baz"),
429+
])),
430+
],
417431
)
418432
.unwrap();
419433

datafusion/sqllogictest/test_files/union_function.slt

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717

18+
# Note: union_table is registered via Rust code in the sqllogictest test harness
19+
# because there is no way to create a union type in SQL today
20+
1821
##########
1922
## UNION DataType Tests
2023
##########
@@ -23,7 +26,8 @@ query ?I
2326
select union_column, union_extract(union_column, 'int') from union_table;
2427
----
2528
{int=1} 1
26-
{int=2} 2
29+
{string=bar} NULL
30+
{int=3} 3
2731

2832
query error DataFusion error: Execution error: field bool not found on union
2933
select union_extract(union_column, 'bool') from union_table;
@@ -45,3 +49,19 @@ select union_extract(union_column, 1) from union_table;
4549

4650
query error DataFusion error: Error during planning: The function 'union_extract' expected 2 arguments but received 3
4751
select union_extract(union_column, 'a', 'b') from union_table;
52+
53+
query ?T
54+
select union_column, union_tag(union_column) from union_table;
55+
----
56+
{int=1} int
57+
{string=bar} string
58+
{int=3} int
59+
60+
query error DataFusion error: Error during planning: 'union_tag' does not support zero arguments
61+
select union_tag() from union_table;
62+
63+
query error DataFusion error: Error during planning: The function 'union_tag' expected 1 arguments but received 2
64+
select union_tag(union_column, 'int') from union_table;
65+
66+
query error DataFusion error: Execution error: union_tag only support unions, got Utf8
67+
select union_tag('int') from union_table;

docs/source/user-guide/sql/scalar_functions.md

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4404,6 +4404,7 @@ sha512(expression)
44044404
Functions to work with the union data type, also know as tagged unions, variant types, enums or sum types. Note: Not related to the SQL UNION operator
44054405

44064406
- [union_extract](#union_extract)
4407+
- [union_tag](#union_tag)
44074408

44084409
### `union_extract`
44094410

@@ -4433,6 +4434,33 @@ union_extract(union, field_name)
44334434
+--------------+----------------------------------+----------------------------------+
44344435
```
44354436

4437+
### `union_tag`
4438+
4439+
Returns the name of the currently selected field in the union
4440+
4441+
```sql
4442+
union_tag(union_expression)
4443+
```
4444+
4445+
#### Arguments
4446+
4447+
- **union**: Union expression to operate on. Can be a constant, column, or function, and any combination of operators.
4448+
4449+
#### Example
4450+
4451+
```sql
4452+
select union_column, union_tag(union_column) from table_with_union;
4453+
+--------------+-------------------------+
4454+
| union_column | union_tag(union_column) |
4455+
+--------------+-------------------------+
4456+
| {a=1} | a |
4457+
| {b=3.0} | b |
4458+
| {a=4} | a |
4459+
| {b=} | b |
4460+
| {a=} | a |
4461+
+--------------+-------------------------+
4462+
```
4463+
44364464
## Other Functions
44374465

44384466
- [arrow_cast](#arrow_cast)

0 commit comments

Comments
 (0)