|
12 | 12 | from typing import NamedTuple
|
13 | 13 | import inspect
|
14 | 14 |
|
15 |
| -from ._helpers import array_namespace, _check_device, device, is_cupy_namespace |
| 15 | +from ._helpers import ( |
| 16 | + array_namespace, |
| 17 | + _check_device, |
| 18 | + device as _get_device, |
| 19 | + is_cupy_namespace as _is_cupy_namespace |
| 20 | +) |
16 | 21 |
|
17 | 22 | # These functions are modified from the NumPy versions.
|
18 | 23 |
|
@@ -287,7 +292,7 @@ def cumulative_sum(
|
287 | 292 | initial_shape = list(x.shape)
|
288 | 293 | initial_shape[axis] = 1
|
289 | 294 | res = xp.concatenate(
|
290 |
| - [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=device(res)), res], |
| 295 | + [wrapped_xp.zeros(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], |
291 | 296 | axis=axis,
|
292 | 297 | )
|
293 | 298 | return res
|
@@ -317,7 +322,7 @@ def cumulative_prod(
|
317 | 322 | initial_shape = list(x.shape)
|
318 | 323 | initial_shape[axis] = 1
|
319 | 324 | res = xp.concatenate(
|
320 |
| - [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=device(res)), res], |
| 325 | + [wrapped_xp.ones(shape=initial_shape, dtype=res.dtype, device=_get_device(res)), res], |
321 | 326 | axis=axis,
|
322 | 327 | )
|
323 | 328 | return res
|
@@ -369,7 +374,7 @@ def _isscalar(a):
|
369 | 374 | if type(max) is int and max >= wrapped_xp.iinfo(x.dtype).max:
|
370 | 375 | max = None
|
371 | 376 |
|
372 |
| - dev = device(x) |
| 377 | + dev = _get_device(x) |
373 | 378 | if out is None:
|
374 | 379 | out = wrapped_xp.empty(result_shape, dtype=x.dtype, device=dev)
|
375 | 380 | out[()] = x
|
@@ -579,3 +584,5 @@ def sign(x: ndarray, /, xp, **kwargs) -> ndarray:
|
579 | 584 | 'reshape', 'argsort', 'sort', 'nonzero', 'ceil', 'floor', 'trunc',
|
580 | 585 | 'matmul', 'matrix_transpose', 'tensordot', 'vecdot', 'isdtype',
|
581 | 586 | 'unstack', 'sign']
|
| 587 | + |
| 588 | +_all_ignore = ['inspect', 'array_namespace', 'NamedTuple'] |
0 commit comments