Skip to content

Commit b849435

Browse files
AzeezIshAzeezIsh
and
AzeezIsh
authored
Arithmetic Operation Testing (#42)
* testing on arithmetic functions --------- Co-authored-by: AzeezIsh <[email protected]>
1 parent c414f5f commit b849435

File tree

1 file changed

+152
-0
lines changed

1 file changed

+152
-0
lines changed

tests/test_arithmetic.py

+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
import random
2+
3+
import pytest
4+
5+
import arrayfire_wrapper.dtypes as dtype
6+
import arrayfire_wrapper.lib as wrapper
7+
from tests.utility_functions import check_type_supported, get_all_types
8+
9+
10+
@pytest.mark.parametrize(
11+
"shape",
12+
[
13+
(),
14+
(random.randint(1, 10),),
15+
(random.randint(1, 10), random.randint(1, 10)),
16+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
17+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
18+
],
19+
)
20+
def test_add_shapes(shape: tuple) -> None:
21+
"""Test addition operation between two arrays of the same shape"""
22+
lhs = wrapper.randu(shape, dtype.f16)
23+
rhs = wrapper.randu(shape, dtype.f16)
24+
25+
result = wrapper.add(lhs, rhs)
26+
27+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203, W291
28+
29+
30+
def test_add_different_shapes() -> None:
31+
"""Test if addition handles arrays of different shapes"""
32+
with pytest.raises(RuntimeError):
33+
lhs_shape = (2, 3)
34+
rhs_shape = (3, 2)
35+
dtypes = dtype.f16
36+
lhs = wrapper.randu(lhs_shape, dtypes)
37+
rhs = wrapper.randu(rhs_shape, dtypes)
38+
39+
wrapper.add(lhs, rhs)
40+
41+
42+
@pytest.mark.parametrize("dtype_name", get_all_types())
43+
def test_add_supported_dtypes(dtype_name: dtype.Dtype) -> None:
44+
"""Test addition operation across all supported data types."""
45+
check_type_supported(dtype_name)
46+
shape = (5, 5) # Using a common shape for simplicity
47+
lhs = wrapper.randu(shape, dtype_name)
48+
rhs = wrapper.randu(shape, dtype_name)
49+
result = wrapper.add(lhs, rhs)
50+
assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == dtype_name, f"Failed for dtype: {dtype_name}"
51+
52+
53+
@pytest.mark.parametrize(
54+
"invdtypes",
55+
[
56+
dtype.c64,
57+
dtype.f64,
58+
],
59+
)
60+
def test_add_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
61+
"""Test addition operation across all supported data types."""
62+
with pytest.raises(RuntimeError):
63+
shape = (5, 5)
64+
lhs = wrapper.randu(shape, invdtypes)
65+
rhs = wrapper.randu(shape, invdtypes)
66+
result = wrapper.add(lhs, rhs)
67+
assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == invdtypes, f"Didn't Fail for Dtype: {invdtypes}"
68+
69+
70+
def test_add_zero_sized_arrays() -> None:
71+
"""Test addition with arrays where at least one array has zero size."""
72+
with pytest.raises(RuntimeError):
73+
zero_shape = (0, 5)
74+
normal_shape = (5, 5)
75+
zero_array = wrapper.randu(zero_shape, dtype.f32)
76+
normal_array = wrapper.randu(normal_shape, dtype.f32)
77+
78+
# Test addition when lhs is zero-sized
79+
result_lhs_zero = wrapper.add(zero_array, normal_array)
80+
assert wrapper.get_dims(result_lhs_zero) == zero_shape
81+
82+
83+
@pytest.mark.parametrize(
84+
"shape",
85+
[
86+
(),
87+
(random.randint(1, 10),),
88+
(random.randint(1, 10), random.randint(1, 10)),
89+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
90+
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
91+
],
92+
)
93+
def test_subtract_shapes(shape: tuple) -> None:
94+
"""Test subtraction operation between two arrays of the same shape"""
95+
lhs = wrapper.randu(shape, dtype.f16)
96+
rhs = wrapper.randu(shape, dtype.f16)
97+
98+
result = wrapper.sub(lhs, rhs)
99+
100+
assert wrapper.get_dims(result)[0 : len(shape)] == shape # noqa: E203, W291
101+
102+
103+
def test_subtract_different_shapes() -> None:
104+
"""Test if subtraction handles arrays of different shapes"""
105+
with pytest.raises(RuntimeError):
106+
lhs_shape = (2, 3)
107+
rhs_shape = (3, 2)
108+
dtypes = dtype.f16
109+
lhs = wrapper.randu(lhs_shape, dtypes)
110+
rhs = wrapper.randu(rhs_shape, dtypes)
111+
112+
wrapper.sub(lhs, rhs)
113+
114+
115+
@pytest.mark.parametrize("dtype_name", get_all_types())
116+
def test_subtract_supported_dtypes(dtype_name: dtype.Dtype) -> None:
117+
"""Test subtraction operation across all supported data types."""
118+
check_type_supported(dtype_name)
119+
shape = (5, 5)
120+
lhs = wrapper.randu(shape, dtype_name)
121+
rhs = wrapper.randu(shape, dtype_name)
122+
result = wrapper.sub(lhs, rhs)
123+
assert dtype.c_api_value_to_dtype(wrapper.get_type(result)) == dtype_name, f"Failed for dtype: {dtype_name}"
124+
125+
126+
@pytest.mark.parametrize(
127+
"invdtypes",
128+
[
129+
dtype.c64,
130+
dtype.f64,
131+
],
132+
)
133+
def test_subtract_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
134+
"""Test subtraction operation for unsupported data types."""
135+
with pytest.raises(RuntimeError):
136+
shape = (5, 5)
137+
lhs = wrapper.randu(shape, invdtypes)
138+
rhs = wrapper.randu(shape, invdtypes)
139+
result = wrapper.sub(lhs, rhs)
140+
assert result == invdtypes, f"Didn't Fail for Dtype: {invdtypes}"
141+
142+
143+
def test_subtract_zero_sized_arrays() -> None:
144+
"""Test subtraction with arrays where at least one array has zero size."""
145+
with pytest.raises(RuntimeError):
146+
zero_shape = (0, 5)
147+
normal_shape = (5, 5)
148+
zero_array = wrapper.randu(zero_shape, dtype.f32)
149+
normal_array = wrapper.randu(normal_shape, dtype.f32)
150+
151+
result_lhs_zero = wrapper.sub(zero_array, normal_array)
152+
assert wrapper.get_dims(result_lhs_zero) == zero_shape

0 commit comments

Comments
 (0)