9
9
10
10
import array_api_compat
11
11
from array_api_compat import array_namespace
12
+ import array_api_compat .numpy
12
13
13
14
from ._helpers import import_ , all_libraries , wrapped_libraries
14
15
@@ -22,6 +23,7 @@ def test_array_namespace(library, api_version, use_compat):
22
23
if use_compat is True and library in {'array_api_strict' , 'jax.numpy' , 'sparse' }:
23
24
pytest .raises (ValueError , lambda : array_namespace (array , use_compat = use_compat ))
24
25
return
26
+ print (use_compat )
25
27
namespace = array_api_compat .array_namespace (array , api_version = api_version , use_compat = use_compat )
26
28
27
29
if use_compat is False or use_compat is None and library not in wrapped_libraries :
@@ -36,6 +38,17 @@ def test_array_namespace(library, api_version, use_compat):
36
38
assert namespace == jax .experimental .array_api
37
39
else :
38
40
assert namespace == xp
41
+ elif use_compat is None :
42
+ if library == "dask.array" :
43
+ # dask should always return wrapped version
44
+ # since dask.array is not array API compatible
45
+ assert namespace == array_api_compat .dask .array
46
+ elif library == "numpy" :
47
+ assert namespace == array_api_compat .numpy
48
+ elif library == "torch" :
49
+ assert namespace == array_api_compat .torch
50
+ else :
51
+ assert namespace == xp
39
52
else :
40
53
if library == "dask.array" :
41
54
assert namespace == array_api_compat .dask .array
0 commit comments