Skip to content

Commit 12ff1ea

Browse files
authored
fix: Correctly handle take on dense union of a single selected type (#6209)
* fix: use filter instead of filter_primitive * fix: remove pub(crate) from filter_primitive * fix: run cargo fmt * fix: clippy
1 parent b90c799 commit 12ff1ea

File tree

2 files changed

+22
-10
lines changed

2 files changed

+22
-10
lines changed

arrow-select/src/filter.rs

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -552,10 +552,7 @@ fn filter_native<T: ArrowNativeType>(values: &[T], predicate: &FilterPredicate)
552552
}
553553

554554
/// `filter` implementation for primitive arrays
555-
pub(crate) fn filter_primitive<T>(
556-
array: &PrimitiveArray<T>,
557-
predicate: &FilterPredicate,
558-
) -> PrimitiveArray<T>
555+
fn filter_primitive<T>(array: &PrimitiveArray<T>, predicate: &FilterPredicate) -> PrimitiveArray<T>
559556
where
560557
T: ArrowPrimitiveType,
561558
{

arrow-select/src/take.rs

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,6 @@ use arrow_schema::{ArrowError, DataType, FieldRef, UnionMode};
3131

3232
use num::{One, Zero};
3333

34-
use crate::filter::{filter_primitive, FilterBuilder};
35-
3634
/// Take elements by index from [Array], creating a new [Array] from those indexes.
3735
///
3836
/// ```text
@@ -251,13 +249,12 @@ fn take_impl<IndexType: ArrowPrimitiveType>(
251249
let children = fields.iter()
252250
.map(|(field_type_id, _)| {
253251
let mask = BooleanArray::from_unary(&type_ids, |value_type_id| value_type_id == field_type_id);
254-
let predicate = FilterBuilder::new(&mask).build();
255252

256-
let indices = filter_primitive(&offsets, &predicate);
253+
let indices = crate::filter::filter(&offsets, &mask)?;
257254

258255
let values = values.child(field_type_id);
259256

260-
take_impl(values, &indices)
257+
take_impl(values, indices.as_primitive::<Int32Type>())
261258
})
262259
.collect::<Result<_, _>>()?;
263260

@@ -885,7 +882,7 @@ mod tests {
885882
use super::*;
886883
use arrow_array::builder::*;
887884
use arrow_buffer::{IntervalDayTime, IntervalMonthDayNano};
888-
use arrow_schema::{Field, Fields, TimeUnit};
885+
use arrow_schema::{Field, Fields, TimeUnit, UnionFields};
889886

890887
fn test_take_decimal_arrays(
891888
data: Vec<Option<i128>>,
@@ -2308,4 +2305,22 @@ mod tests {
23082305
take(&union, &indices, None).unwrap().to_data()
23092306
);
23102307
}
2308+
2309+
#[test]
2310+
fn test_take_union_dense_all_match_issue_6206() {
2311+
let fields = UnionFields::new(vec![0], vec![Field::new("a", DataType::Int64, false)]);
2312+
let ints = Arc::new(Int64Array::from(vec![1, 2, 3, 4, 5]));
2313+
2314+
let array = UnionArray::try_new(
2315+
fields,
2316+
ScalarBuffer::from(vec![0_i8, 0, 0, 0, 0]),
2317+
Some(ScalarBuffer::from_iter(0_i32..5)),
2318+
vec![ints],
2319+
)
2320+
.unwrap();
2321+
2322+
let indicies = Int64Array::from(vec![0, 2, 4]);
2323+
let array = take(&array, &indicies, None).unwrap();
2324+
assert_eq!(array.len(), 3);
2325+
}
23112326
}

0 commit comments

Comments
 (0)