Skip to content

Commit 7594db6

Browse files
authored
Add overflow-checking variants of arithmetic scalar dyn kernels (#2713)
* Add overflow-checking variants of arithmetic scalar dyn kernels * Update doc * For review
1 parent 2a0fc77 commit 7594db6

File tree

2 files changed

+226
-23
lines changed

2 files changed

+226
-23
lines changed

arrow/src/compute/kernels/arithmetic.rs

Lines changed: 178 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
2323
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
2424
25-
use std::ops::{Add, Div, Mul, Neg, Rem, Sub};
25+
use std::ops::{Div, Neg, Rem};
2626

2727
use num::{One, Zero};
2828

@@ -32,7 +32,9 @@ use crate::buffer::Buffer;
3232
use crate::buffer::MutableBuffer;
3333
use crate::compute::kernels::arity::unary;
3434
use crate::compute::util::combine_option_bitmap;
35-
use crate::compute::{binary, binary_opt, try_binary, try_unary, unary_dyn};
35+
use crate::compute::{
36+
binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn,
37+
};
3638
use crate::datatypes::{
3739
native_op::ArrowNativeTypeOp, ArrowNumericType, DataType, Date32Type, Date64Type,
3840
IntervalDayTimeType, IntervalMonthDayNanoType, IntervalUnit, IntervalYearMonthType,
@@ -834,12 +836,39 @@ where
834836
/// Add every value in an array by a scalar. If any value in the array is null then the
835837
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
836838
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
839+
///
840+
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
841+
/// For an overflow-checking variant, use `add_scalar_checked_dyn` instead.
842+
///
843+
/// This returns an `Err` when the input array is not supported for adding operation.
837844
pub fn add_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
838845
where
839846
T: ArrowNumericType,
840-
T::Native: Add<Output = T::Native>,
847+
T::Native: ArrowNativeTypeOp,
848+
{
849+
unary_dyn::<_, T>(array, |value| value.add_wrapping(scalar))
850+
}
851+
852+
/// Add every value in an array by a scalar. If any value in the array is null then the
853+
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
854+
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
855+
///
856+
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
857+
/// use `add_scalar_dyn` instead.
858+
///
859+
/// As this kernel has the branching costs and also prevents LLVM from vectorising it correctly,
860+
/// it is usually much slower than non-checking variant.
861+
pub fn add_scalar_checked_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
862+
where
863+
T: ArrowNumericType,
864+
T::Native: ArrowNativeTypeOp,
841865
{
842-
unary_dyn::<_, T>(array, |value| value + scalar)
866+
try_unary_dyn::<_, T>(array, |value| {
867+
value.add_checked(scalar).ok_or_else(|| {
868+
ArrowError::CastError(format!("Overflow: adding {:?} to {:?}", scalar, value))
869+
})
870+
})
871+
.map(|a| Arc::new(a) as ArrayRef)
843872
}
844873

845874
/// Perform `left - right` operation on two arrays. If either left or right value is null
@@ -937,16 +966,40 @@ where
937966
/// Subtract every value in an array by a scalar. If any value in the array is null then the
938967
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
939968
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
969+
///
970+
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
971+
/// For an overflow-checking variant, use `subtract_scalar_checked_dyn` instead.
940972
pub fn subtract_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
941973
where
942-
T: datatypes::ArrowNumericType,
943-
T::Native: Add<Output = T::Native>
944-
+ Sub<Output = T::Native>
945-
+ Mul<Output = T::Native>
946-
+ Div<Output = T::Native>
947-
+ Zero,
974+
T: ArrowNumericType,
975+
T::Native: ArrowNativeTypeOp,
976+
{
977+
unary_dyn::<_, T>(array, |value| value.sub_wrapping(scalar))
978+
}
979+
980+
/// Subtract every value in an array by a scalar. If any value in the array is null then the
981+
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
982+
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
983+
///
984+
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
985+
/// use `subtract_scalar_dyn` instead.
986+
pub fn subtract_scalar_checked_dyn<T>(
987+
array: &dyn Array,
988+
scalar: T::Native,
989+
) -> Result<ArrayRef>
990+
where
991+
T: ArrowNumericType,
992+
T::Native: ArrowNativeTypeOp,
948993
{
949-
unary_dyn::<_, T>(array, |value| value - scalar)
994+
try_unary_dyn::<_, T>(array, |value| {
995+
value.sub_checked(scalar).ok_or_else(|| {
996+
ArrowError::CastError(format!(
997+
"Overflow: subtracting {:?} from {:?}",
998+
scalar, value
999+
))
1000+
})
1001+
})
1002+
.map(|a| Arc::new(a) as ArrayRef)
9501003
}
9511004

9521005
/// Perform `-` operation on an array. If value is null then the result is also null.
@@ -1065,18 +1118,40 @@ where
10651118
/// Multiply every value in an array by a scalar. If any value in the array is null then the
10661119
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
10671120
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
1121+
///
1122+
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
1123+
/// For an overflow-checking variant, use `multiply_scalar_checked_dyn` instead.
10681124
pub fn multiply_scalar_dyn<T>(array: &dyn Array, scalar: T::Native) -> Result<ArrayRef>
10691125
where
10701126
T: ArrowNumericType,
1071-
T::Native: Add<Output = T::Native>
1072-
+ Sub<Output = T::Native>
1073-
+ Mul<Output = T::Native>
1074-
+ Div<Output = T::Native>
1075-
+ Rem<Output = T::Native>
1076-
+ Zero
1077-
+ One,
1127+
T::Native: ArrowNativeTypeOp,
1128+
{
1129+
unary_dyn::<_, T>(array, |value| value.mul_wrapping(scalar))
1130+
}
1131+
1132+
/// Subtract every value in an array by a scalar. If any value in the array is null then the
1133+
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
1134+
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
1135+
///
1136+
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
1137+
/// use `multiply_scalar_dyn` instead.
1138+
pub fn multiply_scalar_checked_dyn<T>(
1139+
array: &dyn Array,
1140+
scalar: T::Native,
1141+
) -> Result<ArrayRef>
1142+
where
1143+
T: ArrowNumericType,
1144+
T::Native: ArrowNativeTypeOp,
10781145
{
1079-
unary_dyn::<_, T>(array, |value| value * scalar)
1146+
try_unary_dyn::<_, T>(array, |value| {
1147+
value.mul_checked(scalar).ok_or_else(|| {
1148+
ArrowError::CastError(format!(
1149+
"Overflow: multiplying {:?} by {:?}",
1150+
value, scalar
1151+
))
1152+
})
1153+
})
1154+
.map(|a| Arc::new(a) as ArrayRef)
10801155
}
10811156

10821157
/// Perform `left % right` operation on two arrays. If either left or right value is null
@@ -1223,15 +1298,48 @@ where
12231298
/// result is also null. If the scalar is zero then the result of this operation will be
12241299
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
12251300
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
1301+
///
1302+
/// This doesn't detect overflow. Once overflowing, the result will wrap around.
1303+
/// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead.
12261304
pub fn divide_scalar_dyn<T>(array: &dyn Array, divisor: T::Native) -> Result<ArrayRef>
12271305
where
12281306
T: ArrowNumericType,
1229-
T::Native: Div<Output = T::Native> + Zero,
1307+
T::Native: ArrowNativeTypeOp + Zero,
1308+
{
1309+
if divisor.is_zero() {
1310+
return Err(ArrowError::DivideByZero);
1311+
}
1312+
unary_dyn::<_, T>(array, |value| value.div_wrapping(divisor))
1313+
}
1314+
1315+
/// Divide every value in an array by a scalar. If any value in the array is null then the
1316+
/// result is also null. If the scalar is zero then the result of this operation will be
1317+
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
1318+
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
1319+
///
1320+
/// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
1321+
/// use `divide_scalar_dyn` instead.
1322+
pub fn divide_scalar_checked_dyn<T>(
1323+
array: &dyn Array,
1324+
divisor: T::Native,
1325+
) -> Result<ArrayRef>
1326+
where
1327+
T: ArrowNumericType,
1328+
T::Native: ArrowNativeTypeOp + Zero,
12301329
{
12311330
if divisor.is_zero() {
12321331
return Err(ArrowError::DivideByZero);
12331332
}
1234-
unary_dyn::<_, T>(array, |value| value / divisor)
1333+
1334+
try_unary_dyn::<_, T>(array, |value| {
1335+
value.div_checked(divisor).ok_or_else(|| {
1336+
ArrowError::CastError(format!(
1337+
"Overflow: dividing {:?} by {:?}",
1338+
value, divisor
1339+
))
1340+
})
1341+
})
1342+
.map(|a| Arc::new(a) as ArrayRef)
12351343
}
12361344

12371345
#[cfg(test)]
@@ -2222,6 +2330,55 @@ mod tests {
22222330
overflow.expect_err("overflow should be detected");
22232331
}
22242332

2333+
#[test]
2334+
fn test_primitive_add_scalar_dyn_wrapping_overflow() {
2335+
let a = Int32Array::from(vec![i32::MAX, i32::MIN]);
2336+
2337+
let wrapped = add_scalar_dyn::<Int32Type>(&a, 1).unwrap();
2338+
let expected =
2339+
Arc::new(Int32Array::from(vec![-2147483648, -2147483647])) as ArrayRef;
2340+
assert_eq!(&expected, &wrapped);
2341+
2342+
let overflow = add_scalar_checked_dyn::<Int32Type>(&a, 1);
2343+
overflow.expect_err("overflow should be detected");
2344+
}
2345+
2346+
#[test]
2347+
fn test_primitive_subtract_scalar_dyn_wrapping_overflow() {
2348+
let a = Int32Array::from(vec![-2]);
2349+
2350+
let wrapped = subtract_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
2351+
let expected = Arc::new(Int32Array::from(vec![i32::MAX])) as ArrayRef;
2352+
assert_eq!(&expected, &wrapped);
2353+
2354+
let overflow = subtract_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
2355+
overflow.expect_err("overflow should be detected");
2356+
}
2357+
2358+
#[test]
2359+
fn test_primitive_mul_scalar_dyn_wrapping_overflow() {
2360+
let a = Int32Array::from(vec![10]);
2361+
2362+
let wrapped = multiply_scalar_dyn::<Int32Type>(&a, i32::MAX).unwrap();
2363+
let expected = Arc::new(Int32Array::from(vec![-10])) as ArrayRef;
2364+
assert_eq!(&expected, &wrapped);
2365+
2366+
let overflow = multiply_scalar_checked_dyn::<Int32Type>(&a, i32::MAX);
2367+
overflow.expect_err("overflow should be detected");
2368+
}
2369+
2370+
#[test]
2371+
fn test_primitive_div_scalar_dyn_wrapping_overflow() {
2372+
let a = Int32Array::from(vec![i32::MIN]);
2373+
2374+
let wrapped = divide_scalar_dyn::<Int32Type>(&a, -1).unwrap();
2375+
let expected = Arc::new(Int32Array::from(vec![-2147483648])) as ArrayRef;
2376+
assert_eq!(&expected, &wrapped);
2377+
2378+
let overflow = divide_scalar_checked_dyn::<Int32Type>(&a, -1);
2379+
overflow.expect_err("overflow should be detected");
2380+
}
2381+
22252382
#[test]
22262383
fn test_primitive_div_opt_overflow_division_by_zero() {
22272384
let a = Int32Array::from(vec![i32::MIN]);

arrow/src/compute/kernels/arity.rs

Lines changed: 48 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ where
123123
Ok(unsafe { build_primitive_array(len, buffer.finish(), null_count, null_buffer) })
124124
}
125125

126-
/// A helper function that applies an unary function to a dictionary array with primitive value type.
126+
/// A helper function that applies an infallible unary function to a dictionary array with primitive value type.
127127
fn unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
128128
where
129129
K: ArrowNumericType,
@@ -138,7 +138,22 @@ where
138138
Ok(Arc::new(new_dict))
139139
}
140140

141-
/// Applies an unary function to an array with primitive values.
141+
/// A helper function that applies a fallible unary function to a dictionary array with primitive value type.
142+
fn try_unary_dict<K, F, T>(array: &DictionaryArray<K>, op: F) -> Result<ArrayRef>
143+
where
144+
K: ArrowNumericType,
145+
T: ArrowPrimitiveType,
146+
F: Fn(T::Native) -> Result<T::Native>,
147+
{
148+
let dict_values = array.values().as_any().downcast_ref().unwrap();
149+
let values = try_unary::<T, F, T>(dict_values, op)?.into_data();
150+
let data = array.data().clone().into_builder().child_data(vec![values]);
151+
152+
let new_dict: DictionaryArray<K> = unsafe { data.build_unchecked() }.into();
153+
Ok(Arc::new(new_dict))
154+
}
155+
156+
/// Applies an infallible unary function to an array with primitive values.
142157
pub fn unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
143158
where
144159
T: ArrowPrimitiveType,
@@ -162,6 +177,37 @@ where
162177
}
163178
}
164179

