|
6 | 6 |
|
7 | 7 | from typing import TYPE_CHECKING
|
8 | 8 | if TYPE_CHECKING:
|
9 |
| - import numpy as np |
10 | 9 | from typing import Optional, Sequence, Tuple, Union
|
11 |
| - from ._typing import ndarray, Device, Dtype, NestedSequence, SupportsBufferProtocol |
| 10 | + from ._typing import ndarray, Device, Dtype |
12 | 11 |
|
13 | 12 | from typing import NamedTuple
|
14 |
| -from types import ModuleType |
15 | 13 | import inspect
|
16 | 14 |
|
17 |
| -from ._helpers import _check_device, is_numpy_array, array_namespace |
| 15 | +from ._helpers import _check_device |
18 | 16 |
|
19 | 17 | # These functions are modified from the NumPy versions.
|
20 | 18 |
|
| 19 | +# Creation functions add the device keyword (which does nothing for NumPy) |
| 20 | + |
21 | 21 | def arange(
|
22 | 22 | start: Union[int, float],
|
23 | 23 | /,
|
@@ -268,90 +268,6 @@ def var(
|
268 | 268 | def permute_dims(x: ndarray, /, axes: Tuple[int, ...], xp) -> ndarray:
|
269 | 269 | return xp.transpose(x, axes)
|
270 | 270 |
|
271 |
| -# Creation functions add the device keyword (which does nothing for NumPy) |
272 |
| - |
273 |
| -# asarray also adds the copy keyword |
274 |
| -def _asarray( |
275 |
| - obj: Union[ |
276 |
| - ndarray, |
277 |
| - bool, |
278 |
| - int, |
279 |
| - float, |
280 |
| - NestedSequence[bool | int | float], |
281 |
| - SupportsBufferProtocol, |
282 |
| - ], |
283 |
| - /, |
284 |
| - *, |
285 |
| - dtype: Optional[Dtype] = None, |
286 |
| - device: Optional[Device] = None, |
287 |
| - copy: "Optional[Union[bool, np._CopyMode]]" = None, |
288 |
| - namespace = None, |
289 |
| - **kwargs, |
290 |
| -) -> ndarray: |
291 |
| - """ |
292 |
| - Array API compatibility wrapper for asarray(). |
293 |
| -
|
294 |
| - See the corresponding documentation in NumPy/CuPy and/or the array API |
295 |
| - specification for more details. |
296 |
| -
|
297 |
| - """ |
298 |
| - if namespace is None: |
299 |
| - try: |
300 |
| - xp = array_namespace(obj, _use_compat=False) |
301 |
| - except ValueError: |
302 |
| - # TODO: What about lists of arrays? |
303 |
| - raise ValueError("A namespace must be specified for asarray() with non-array input") |
304 |
| - elif isinstance(namespace, ModuleType): |
305 |
| - xp = namespace |
306 |
| - elif namespace == 'numpy': |
307 |
| - import numpy as xp |
308 |
| - elif namespace == 'cupy': |
309 |
| - import cupy as xp |
310 |
| - elif namespace == 'dask.array': |
311 |
| - import dask.array as xp |
312 |
| - else: |
313 |
| - raise ValueError("Unrecognized namespace argument to asarray()") |
314 |
| - |
315 |
| - _check_device(xp, device) |
316 |
| - if is_numpy_array(obj): |
317 |
| - import numpy as np |
318 |
| - if hasattr(np, '_CopyMode'): |
319 |
| - # Not present in older NumPys |
320 |
| - COPY_FALSE = (False, np._CopyMode.IF_NEEDED) |
321 |
| - COPY_TRUE = (True, np._CopyMode.ALWAYS) |
322 |
| - else: |
323 |
| - COPY_FALSE = (False,) |
324 |
| - COPY_TRUE = (True,) |
325 |
| - else: |
326 |
| - COPY_FALSE = (False,) |
327 |
| - COPY_TRUE = (True,) |
328 |
| - if copy in COPY_FALSE and namespace != "dask.array": |
329 |
| - # copy=False is not yet implemented in xp.asarray |
330 |
| - raise NotImplementedError("copy=False is not yet implemented") |
331 |
| - if (hasattr(xp, "ndarray") and isinstance(obj, xp.ndarray)): |
332 |
| - if dtype is not None and obj.dtype != dtype: |
333 |
| - copy = True |
334 |
| - if copy in COPY_TRUE: |
335 |
| - return xp.array(obj, copy=True, dtype=dtype) |
336 |
| - return obj |
337 |
| - elif namespace == "dask.array": |
338 |
| - if copy in COPY_TRUE: |
339 |
| - if dtype is None: |
340 |
| - return obj.copy() |
341 |
| - # Go through numpy, since dask copy is no-op by default |
342 |
| - import numpy as np |
343 |
| - obj = np.array(obj, dtype=dtype, copy=True) |
344 |
| - return xp.array(obj, dtype=dtype) |
345 |
| - else: |
346 |
| - import dask.array as da |
347 |
| - import numpy as np |
348 |
| - if not isinstance(obj, da.Array): |
349 |
| - obj = np.asarray(obj, dtype=dtype) |
350 |
| - return da.from_array(obj) |
351 |
| - return obj |
352 |
| - |
353 |
| - return xp.asarray(obj, dtype=dtype, **kwargs) |
354 |
| - |
355 | 271 | # np.reshape calls the keyword argument 'newshape' instead of 'shape'
|
356 | 272 | def reshape(x: ndarray,
|
357 | 273 | /,
|
|
0 commit comments