Skip to content

Commit d6f3f73

Browse files
authored
Support try_from_array and eq_array for ScalarValue::Union (#12208)
1 parent 008c942 commit d6f3f73

File tree

1 file changed

+122
-2
lines changed
  • datafusion/common/src/scalar

1 file changed

+122
-2
lines changed

datafusion/common/src/scalar/mod.rs

Lines changed: 122 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2800,6 +2800,13 @@ impl ScalarValue {
28002800
let a = array.slice(index, 1);
28012801
Self::Map(Arc::new(a.as_map().to_owned()))
28022802
}
2803+
DataType::Union(fields, mode) => {
2804+
let array = as_union_array(array);
2805+
let ti = array.type_id(index);
2806+
let index = array.value_offset(index);
2807+
let value = ScalarValue::try_from_array(array.child(ti), index)?;
2808+
ScalarValue::Union(Some((ti, Box::new(value))), fields.clone(), *mode)
2809+
}
28032810
other => {
28042811
return _not_impl_err!(
28052812
"Can't create a scalar from array of type \"{other:?}\""
@@ -3035,8 +3042,15 @@ impl ScalarValue {
30353042
ScalarValue::DurationNanosecond(val) => {
30363043
eq_array_primitive!(array, index, DurationNanosecondArray, val)?
30373044
}
3038-
ScalarValue::Union(_, _, _) => {
3039-
return _not_impl_err!("Union is not supported yet")
3045+
ScalarValue::Union(value, _, _) => {
3046+
let array = as_union_array(array);
3047+
let ti = array.type_id(index);
3048+
let index = array.value_offset(index);
3049+
if let Some((ti_v, value)) = value {
3050+
ti_v == &ti && value.eq_array(array.child(ti), index)?
3051+
} else {
3052+
array.child(ti).is_null(index)
3053+
}
30403054
}
30413055
ScalarValue::Dictionary(key_type, v) => {
30423056
let (values_array, values_index) = match key_type.as_ref() {
@@ -5536,6 +5550,112 @@ mod tests {
55365550
assert_eq!(&array, &expected);
55375551
}
55385552

5553+
#[test]
5554+
fn test_scalar_union_sparse() {
5555+
let field_a = Arc::new(Field::new("A", DataType::Int32, true));
5556+
let field_b = Arc::new(Field::new("B", DataType::Boolean, true));
5557+
let field_c = Arc::new(Field::new("C", DataType::Utf8, true));
5558+
let fields = UnionFields::from_iter([(0, field_a), (1, field_b), (2, field_c)]);
5559+
5560+
let mut values_a = vec![None; 6];
5561+
values_a[0] = Some(42);
5562+
let mut values_b = vec![None; 6];
5563+
values_b[1] = Some(true);
5564+
let mut values_c = vec![None; 6];
5565+
values_c[2] = Some("foo");
5566+
let children: Vec<ArrayRef> = vec![
5567+
Arc::new(Int32Array::from(values_a)),
5568+
Arc::new(BooleanArray::from(values_b)),
5569+
Arc::new(StringArray::from(values_c)),
5570+
];
5571+
5572+
let type_ids = ScalarBuffer::from(vec![0, 1, 2, 0, 1, 2]);
5573+
let array: ArrayRef = Arc::new(
5574+
UnionArray::try_new(fields.clone(), type_ids, None, children)
5575+
.expect("UnionArray"),
5576+
);
5577+
5578+
let expected = [
5579+
(0, ScalarValue::from(42)),
5580+
(1, ScalarValue::from(true)),
5581+
(2, ScalarValue::from("foo")),
5582+
(0, ScalarValue::Int32(None)),
5583+
(1, ScalarValue::Boolean(None)),
5584+
(2, ScalarValue::Utf8(None)),
5585+
];
5586+
5587+
for (i, (ti, value)) in expected.into_iter().enumerate() {
5588+
let is_null = value.is_null();
5589+
let value = Some((ti, Box::new(value)));
5590+
let expected = ScalarValue::Union(value, fields.clone(), UnionMode::Sparse);
5591+
let actual = ScalarValue::try_from_array(&array, i).expect("try_from_array");
5592+
5593+
assert_eq!(
5594+
actual, expected,
5595+
"[{i}] {actual} was not equal to {expected}"
5596+
);
5597+
5598+
assert!(
5599+
expected.eq_array(&array, i).expect("eq_array"),
5600+
"[{i}] {expected}.eq_array was false"
5601+
);
5602+
5603+
if is_null {
5604+
assert!(actual.is_null(), "[{i}] {actual} was not null")
5605+
}
5606+
}
5607+
}
5608+
5609+
#[test]
5610+
fn test_scalar_union_dense() {
5611+
let field_a = Arc::new(Field::new("A", DataType::Int32, true));
5612+
let field_b = Arc::new(Field::new("B", DataType::Boolean, true));
5613+
let field_c = Arc::new(Field::new("C", DataType::Utf8, true));
5614+
let fields = UnionFields::from_iter([(0, field_a), (1, field_b), (2, field_c)]);
5615+
let children: Vec<ArrayRef> = vec![
5616+
Arc::new(Int32Array::from(vec![Some(42), None])),
5617+
Arc::new(BooleanArray::from(vec![Some(true), None])),
5618+
Arc::new(StringArray::from(vec![Some("foo"), None])),
5619+
];
5620+
5621+
let type_ids = ScalarBuffer::from(vec![0, 1, 2, 0, 1, 2]);
5622+
let offsets = ScalarBuffer::from(vec![0, 0, 0, 1, 1, 1]);
5623+
let array: ArrayRef = Arc::new(
5624+
UnionArray::try_new(fields.clone(), type_ids, Some(offsets), children)
5625+
.expect("UnionArray"),
5626+
);
5627+
5628+
let expected = [
5629+
(0, ScalarValue::from(42)),
5630+
(1, ScalarValue::from(true)),
5631+
(2, ScalarValue::from("foo")),
5632+
(0, ScalarValue::Int32(None)),
5633+
(1, ScalarValue::Boolean(None)),
5634+
(2, ScalarValue::Utf8(None)),
5635+
];
5636+
5637+
for (i, (ti, value)) in expected.into_iter().enumerate() {
5638+
let is_null = value.is_null();
5639+
let value = Some((ti, Box::new(value)));
5640+
let expected = ScalarValue::Union(value, fields.clone(), UnionMode::Dense);
5641+
let actual = ScalarValue::try_from_array(&array, i).expect("try_from_array");
5642+
5643+
assert_eq!(
5644+
actual, expected,
5645+
"[{i}] {actual} was not equal to {expected}"
5646+
);
5647+
5648+
assert!(
5649+
expected.eq_array(&array, i).expect("eq_array"),
5650+
"[{i}] {expected}.eq_array was false"
5651+
);
5652+
5653+
if is_null {
5654+
assert!(actual.is_null(), "[{i}] {actual} was not null")
5655+
}
5656+
}
5657+
}
5658+
55395659
#[test]
55405660
fn test_lists_in_struct() {
55415661
let field_a = Arc::new(Field::new("A", DataType::Utf8, false));

0 commit comments

Comments
 (0)