Skip to content

Trigonometry Testing #35

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
157 changes: 157 additions & 0 deletions tests/test_trig.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
import random

import pytest

import arrayfire_wrapper.dtypes as dtype
import arrayfire_wrapper.lib as wrapper

from . import utility_functions as util


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_asin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test inverse sine operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.asin(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_acos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test inverse cosine operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.acos(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_atan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test inverse tan operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.atan(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_float_types())
def test_atan2_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test inverse tan operation across all supported data types."""
util.check_type_supported(dtype_name)
if dtype_name == dtype.f16:
pytest.skip()
lhs = wrapper.randu(shape, dtype_name)
rhs = wrapper.randu(shape, dtype_name)
result = wrapper.atan2(lhs, rhs)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"invdtypes",
[
dtype.int16,
dtype.bool,
],
)
def test_atan2_unsupported_dtypes(invdtypes: dtype.Dtype) -> None:
"""Test inverse tan operation for unsupported data types."""
with pytest.raises(RuntimeError):
wrapper.atan2(wrapper.randu((10, 10), invdtypes), wrapper.randu((10, 10), invdtypes))


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_cos_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test cosine operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.cos(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_sin_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test sin operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.sin(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa


@pytest.mark.parametrize(
"shape",
[
(),
(random.randint(1, 10),),
(random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
(random.randint(1, 10), random.randint(1, 10), random.randint(1, 10), random.randint(1, 10)),
],
)
@pytest.mark.parametrize("dtype_name", util.get_all_types())
def test_tan_shape_dtypes(shape: tuple, dtype_name: dtype.Dtype) -> None:
"""Test tan operation across all supported data types."""
util.check_type_supported(dtype_name)
values = wrapper.randu(shape, dtype_name)
result = wrapper.tan(values)
assert wrapper.get_dims(result)[0 : len(shape)] == shape, f"failed for shape: {shape}" # noqa
10 changes: 7 additions & 3 deletions tests/utility_functions.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import pytest

import arrayfire_wrapper.lib as wrapper
from arrayfire_wrapper.dtypes import Dtype, c32, c64, f16, f32, f64, s16, s32, s64, u8, u16, u32, u64
from arrayfire_wrapper.dtypes import Dtype, b8, c32, c64, f16, f32, f64, s16, s32, s64, u8, u16, u32, u64


def check_type_supported(dtype: Dtype) -> None:
"""Checks to see if the specified type is supported by the current system"""
if dtype in [f64, c64] and not wrapper.get_dbl_support():
pytest.skip("Device does not support double types")

if dtype == f16 and not wrapper.get_half_support():
pytest.skip("Device does not support half types.")

Expand All @@ -25,4 +24,9 @@ def get_real_types() -> list:

def get_all_types() -> list:
"""Returns all types"""
return [s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]
return [b8, s16, s32, s64, u8, u16, u32, u64, f16, f32, f64, c32, c64]


def get_float_types() -> list:
"""Returns all types"""
return [f16, f32, f64]
Loading