Skip to content

Commit a4dd075

Browse files
committed
Fix scalar_objects() typing
1 parent b1dcf77 commit a4dd075

File tree

1 file changed

+4
-2
lines changed

1 file changed

+4
-2
lines changed

array_api_tests/test_array_object.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import math
22
from itertools import product
3-
from typing import List, get_args
3+
from typing import List, Union, get_args
44

55
import pytest
66
from hypothesis import assume, given, note
@@ -18,7 +18,9 @@
1818
pytestmark = pytest.mark.ci
1919

2020

21-
def scalar_objects(dtype: DataType, shape: Shape) -> st.SearchStrategy[List[Scalar]]:
21+
def scalar_objects(
22+
dtype: DataType, shape: Shape
23+
) -> st.SearchStrategy[Union[Scalar, List[Scalar]]]:
2224
"""Generates scalars or nested sequences which are valid for xp.asarray()"""
2325
size = math.prod(shape)
2426
return st.lists(xps.from_dtype(dtype), min_size=size, max_size=size).map(

0 commit comments

Comments
 (0)