7
7
from arrayfire_wrapper .lib .mathematical_functions .arithmetic_operations import sub
8
8
9
9
10
+ import arrayfire_wrapper .dtypes as dtype
11
+ import arrayfire_wrapper .lib as wrapper
12
+
13
+
10
14
def abs_ (arr : AFArray , / ) -> AFArray :
11
15
"""
12
16
source: https://arrayfire.org/docs/group__arith__func__abs.htm#ga7e8b3c848e6cda3d1f3b0c8b2b4c3f8f
@@ -44,12 +48,12 @@ def floor(arr: AFArray, /) -> AFArray:
44
48
return unary_op (floor .__name__ , arr )
45
49
46
50
47
- def hypot (lhs : AFArray , rhs : AFArray , / ) -> AFArray :
51
+ def hypot (lhs : AFArray , rhs : AFArray , batch : bool , / ) -> AFArray :
48
52
"""
49
53
source:
50
54
"""
51
55
out = AFArray .create_null_pointer ()
52
- call_from_clib (hypot .__name__ , lhs , rhs )
56
+ call_from_clib (hypot .__name__ , ctypes . pointer ( out ), lhs , rhs , ctypes . c_bool ( batch ) )
53
57
return out
54
58
55
59
@@ -75,7 +79,7 @@ def mod(lhs: AFArray, rhs: AFArray, /) -> AFArray:
75
79
76
80
77
81
def neg (arr : AFArray ) -> AFArray :
78
- return sub (create_constant_array (0 , (1 ,), float32 ), arr )
82
+ return sub (create_constant_array (0 , (1 ,), dtype . c_api_value_to_dtype ( wrapper . get_type ( arr )) ), arr )
79
83
80
84
81
85
def rem (lhs : AFArray , rhs : AFArray , / ) -> AFArray :
0 commit comments