We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
scalar_objects()
1 parent b1dcf77 commit a4dd075Copy full SHA for a4dd075
array_api_tests/test_array_object.py
@@ -1,6 +1,6 @@
1
import math
2
from itertools import product
3
-from typing import List, get_args
+from typing import List, Union, get_args
4
5
import pytest
6
from hypothesis import assume, given, note
@@ -18,7 +18,9 @@
18
pytestmark = pytest.mark.ci
19
20
21
-def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scalar]]:
+def scalar_objects(
22
+ dtype: DataType, shape: Shape
23
+) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]:
24
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
25
size = math.prod(shape)
26
return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(
0 commit comments