Skip to content

Commit afb263e

Browse files
committed
2 parents a3d0511 + 28e1c1b commit afb263e

File tree

5 files changed

+1214
-4
lines changed

5 files changed

+1214
-4
lines changed

Diff for: arrayfire_wrapper/lib/mathematical_functions/numeric_functions.py

+7-3
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77
from arrayfire_wrapper.lib.mathematical_functions.arithmetic_operations import sub
88

99

10+
import arrayfire_wrapper.dtypes as dtype
11+
import arrayfire_wrapper.lib as wrapper
12+
13+
1014
def abs_(arr: AFArray, /) -> AFArray:
1115
"""
1216
source: https://arrayfire.org/docs/group__arith__func__abs.htm#ga7e8b3c848e6cda3d1f3b0c8b2b4c3f8f
@@ -44,12 +48,12 @@ def floor(arr: AFArray, /) -> AFArray:
4448
return unary_op(floor.__name__, arr)
4549

4650

47-
def hypot(lhs: AFArray, rhs: AFArray, /) -> AFArray:
51+
def hypot(lhs: AFArray, rhs: AFArray, batch: bool, /) -> AFArray:
4852
"""
4953
source:
5054
"""
5155
out = AFArray.create_null_pointer()
52-
call_from_clib(hypot.__name__, lhs, rhs)
56+
call_from_clib(hypot.__name__, ctypes.pointer(out), lhs, rhs, ctypes.c_bool(batch))
5357
return out
5458

5559

@@ -75,7 +79,7 @@ def mod(lhs: AFArray, rhs: AFArray, /) -> AFArray:
7579

7680

7781
def neg(arr: AFArray) -> AFArray:
78-
return sub(create_constant_array(0, (1,), float32), arr)
82+
return sub(create_constant_array(0, (1,), dtype.c_api_value_to_dtype(wrapper.get_type(arr))), arr)
7983

8084

8185
def rem(lhs: AFArray, rhs: AFArray, /) -> AFArray:

Diff for: arrayfire_wrapper/lib/vector_algorithms/inclusive_scan_operations.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def scan_by_key(key: AFArray, arr: AFArray, dim: int, op: BinaryOperator, inclus
2828
source: https://arrayfire.org/docs/group__scan__func__scanbykey.htm#gaaae150e0f197782782f45340d137b027
2929
"""
3030
out = AFArray.create_null_pointer()
31-
call_from_clib(scan.__name__, ctypes.pointer(out), key, arr, dim, op.value, inclusive_scan)
31+
call_from_clib(scan_by_key.__name__, ctypes.pointer(out), key, arr, dim, op.value, inclusive_scan)
3232
return out
3333

3434

0 commit comments

Comments
 (0)