Skip to content

Commit 33f39a8

Browse files
authored
MAINT: simple refactor (#45)
1 parent fc02b56 commit 33f39a8

File tree

3 files changed

+71
-71
lines changed

3 files changed

+71
-71
lines changed

src/array_api_extra/_funcs.py

+1-21
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ def cov(m: Array, /, *, xp: ModuleType) -> Array:
133133
m = atleast_nd(m, ndim=2, xp=xp)
134134
m = xp.astype(m, dtype)
135135

136-
avg = _mean(m, axis=1, xp=xp)
136+
avg = _utils.mean(m, axis=1, xp=xp)
137137
fact = m.shape[1] - 1
138138

139139
if fact <= 0:
@@ -199,26 +199,6 @@ def create_diagonal(x: Array, /, *, offset: int = 0, xp: ModuleType) -> Array:
199199
return xp.reshape(diag, (n, n))
200200

201201

202-
def _mean(
203-
x: Array,
204-
/,
205-
*,
206-
axis: int | tuple[int, ...] | None = None,
207-
keepdims: bool = False,
208-
xp: ModuleType,
209-
) -> Array:
210-
"""
211-
Complex mean, https://github.com/data-apis/array-api/issues/846.
212-
"""
213-
if xp.isdtype(x.dtype, "complex floating"):
214-
x_real = xp.real(x)
215-
x_imag = xp.imag(x)
216-
mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims)
217-
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
218-
return mean_real + (mean_imag * xp.asarray(1j))
219-
return xp.mean(x, axis=axis, keepdims=keepdims)
220-
221-
222202
def expand_dims(
223203
a: Array, /, *, axis: int | tuple[int, ...] = (0,), xp: ModuleType
224204
) -> Array:

src/array_api_extra/_lib/_utils.py

+21-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
from . import _compat
99

10-
__all__ = ["in1d"]
10+
__all__ = ["in1d", "mean"]
1111

1212

1313
def in1d(
@@ -63,3 +63,23 @@ def in1d(
6363
if assume_unique:
6464
return ret[: x1.shape[0]]
6565
return xp.take(ret, rev_idx, axis=0)
66+
67+
68+
def mean(
69+
x: Array,
70+
/,
71+
*,
72+
axis: int | tuple[int, ...] | None = None,
73+
keepdims: bool = False,
74+
xp: ModuleType,
75+
) -> Array:
76+
"""
77+
Complex mean, https://github.com/data-apis/array-api/issues/846.
78+
"""
79+
if xp.isdtype(x.dtype, "complex floating"):
80+
x_real = xp.real(x)
81+
x_imag = xp.imag(x)
82+
mean_real = xp.mean(x_real, axis=axis, keepdims=keepdims)
83+
mean_imag = xp.mean(x_imag, axis=axis, keepdims=keepdims)
84+
return mean_real + (mean_imag * xp.asarray(1j))
85+
return xp.mean(x, axis=axis, keepdims=keepdims)

tests/test_funcs.py

+49-49
Original file line numberDiff line numberDiff line change
@@ -157,6 +157,55 @@ def test_2d(self):
157157
create_diagonal(xp.asarray([[1]]), xp=xp)
158158

159159

160+
class TestExpandDims:
161+
def test_functionality(self):
162+
def _squeeze_all(b: Array) -> Array:
163+
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
164+
for axis in range(b.ndim):
165+
with contextlib.suppress(ValueError):
166+
b = xp.squeeze(b, axis=axis)
167+
return b
168+
169+
s = (2, 3, 4, 5)
170+
a = xp.empty(s)
171+
for axis in range(-5, 4):
172+
b = expand_dims(a, axis=axis, xp=xp)
173+
assert b.shape[axis] == 1
174+
assert _squeeze_all(b).shape == s
175+
176+
def test_axis_tuple(self):
177+
a = xp.empty((3, 3, 3))
178+
assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3)
179+
assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1)
180+
assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1)
181+
assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3)
182+
183+
def test_axis_out_of_range(self):
184+
s = (2, 3, 4, 5)
185+
a = xp.empty(s)
186+
with pytest.raises(IndexError, match="out of bounds"):
187+
expand_dims(a, axis=-6, xp=xp)
188+
with pytest.raises(IndexError, match="out of bounds"):
189+
expand_dims(a, axis=5, xp=xp)
190+
191+
a = xp.empty((3, 3, 3))
192+
with pytest.raises(IndexError, match="out of bounds"):
193+
expand_dims(a, axis=(0, -6), xp=xp)
194+
with pytest.raises(IndexError, match="out of bounds"):
195+
expand_dims(a, axis=(0, 5), xp=xp)
196+
197+
def test_repeated_axis(self):
198+
a = xp.empty((3, 3, 3))
199+
with pytest.raises(ValueError, match="Duplicate dimensions"):
200+
expand_dims(a, axis=(1, 1), xp=xp)
201+
202+
def test_positive_negative_repeated(self):
203+
# https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
204+
a = xp.empty((2, 3, 4, 5))
205+
with pytest.raises(ValueError, match="Duplicate dimensions"):
206+
expand_dims(a, axis=(3, -3), xp=xp)
207+
208+
160209
class TestKron:
161210
def test_basic(self):
162211
# Using 0-dimensional array
@@ -222,55 +271,6 @@ def test_kron_shape(self, shape_a: tuple[int, ...], shape_b: tuple[int, ...]):
222271
assert_equal(k.shape, expected_shape, err_msg="Unexpected shape from kron")
223272

