-
Notifications
You must be signed in to change notification settings - Fork 11
ENH: pad
: add delegation
#72
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 15 commits
d17fd2f
38690bb
d4d05b0
ea09206
9dcb9e5
5db1a93
486ebef
44ec95a
303c7fc
6f72daf
a54357b
71edc05
fd6b9d8
1e59fbd
e0046c5
2bd8205
1531841
64db422
ed7fd25
93f2591
47e9b5b
4c786a2
9daa33a
3ac8e45
82e5258
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,6 +13,7 @@ | |
expand_dims | ||
kron | ||
nunique | ||
pad | ||
setdiff1d | ||
sinc | ||
``` |
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
"""Delegation to existing implementations for Public API Functions.""" | ||
|
||
from types import ModuleType | ||
|
||
from ._lib import Library, _funcs | ||
from ._lib._utils._compat import array_namespace | ||
from ._lib._utils._typing import Array | ||
|
||
__all__ = ["pad"] | ||
|
||
|
||
def _delegate(xp: ModuleType, *backends: Library) -> bool: | ||
""" | ||
Check whether `xp` is one of the `backends` to delegate to. | ||
|
||
Parameters | ||
---------- | ||
xp : array_namespace | ||
Array namespace to check. | ||
*backends : IsNamespace | ||
Arbitrarily many backends (from the ``IsNamespace`` enum) to check. | ||
|
||
Returns | ||
------- | ||
bool | ||
``True`` if `xp` matches one of the `backends`, ``False`` otherwise. | ||
""" | ||
return any(backend.is_namespace(xp) for backend in backends) | ||
|
||
|
||
def pad( | ||
x: Array, | ||
pad_width: int | tuple[int, int] | list[tuple[int, int]], | ||
mode: str = "constant", | ||
*, | ||
constant_values: bool | int | float | complex = 0, | ||
xp: ModuleType | None = None, | ||
) -> Array: | ||
""" | ||
Pad the input array. | ||
|
||
Parameters | ||
---------- | ||
x : array | ||
Input array. | ||
pad_width : int or tuple of ints or list of pairs of ints | ||
Pad the input array with this many elements from each side. | ||
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``, | ||
each pair applies to the corresponding axis of ``x``. | ||
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim`` | ||
copies of this tuple. | ||
mode : str, optional | ||
Only "constant" mode is currently supported, which pads with | ||
the value passed to `constant_values`. | ||
constant_values : python scalar, optional | ||
Use this value to pad the input. Default is zero. | ||
xp : array_namespace, optional | ||
The standard-compatible namespace for `x`. Default: infer. | ||
|
||
Returns | ||
------- | ||
array | ||
The input array, | ||
padded with ``pad_width`` elements equal to ``constant_values``. | ||
""" | ||
xp = array_namespace(x) if xp is None else xp | ||
|
||
if mode != "constant": | ||
msg = "Only `'constant'` mode is currently supported" | ||
raise NotImplementedError(msg) | ||
|
||
# https://github.com/pytorch/pytorch/blob/cf76c05b4dc629ac989d1fb8e789d4fac04a095a/torch/_numpy/_funcs_impl.py#L2045-L2056 | ||
if _delegate(xp, Library.TORCH): | ||
lucascolley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
pad_width = xp.asarray(pad_width) | ||
pad_width = xp.broadcast_to(pad_width, (x.ndim, 2)) | ||
pad_width = xp.flip(pad_width, axis=(0,)).flatten() | ||
return xp.nn.functional.pad(x, tuple(pad_width), value=constant_values) # type: ignore[arg-type] # pyright: ignore[reportArgumentType] | ||
|
||
if _delegate(xp, Library.NUMPY, Library.JAX_NUMPY, Library.CUPY): | ||
lucascolley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
return xp.pad(x, pad_width, mode, constant_values=constant_values) | ||
|
||
return _funcs.pad(x, pad_width, constant_values=constant_values, xp=xp) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,5 @@ | ||
"""Modules housing private functions.""" | ||
"""Internals of array-api-extra.""" | ||
|
||
from ._libraries import Library | ||
|
||
__all__ = ["Library"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,4 @@ | ||
"""Public API Functions.""" | ||
"""Array-agnostic implementations for the public API.""" | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Except.... it isn't agnostic, see for example the special paths in There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes, I would like to split the file structure so that functions which make use of special paths are separate from array-agnostic implementations. I'll save that for a follow-up. |
||
# https://github.com/scikit-learn/scikit-learn/pull/27910#issuecomment-2568023972 | ||
from __future__ import annotations | ||
|
@@ -11,13 +11,9 @@ | |
from types import ModuleType | ||
from typing import ClassVar, cast | ||
|
||
from ._lib import _compat, _utils | ||
from ._lib._compat import ( | ||
array_namespace, | ||
is_jax_array, | ||
is_writeable_array, | ||
) | ||
from ._lib._typing import Array, Index | ||
from ._utils import _compat, _helpers | ||
from ._utils._compat import array_namespace, is_jax_array, is_writeable_array | ||
from ._utils._typing import Array, Index | ||
|
||
__all__ = [ | ||
"at", | ||
|
@@ -151,7 +147,7 @@ def cov(m: Array, /, *, xp: ModuleType | None = None) -> Array: | |
m = atleast_nd(m, ndim=2, xp=xp) | ||
m = xp.astype(m, dtype) | ||
|
||
avg = _utils.mean(m, axis=1, xp=xp) | ||
avg = _helpers.mean(m, axis=1, xp=xp) | ||
fact = m.shape[1] - 1 | ||
|
||
if fact <= 0: | ||
|
@@ -467,7 +463,7 @@ def setdiff1d( | |
else: | ||
x1 = xp.unique_values(x1) | ||
x2 = xp.unique_values(x2) | ||
return x1[_utils.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] | ||
return x1[_helpers.in1d(x1, x2, assume_unique=True, invert=True, xp=xp)] | ||
|
||
|
||
def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: | ||
|
@@ -562,54 +558,18 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array: | |
def pad( | ||
x: Array, | ||
pad_width: int | tuple[int, int] | list[tuple[int, int]], | ||
mode: str = "constant", | ||
*, | ||
xp: ModuleType | None = None, | ||
constant_values: bool | int | float | complex = 0, | ||
) -> Array: | ||
""" | ||
Pad the input array. | ||
|
||
Parameters | ||
---------- | ||
x : array | ||
Input array. | ||
pad_width : int or tuple of ints or list of pairs of ints | ||
Pad the input array with this many elements from each side. | ||
If a list of tuples, ``[(before_0, after_0), ... (before_N, after_N)]``, | ||
each pair applies to the corresponding axis of ``x``. | ||
A single tuple, ``(before, after)``, is equivalent to a list of ``x.ndim`` | ||
copies of this tuple. | ||
mode : str, optional | ||
Only "constant" mode is currently supported, which pads with | ||
the value passed to `constant_values`. | ||
xp : array_namespace, optional | ||
The standard-compatible namespace for `x`. Default: infer. | ||
constant_values : python scalar, optional | ||
Use this value to pad the input. Default is zero. | ||
|
||
Returns | ||
------- | ||
array | ||
The input array, | ||
padded with ``pad_width`` elements equal to ``constant_values``. | ||
""" | ||
if mode != "constant": | ||
msg = "Only `'constant'` mode is currently supported" | ||
raise NotImplementedError(msg) | ||
|
||
value = constant_values | ||
|
||
xp: ModuleType, | ||
) -> Array: # numpydoc ignore=PR01,RT01 | ||
lucascolley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
"""See docstring in `array_api_extra._delegation.py`.""" | ||
# make pad_width a list of length-2 tuples of ints | ||
x_ndim = cast(int, x.ndim) | ||
if isinstance(pad_width, int): | ||
pad_width = [(pad_width, pad_width)] * x_ndim | ||
if isinstance(pad_width, tuple): | ||
pad_width = [pad_width] * x_ndim | ||
|
||
if xp is None: | ||
xp = array_namespace(x) | ||
|
||
# https://github.com/python/typeshed/issues/13376 | ||
slices: list[slice] = [] # type: ignore[no-any-explicit] | ||
newshape: list[int] = [] | ||
|
@@ -633,7 +593,7 @@ def pad( | |
|
||
padded = xp.full( | ||
tuple(newshape), | ||
fill_value=value, | ||
fill_value=constant_values, | ||
dtype=x.dtype, | ||
device=_compat.device(x), | ||
) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,69 @@ | ||
"""Code specifying libraries array-api-extra interacts with.""" | ||
lucascolley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
||
from collections.abc import Callable | ||
from enum import Enum | ||
from types import ModuleType | ||
from typing import cast | ||
|
||
from ._utils import _compat | ||
|
||
__all__ = ["Library"] | ||
|
||
|
||
class Library(Enum): # numpydoc ignore=PR01,PR02 | ||
""" | ||
All array library backends explicitly tested by array-api-extra. | ||
|
||
Parameters | ||
---------- | ||
value : str | ||
String describing the backend. | ||
library_name : str | ||
Name of the array library of the backend. | ||
module_name : str | ||
Name of the backend's module. | ||
""" | ||
|
||
ARRAY_API_STRICT = "array_api_strict", "array_api_strict", "array_api_strict" | ||
lucascolley marked this conversation as resolved.
Show resolved
Hide resolved
|
||
NUMPY = "numpy", "numpy", "numpy" | ||
NUMPY_READONLY = "numpy_readonly", "numpy", "numpy" | ||
CUPY = "cupy", "cupy", "cupy" | ||
TORCH = "torch", "torch", "torch" | ||
DASK_ARRAY = "dask.array", "dask", "dask.array" | ||
SPARSE = "sparse", "pydata_sparse", "sparse" | ||
JAX_NUMPY = "jax.numpy", "jax", "jax.numpy" | ||
|
||
def __new__( | ||
cls, value: str, _library_name: str, _module_name: str | ||
): # numpydoc ignore=GL08 | ||
obj = object.__new__(cls) | ||
obj._value_ = value | ||
return obj | ||
|
||
def __init__( | ||
self, _value: str, library_name: str, module_name: str | ||
): # numpydoc ignore=GL08 | ||
self.library_name = library_name | ||
self.module_name = module_name | ||
|
||
def __str__(self) -> str: # type: ignore[explicit-override] # pyright: ignore[reportImplicitOverride] # numpydoc ignore=RT01 | ||
"""Pretty-print parameterized test names.""" | ||
return cast(str, self.value) | ||
|
||
def is_namespace(self, xp: ModuleType) -> bool: | ||
""" | ||
Call the corresponding is_namespace function. | ||
|
||
Parameters | ||
---------- | ||
xp : array_namespace | ||
Array namespace to check. | ||
|
||
Returns | ||
------- | ||
bool | ||
``True`` if xp matches the namespace, ``False`` otherwise. | ||
""" | ||
is_namespace_func = getattr(_compat, f"is_{self.library_name}_namespace") | ||
is_namespace_func = cast(Callable[[ModuleType], bool], is_namespace_func) | ||
return is_namespace_func(xp) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this whole class feels quite over-engineered. def is_namespace(xp: ModuleType, library: Library) -> bool: ...
def import(library: Library) -> ModuleType: ... We should experiment with making a fake namespace for numpy_readonly too (asarray returns a read-only array, everything else is redirected to array_api_compat.numpy) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice suggestion with
sounds cool for a follow-up |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
"""Modules housing private utility functions.""" |
Uh oh!
There was an error while loading. Please reload this page.