Skip to content

Commit 4d141a3

Browse files
Allow 'zip' compute function to operate on Scalar arrays (#5086)
1 parent 6815bf1 commit 4d141a3

File tree

1 file changed

+148
-8
lines changed

1 file changed

+148
-8
lines changed

arrow-select/src/zip.rs

Lines changed: 148 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,19 +29,39 @@ use arrow_schema::ArrowError;
2929
/// * `falsy` - Values of this array are taken if mask evaluates `false`
3030
pub fn zip(
3131
mask: &BooleanArray,
32-
truthy: &dyn Array,
33-
falsy: &dyn Array,
32+
truthy: &dyn Datum,
33+
falsy: &dyn Datum,
3434
) -> Result<ArrayRef, ArrowError> {
35+
let (truthy, truthy_is_scalar) = truthy.get();
36+
let (falsy, falsy_is_scalar) = falsy.get();
37+
3538
if truthy.data_type() != falsy.data_type() {
3639
return Err(ArrowError::InvalidArgumentError(
3740
"arguments need to have the same data type".into(),
3841
));
3942
}
40-
if truthy.len() != falsy.len() || falsy.len() != mask.len() {
43+
44+
if truthy_is_scalar && truthy.len() != 1 {
45+
return Err(ArrowError::InvalidArgumentError(
46+
"scalar arrays must have 1 element".into(),
47+
));
48+
}
49+
if !truthy_is_scalar && truthy.len() != mask.len() {
50+
return Err(ArrowError::InvalidArgumentError(
51+
"all arrays should have the same length".into(),
52+
));
53+
}
54+
if truthy_is_scalar && truthy.len() != 1 {
55+
return Err(ArrowError::InvalidArgumentError(
56+
"scalar arrays must have 1 element".into(),
57+
));
58+
}
59+
if !falsy_is_scalar && falsy.len() != mask.len() {
4160
return Err(ArrowError::InvalidArgumentError(
4261
"all arrays should have the same length".into(),
4362
));
4463
}
64+
4565
let falsy = falsy.to_data();
4666
let truthy = truthy.to_data();
4767

@@ -56,15 +76,36 @@ pub fn zip(
5676
SlicesIterator::new(mask).for_each(|(start, end)| {
5777
// the gap needs to be filled with falsy values
5878
if start > filled {
59-
mutable.extend(1, filled, start);
79+
if falsy_is_scalar {
80+
for _ in filled..start {
81+
// Copy the first item from the 'falsy' array into the output buffer.
82+
mutable.extend(1, 0, 1);
83+
}
84+
} else {
85+
mutable.extend(1, filled, start);
86+
}
6087
}
6188
// fill with truthy values
62-
mutable.extend(0, start, end);
89+
if truthy_is_scalar {
90+
for _ in start..end {
91+
// Copy the first item from the 'truthy' array into the output buffer.
92+
mutable.extend(0, 0, 1);
93+
}
94+
} else {
95+
mutable.extend(0, start, end);
96+
}
6397
filled = end;
6498
});
6599
// the remaining part is falsy
66-
if filled < truthy.len() {
67-
mutable.extend(1, filled, truthy.len());
100+
if filled < mask.len() {
101+
if falsy_is_scalar {
102+
for _ in filled..mask.len() {
103+
// Copy the first item from the 'falsy' array into the output buffer.
104+
mutable.extend(1, 0, 1);
105+
}
106+
} else {
107+
mutable.extend(1, filled, mask.len());
108+
}
68109
}
69110

70111
let data = mutable.freeze();
@@ -76,7 +117,7 @@ mod test {
76117
use super::*;
77118

78119
#[test]
79-
fn test_zip_kernel() {
120+
fn test_zip_kernel_one() {
80121
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
81122
let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]);
82123
let mask = BooleanArray::from(vec![true, true, false, false, true]);
@@ -85,4 +126,103 @@ mod test {
85126
let expected = Int32Array::from(vec![Some(5), None, Some(6), Some(7), Some(1)]);
86127
assert_eq!(actual, &expected);
87128
}
129+
130+
#[test]
131+
fn test_zip_kernel_two() {
132+
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
133+
let b = Int32Array::from(vec![None, Some(3), Some(6), Some(7), Some(3)]);
134+
let mask = BooleanArray::from(vec![false, false, true, true, false]);
135+
let out = zip(&mask, &a, &b).unwrap();
136+
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
137+
let expected = Int32Array::from(vec![None, Some(3), Some(7), None, Some(3)]);
138+
assert_eq!(actual, &expected);
139+
}
140+
141+
#[test]
142+
fn test_zip_kernel_scalar_falsy_1() {
143+
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
144+
145+
let fallback = Scalar::new(Int32Array::from_value(42, 1));
146+
147+
let mask = BooleanArray::from(vec![true, true, false, false, true]);
148+
let out = zip(&mask, &a, &fallback).unwrap();
149+
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
150+
let expected = Int32Array::from(vec![Some(5), None, Some(42), Some(42), Some(1)]);
151+
assert_eq!(actual, &expected);
152+
}
153+
154+
#[test]
155+
fn test_zip_kernel_scalar_falsy_2() {
156+
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
157+
158+
let fallback = Scalar::new(Int32Array::from_value(42, 1));
159+
160+
let mask = BooleanArray::from(vec![false, false, true, true, false]);
161+
let out = zip(&mask, &a, &fallback).unwrap();
162+
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
163+
let expected = Int32Array::from(vec![Some(42), Some(42), Some(7), None, Some(42)]);
164+
assert_eq!(actual, &expected);
165+
}
166+
167+
#[test]
168+
fn test_zip_kernel_scalar_truthy_1() {
169+
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
170+
171+
let fallback = Scalar::new(Int32Array::from_value(42, 1));
172+
173+
let mask = BooleanArray::from(vec![true, true, false, false, true]);
174+
let out = zip(&mask, &fallback, &a).unwrap();
175+
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
176+
let expected = Int32Array::from(vec![Some(42), Some(42), Some(7), None, Some(42)]);
177+
assert_eq!(actual, &expected);
178+
}
179+
180+
#[test]
181+
fn test_zip_kernel_scalar_truthy_2() {
182+
let a = Int32Array::from(vec![Some(5), None, Some(7), None, Some(1)]);
183+
184+
let fallback = Scalar::new(Int32Array::from_value(42, 1));
185+
186+
let mask = BooleanArray::from(vec![false, false, true, true, false]);
187+
let out = zip(&mask, &fallback, &a).unwrap();
188+
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
189+
let expected = Int32Array::from(vec![Some(5), None, Some(42), Some(42), Some(1)]);
190+
assert_eq!(actual, &expected);
191+
}
192+
193+
#[test]
194+
fn test_zip_kernel_scalar_both() {
195+
let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
196+
let scalar_falsy = Scalar::new(Int32Array::from_value(123, 1));
197+
198+
let mask = BooleanArray::from(vec![true, true, false, false, true]);
199+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
200+
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
201+
let expected = Int32Array::from(vec![Some(42), Some(42), Some(123), Some(123), Some(42)]);
202+
assert_eq!(actual, &expected);
203+
}
204+
205+
#[test]
206+
fn test_zip_kernel_scalar_none_1() {
207+
let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
208+
let scalar_falsy = Scalar::new(Int32Array::new_null(1));
209+
210+
let mask = BooleanArray::from(vec![true, true, false, false, true]);
211+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
212+
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
213+
let expected = Int32Array::from(vec![Some(42), Some(42), None, None, Some(42)]);
214+
assert_eq!(actual, &expected);
215+
}
216+
217+
#[test]
218+
fn test_zip_kernel_scalar_none_2() {
219+
let scalar_truthy = Scalar::new(Int32Array::from_value(42, 1));
220+
let scalar_falsy = Scalar::new(Int32Array::new_null(1));
221+
222+
let mask = BooleanArray::from(vec![false, false, true, true, false]);
223+
let out = zip(&mask, &scalar_truthy, &scalar_falsy).unwrap();
224+
let actual = out.as_any().downcast_ref::<Int32Array>().unwrap();
225+
let expected = Int32Array::from(vec![None, None, Some(42), Some(42), None]);
226+
assert_eq!(actual, &expected);
227+
}
88228
}

0 commit comments

Comments
 (0)