224273

225-
class TestExpandDims:
226-
def test_functionality(self):
227-
def _squeeze_all(b: Array) -> Array:
228-
"""Mimics `np.squeeze(b)`. `xpx.squeeze`?"""
229-
for axis in range(b.ndim):
230-
with contextlib.suppress(ValueError):
231-
b = xp.squeeze(b, axis=axis)
232-
return b
233-
234-
s = (2, 3, 4, 5)
235-
a = xp.empty(s)
236-
for axis in range(-5, 4):
237-
b = expand_dims(a, axis=axis, xp=xp)
238-
assert b.shape[axis] == 1
239-
assert _squeeze_all(b).shape == s
240-
241-
def test_axis_tuple(self):
242-
a = xp.empty((3, 3, 3))
243-
assert expand_dims(a, axis=(0, 1, 2), xp=xp).shape == (1, 1, 1, 3, 3, 3)
244-
assert expand_dims(a, axis=(0, -1, -2), xp=xp).shape == (1, 3, 3, 3, 1, 1)
245-
assert expand_dims(a, axis=(0, 3, 5), xp=xp).shape == (1, 3, 3, 1, 3, 1)
246-
assert expand_dims(a, axis=(0, -3, -5), xp=xp).shape == (1, 1, 3, 1, 3, 3)
247-
248-
def test_axis_out_of_range(self):
249-
s = (2, 3, 4, 5)
250-
a = xp.empty(s)
251-
with pytest.raises(IndexError, match="out of bounds"):
252-
expand_dims(a, axis=-6, xp=xp)
253-
with pytest.raises(IndexError, match="out of bounds"):
254-
expand_dims(a, axis=5, xp=xp)
255-
256-
a = xp.empty((3, 3, 3))
257-
with pytest.raises(IndexError, match="out of bounds"):
258-
expand_dims(a, axis=(0, -6), xp=xp)
259-
with pytest.raises(IndexError, match="out of bounds"):
260-
expand_dims(a, axis=(0, 5), xp=xp)
261-
262-
def test_repeated_axis(self):
263-
a = xp.empty((3, 3, 3))
264-
with pytest.raises(ValueError, match="Duplicate dimensions"):
265-
expand_dims(a, axis=(1, 1), xp=xp)
266-
267-
def test_positive_negative_repeated(self):
268-
# https://github.com/data-apis/array-api/issues/760#issuecomment-1989449817
269-
a = xp.empty((2, 3, 4, 5))
270-
with pytest.raises(ValueError, match="Duplicate dimensions"):
271-
expand_dims(a, axis=(3, -3), xp=xp)
272-
273-
274274
class TestSetDiff1D:
275275
def test_setdiff1d(self):
276276
x1 = xp.asarray([6, 5, 4, 7, 1, 2, 7, 4])

0 commit comments

Comments
 (0)