22
22
//! `RUSTFLAGS="-C target-feature=+avx2"` for example. See the documentation
23
23
//! [here](https://doc.rust-lang.org/stable/core/arch/) for more information.
24
24
25
- use std:: ops:: { Add , Div , Mul , Neg , Rem , Sub } ;
25
+ use std:: ops:: { Div , Neg , Rem } ;
26
26
27
27
use num:: { One , Zero } ;
28
28
@@ -32,7 +32,9 @@ use crate::buffer::Buffer;
32
32
use crate :: buffer:: MutableBuffer ;
33
33
use crate :: compute:: kernels:: arity:: unary;
34
34
use crate :: compute:: util:: combine_option_bitmap;
35
- use crate :: compute:: { binary, binary_opt, try_binary, try_unary, unary_dyn} ;
35
+ use crate :: compute:: {
36
+ binary, binary_opt, try_binary, try_unary, try_unary_dyn, unary_dyn,
37
+ } ;
36
38
use crate :: datatypes:: {
37
39
native_op:: ArrowNativeTypeOp , ArrowNumericType , DataType , Date32Type , Date64Type ,
38
40
IntervalDayTimeType , IntervalMonthDayNanoType , IntervalUnit , IntervalYearMonthType ,
@@ -834,12 +836,39 @@ where
834
836
/// Add every value in an array by a scalar. If any value in the array is null then the
835
837
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
836
838
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
839
+ ///
840
+ /// This doesn't detect overflow. Once overflowing, the result will wrap around.
841
+ /// For an overflow-checking variant, use `add_scalar_checked_dyn` instead.
842
+ ///
843
+ /// This returns an `Err` when the input array is not supported for adding operation.
837
844
pub fn add_scalar_dyn < T > ( array : & dyn Array , scalar : T :: Native ) -> Result < ArrayRef >
838
845
where
839
846
T : ArrowNumericType ,
840
- T :: Native : Add < Output = T :: Native > ,
847
+ T :: Native : ArrowNativeTypeOp ,
848
+ {
849
+ unary_dyn :: < _ , T > ( array, |value| value. add_wrapping ( scalar) )
850
+ }
851
+
852
+ /// Add every value in an array by a scalar. If any value in the array is null then the
853
+ /// result is also null. The given array must be a `PrimitiveArray` of the type same as
854
+ /// the scalar, or a `DictionaryArray` of the value type same as the scalar.
855
+ ///
856
+ /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
857
+ /// use `add_scalar_dyn` instead.
858
+ ///
859
+ /// As this kernel has the branching costs and also prevents LLVM from vectorising it correctly,
860
+ /// it is usually much slower than non-checking variant.
861
+ pub fn add_scalar_checked_dyn < T > ( array : & dyn Array , scalar : T :: Native ) -> Result < ArrayRef >
862
+ where
863
+ T : ArrowNumericType ,
864
+ T :: Native : ArrowNativeTypeOp ,
841
865
{
842
- unary_dyn :: < _ , T > ( array, |value| value + scalar)
866
+ try_unary_dyn :: < _ , T > ( array, |value| {
867
+ value. add_checked ( scalar) . ok_or_else ( || {
868
+ ArrowError :: CastError ( format ! ( "Overflow: adding {:?} to {:?}" , scalar, value) )
869
+ } )
870
+ } )
871
+ . map ( |a| Arc :: new ( a) as ArrayRef )
843
872
}
844
873
845
874
/// Perform `left - right` operation on two arrays. If either left or right value is null
@@ -937,16 +966,40 @@ where
937
966
/// Subtract every value in an array by a scalar. If any value in the array is null then the
938
967
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
939
968
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
969
+ ///
970
+ /// This doesn't detect overflow. Once overflowing, the result will wrap around.
971
+ /// For an overflow-checking variant, use `subtract_scalar_checked_dyn` instead.
940
972
pub fn subtract_scalar_dyn < T > ( array : & dyn Array , scalar : T :: Native ) -> Result < ArrayRef >
941
973
where
942
- T : datatypes:: ArrowNumericType ,
943
- T :: Native : Add < Output = T :: Native >
944
- + Sub < Output = T :: Native >
945
- + Mul < Output = T :: Native >
946
- + Div < Output = T :: Native >
947
- + Zero ,
974
+ T : ArrowNumericType ,
975
+ T :: Native : ArrowNativeTypeOp ,
976
+ {
977
+ unary_dyn :: < _ , T > ( array, |value| value. sub_wrapping ( scalar) )
978
+ }
979
+
980
+ /// Subtract every value in an array by a scalar. If any value in the array is null then the
981
+ /// result is also null. The given array must be a `PrimitiveArray` of the type same as
982
+ /// the scalar, or a `DictionaryArray` of the value type same as the scalar.
983
+ ///
984
+ /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
985
+ /// use `subtract_scalar_dyn` instead.
986
+ pub fn subtract_scalar_checked_dyn < T > (
987
+ array : & dyn Array ,
988
+ scalar : T :: Native ,
989
+ ) -> Result < ArrayRef >
990
+ where
991
+ T : ArrowNumericType ,
992
+ T :: Native : ArrowNativeTypeOp ,
948
993
{
949
- unary_dyn :: < _ , T > ( array, |value| value - scalar)
994
+ try_unary_dyn :: < _ , T > ( array, |value| {
995
+ value. sub_checked ( scalar) . ok_or_else ( || {
996
+ ArrowError :: CastError ( format ! (
997
+ "Overflow: subtracting {:?} from {:?}" ,
998
+ scalar, value
999
+ ) )
1000
+ } )
1001
+ } )
1002
+ . map ( |a| Arc :: new ( a) as ArrayRef )
950
1003
}
951
1004
952
1005
/// Perform `-` operation on an array. If value is null then the result is also null.
@@ -1065,18 +1118,40 @@ where
1065
1118
/// Multiply every value in an array by a scalar. If any value in the array is null then the
1066
1119
/// result is also null. The given array must be a `PrimitiveArray` of the type same as
1067
1120
/// the scalar, or a `DictionaryArray` of the value type same as the scalar.
1121
+ ///
1122
+ /// This doesn't detect overflow. Once overflowing, the result will wrap around.
1123
+ /// For an overflow-checking variant, use `multiply_scalar_checked_dyn` instead.
1068
1124
pub fn multiply_scalar_dyn < T > ( array : & dyn Array , scalar : T :: Native ) -> Result < ArrayRef >
1069
1125
where
1070
1126
T : ArrowNumericType ,
1071
- T :: Native : Add < Output = T :: Native >
1072
- + Sub < Output = T :: Native >
1073
- + Mul < Output = T :: Native >
1074
- + Div < Output = T :: Native >
1075
- + Rem < Output = T :: Native >
1076
- + Zero
1077
- + One ,
1127
+ T :: Native : ArrowNativeTypeOp ,
1128
+ {
1129
+ unary_dyn :: < _ , T > ( array, |value| value. mul_wrapping ( scalar) )
1130
+ }
1131
+
1132
+ /// Subtract every value in an array by a scalar. If any value in the array is null then the
1133
+ /// result is also null. The given array must be a `PrimitiveArray` of the type same as
1134
+ /// the scalar, or a `DictionaryArray` of the value type same as the scalar.
1135
+ ///
1136
+ /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
1137
+ /// use `multiply_scalar_dyn` instead.
1138
+ pub fn multiply_scalar_checked_dyn < T > (
1139
+ array : & dyn Array ,
1140
+ scalar : T :: Native ,
1141
+ ) -> Result < ArrayRef >
1142
+ where
1143
+ T : ArrowNumericType ,
1144
+ T :: Native : ArrowNativeTypeOp ,
1078
1145
{
1079
- unary_dyn :: < _ , T > ( array, |value| value * scalar)
1146
+ try_unary_dyn :: < _ , T > ( array, |value| {
1147
+ value. mul_checked ( scalar) . ok_or_else ( || {
1148
+ ArrowError :: CastError ( format ! (
1149
+ "Overflow: multiplying {:?} by {:?}" ,
1150
+ value, scalar
1151
+ ) )
1152
+ } )
1153
+ } )
1154
+ . map ( |a| Arc :: new ( a) as ArrayRef )
1080
1155
}
1081
1156
1082
1157
/// Perform `left % right` operation on two arrays. If either left or right value is null
@@ -1223,15 +1298,48 @@ where
1223
1298
/// result is also null. If the scalar is zero then the result of this operation will be
1224
1299
/// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
1225
1300
/// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
1301
+ ///
1302
+ /// This doesn't detect overflow. Once overflowing, the result will wrap around.
1303
+ /// For an overflow-checking variant, use `divide_scalar_checked_dyn` instead.
1226
1304
pub fn divide_scalar_dyn < T > ( array : & dyn Array , divisor : T :: Native ) -> Result < ArrayRef >
1227
1305
where
1228
1306
T : ArrowNumericType ,
1229
- T :: Native : Div < Output = T :: Native > + Zero ,
1307
+ T :: Native : ArrowNativeTypeOp + Zero ,
1308
+ {
1309
+ if divisor. is_zero ( ) {
1310
+ return Err ( ArrowError :: DivideByZero ) ;
1311
+ }
1312
+ unary_dyn :: < _ , T > ( array, |value| value. div_wrapping ( divisor) )
1313
+ }
1314
+
1315
+ /// Divide every value in an array by a scalar. If any value in the array is null then the
1316
+ /// result is also null. If the scalar is zero then the result of this operation will be
1317
+ /// `Err(ArrowError::DivideByZero)`. The given array must be a `PrimitiveArray` of the type
1318
+ /// same as the scalar, or a `DictionaryArray` of the value type same as the scalar.
1319
+ ///
1320
+ /// This detects overflow and returns an `Err` for that. For an non-overflow-checking variant,
1321
+ /// use `divide_scalar_dyn` instead.
1322
+ pub fn divide_scalar_checked_dyn < T > (
1323
+ array : & dyn Array ,
1324
+ divisor : T :: Native ,
1325
+ ) -> Result < ArrayRef >
1326
+ where
1327
+ T : ArrowNumericType ,
1328
+ T :: Native : ArrowNativeTypeOp + Zero ,
1230
1329
{
1231
1330
if divisor. is_zero ( ) {
1232
1331
return Err ( ArrowError :: DivideByZero ) ;
1233
1332
}
1234
- unary_dyn :: < _ , T > ( array, |value| value / divisor)
1333
+
1334
+ try_unary_dyn :: < _ , T > ( array, |value| {
1335
+ value. div_checked ( divisor) . ok_or_else ( || {
1336
+ ArrowError :: CastError ( format ! (
1337
+ "Overflow: dividing {:?} by {:?}" ,
1338
+ value, divisor
1339
+ ) )
1340
+ } )
1341
+ } )
1342
+ . map ( |a| Arc :: new ( a) as ArrayRef )
1235
1343
}
1236
1344
1237
1345
#[ cfg( test) ]
@@ -2222,6 +2330,55 @@ mod tests {
2222
2330
overflow. expect_err ( "overflow should be detected" ) ;
2223
2331
}
2224
2332
2333
+ #[ test]
2334
+ fn test_primitive_add_scalar_dyn_wrapping_overflow ( ) {
2335
+ let a = Int32Array :: from ( vec ! [ i32 :: MAX , i32 :: MIN ] ) ;
2336
+
2337
+ let wrapped = add_scalar_dyn :: < Int32Type > ( & a, 1 ) . unwrap ( ) ;
2338
+ let expected =
2339
+ Arc :: new ( Int32Array :: from ( vec ! [ -2147483648 , -2147483647 ] ) ) as ArrayRef ;
2340
+ assert_eq ! ( & expected, & wrapped) ;
2341
+
2342
+ let overflow = add_scalar_checked_dyn :: < Int32Type > ( & a, 1 ) ;
2343
+ overflow. expect_err ( "overflow should be detected" ) ;
2344
+ }
2345
+
2346
+ #[ test]
2347
+ fn test_primitive_subtract_scalar_dyn_wrapping_overflow ( ) {
2348
+ let a = Int32Array :: from ( vec ! [ -2 ] ) ;
2349
+
2350
+ let wrapped = subtract_scalar_dyn :: < Int32Type > ( & a, i32:: MAX ) . unwrap ( ) ;
2351
+ let expected = Arc :: new ( Int32Array :: from ( vec ! [ i32 :: MAX ] ) ) as ArrayRef ;
2352
+ assert_eq ! ( & expected, & wrapped) ;
2353
+
2354
+ let overflow = subtract_scalar_checked_dyn :: < Int32Type > ( & a, i32:: MAX ) ;
2355
+ overflow. expect_err ( "overflow should be detected" ) ;
2356
+ }
2357
+
2358
+ #[ test]
2359
+ fn test_primitive_mul_scalar_dyn_wrapping_overflow ( ) {
2360
+ let a = Int32Array :: from ( vec ! [ 10 ] ) ;
2361
+
2362
+ let wrapped = multiply_scalar_dyn :: < Int32Type > ( & a, i32:: MAX ) . unwrap ( ) ;
2363
+ let expected = Arc :: new ( Int32Array :: from ( vec ! [ -10 ] ) ) as ArrayRef ;
2364
+ assert_eq ! ( & expected, & wrapped) ;
2365
+
2366
+ let overflow = multiply_scalar_checked_dyn :: < Int32Type > ( & a, i32:: MAX ) ;
2367
+ overflow. expect_err ( "overflow should be detected" ) ;
2368
+ }
2369
+
2370
+ #[ test]
2371
+ fn test_primitive_div_scalar_dyn_wrapping_overflow ( ) {
2372
+ let a = Int32Array :: from ( vec ! [ i32 :: MIN ] ) ;
2373
+
2374
+ let wrapped = divide_scalar_dyn :: < Int32Type > ( & a, -1 ) . unwrap ( ) ;
2375
+ let expected = Arc :: new ( Int32Array :: from ( vec ! [ -2147483648 ] ) ) as ArrayRef ;
2376
+ assert_eq ! ( & expected, & wrapped) ;
2377
+
2378
+ let overflow = divide_scalar_checked_dyn :: < Int32Type > ( & a, -1 ) ;
2379
+ overflow. expect_err ( "overflow should be detected" ) ;
2380
+ }
2381
+
2225
2382
#[ test]
2226
2383
fn test_primitive_div_opt_overflow_division_by_zero ( ) {
2227
2384
let a = Int32Array :: from ( vec ! [ i32 :: MIN ] ) ;
0 commit comments