Skip to content

Commit 169f21d

Browse files
ev-brlucascolley
andauthored
ENH: add pad (#71)
* ENH: add pad * remove delegation for now * tweaks * add xp, device tests --------- Co-authored-by: Lucas Colley <[email protected]>
1 parent 6df1916 commit 169f21d

File tree

3 files changed

+93
-1
lines changed

3 files changed

+93
-1
lines changed

src/array_api_extra/__init__.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
"""Extra array functions built on top of the array API standard."""
22

3-
from ._funcs import atleast_nd, cov, create_diagonal, expand_dims, kron, setdiff1d, sinc
3+
from ._funcs import (
4+
atleast_nd,
5+
cov,
6+
create_diagonal,
7+
expand_dims,
8+
kron,
9+
pad,
10+
setdiff1d,
11+
sinc,
12+
)
413

514
__version__ = "0.4.1.dev0"
615

@@ -12,6 +21,7 @@
1221
"create_diagonal",
1322
"expand_dims",
1423
"kron",
24+
"pad",
1525
"setdiff1d",
1626
"sinc",
1727
]

src/array_api_extra/_funcs.py

+51
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
"create_diagonal",
1313
"expand_dims",
1414
"kron",
15+
"pad",
1516
"setdiff1d",
1617
"sinc",
1718
]
@@ -538,3 +539,53 @@ def sinc(x: Array, /, *, xp: ModuleType | None = None) -> Array:
538539
xp.asarray(xp.finfo(x.dtype).eps, dtype=x.dtype, device=_compat.device(x)),
539540
)
540541
return xp.sin(y) / y
542+
543+
544+
def pad(
545+
x: Array,
546+
pad_width: int,
547+
mode: str = "constant",
548+
*,
549+
xp: ModuleType | None = None,
550+
constant_values: bool | int | float | complex = 0,
551+
) -> Array:
552+
"""
553+
Pad the input array.
554+
555+
Parameters
556+
----------
557+
x : array
558+
Input array.
559+
pad_width : int
560+
Pad the input array with this many elements from each side.
561+
mode : str, optional
562+
Only "constant" mode is currently supported, which pads with
563+
the value passed to `constant_values`.
564+
xp : array_namespace, optional
565+
The standard-compatible namespace for `x`. Default: infer.
566+
constant_values : python scalar, optional
567+
Use this value to pad the input. Default is zero.
568+
569+
Returns
570+
-------
571+
array
572+
The input array,
573+
padded with ``pad_width`` elements equal to ``constant_values``.
574+
"""
575+
if mode != "constant":
576+
msg = "Only `'constant'` mode is currently supported"
577+
raise NotImplementedError(msg)
578+
579+
value = constant_values
580+
581+
if xp is None:
582+
xp = array_namespace(x)
583+
584+
padded = xp.full(
585+
tuple(x + 2 * pad_width for x in x.shape),
586+
fill_value=value,
587+
dtype=x.dtype,
588+
device=_compat.device(x),
589+
)
590+
padded[(slice(pad_width, -pad_width, None),) * x.ndim] = x
591+
return padded

tests/test_funcs.py

+31
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
create_diagonal,
1414
expand_dims,
1515
kron,
16+
pad,
1617
setdiff1d,
1718
sinc,
1819
)
@@ -385,3 +386,33 @@ def test_device(self):
385386

386387
def test_xp(self):
387388
assert_array_equal(sinc(xp.asarray(0.0), xp=xp), xp.asarray(1.0))
389+
390+
391+
class TestPad:
392+
def test_simple(self):
393+
a = xp.arange(1, 4)
394+
padded = pad(a, 2)
395+
assert xp.all(padded == xp.asarray([0, 0, 1, 2, 3, 0, 0]))
396+
397+
def test_fill_value(self):
398+
a = xp.arange(1, 4)
399+
padded = pad(a, 2, constant_values=42)
400+
assert xp.all(padded == xp.asarray([42, 42, 1, 2, 3, 42, 42]))
401+
402+
def test_ndim(self):
403+
a = xp.reshape(xp.arange(2 * 3 * 4), (2, 3, 4))
404+
padded = pad(a, 2)
405+
assert padded.shape == (6, 7, 8)
406+
407+
def test_mode_not_implemented(self):
408+
a = xp.arange(3)
409+
with pytest.raises(NotImplementedError, match="Only `'constant'`"):
410+
pad(a, 2, mode="edge")
411+
412+
def test_device(self):
413+
device = xp.Device("device1")
414+
a = xp.asarray(0.0, device=device)
415+
assert pad(a, 2).device == device
416+
417+
def test_xp(self):
418+
assert_array_equal(pad(xp.asarray(0), 1, xp=xp), xp.zeros(3))

0 commit comments

Comments
 (0)