180+
/// Applies a fallible unary function to an array with primitive values.
181+
pub fn try_unary_dyn<F, T>(array: &dyn Array, op: F) -> Result<ArrayRef>
182+
where
183+
T: ArrowPrimitiveType,
184+
F: Fn(T::Native) -> Result<T::Native>,
185+
{
186+
downcast_dictionary_array! {
187+
array => if array.values().data_type() == &T::DATA_TYPE {
188+
try_unary_dict::<_, F, T>(array, op)
189+
} else {
190+
Err(ArrowError::NotYetImplemented(format!(
191+
"Cannot perform unary operation on dictionary array of type {}",
192+
array.data_type()
193+
)))
194+
},
195+
t => {
196+
if t == &T::DATA_TYPE {
197+
Ok(Arc::new(try_unary::<T, F, T>(
198+
array.as_any().downcast_ref::<PrimitiveArray<T>>().unwrap(),
199+
op,
200+
)?))
201+
} else {
202+
Err(ArrowError::NotYetImplemented(format!(
203+
"Cannot perform unary operation on array of type {}",
204+
t
205+
)))
206+
}
207+
}
208+
}
209+
}
210+
165211
/// Given two arrays of length `len`, calls `op(a[i], b[i])` for `i` in `0..len`, collecting
166212
/// the results in a [`PrimitiveArray`]. If any index is null in either `a` or `b`, the
167213
/// corresponding index in the result will also be null

0 commit comments

Comments
 (0)