Skip to content

Commit 11bb686

Browse files
authored
Merge pull request #308 from ev-br/math_prod
Use `math.prod`
2 parents 393b69b + 1f8676f commit 11bb686

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

array_api_tests/hypothesis_helpers.py

+5-9
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,8 @@
22

33
import re
44
from contextlib import contextmanager
5-
from functools import reduce, wraps
5+
from functools import wraps
66
import math
7-
from operator import mul
87
import struct
98
from typing import Any, List, Mapping, NamedTuple, Optional, Sequence, Tuple, Union
109

@@ -217,9 +216,6 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
217216
# Size to use for 2-dim arrays
218217
SQRT_MAX_ARRAY_SIZE = int(math.sqrt(MAX_ARRAY_SIZE))
219218

220-
# np.prod and others have overflow and math.prod is Python 3.8+ only
221-
def prod(seq):
222-
return reduce(mul, seq, 1)
223219

224220
# hypotheses.strategies.tuples only generates tuples of a fixed size
225221
def tuples(elements, *, min_size=0, max_size=None, unique_by=None, unique=False):
@@ -233,7 +229,7 @@ def shapes(**kw):
233229
kw.setdefault('min_dims', 0)
234230
kw.setdefault('min_side', 0)
235231
return xps.array_shapes(**kw).filter(
236-
lambda shape: prod(i for i in shape if i) < MAX_ARRAY_SIZE
232+
lambda shape: math.prod(i for i in shape if i) < MAX_ARRAY_SIZE
237233
)
238234

239235

@@ -245,7 +241,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
245241
stack_shape = draw(stack_shapes)
246242
mat_shape = draw(xps.array_shapes(max_dims=2, min_dims=2))
247243
shape = stack_shape + mat_shape
248-
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
244+
assume(math.prod(i for i in shape if i) < MAX_ARRAY_SIZE)
249245
return shape
250246

251247
square_matrix_shapes = matrix_shapes().filter(lambda shape: shape[-1] == shape[-2])
@@ -290,7 +286,7 @@ def mutually_broadcastable_shapes(
290286
)
291287
.map(lambda BS: BS.input_shapes)
292288
.filter(lambda shapes: all(
293-
prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes
289+
math.prod(i for i in s if i > 0) < MAX_ARRAY_SIZE for s in shapes
294290
))
295291
)
296292

@@ -321,7 +317,7 @@ def positive_definite_matrices(draw, dtypes=floating_dtypes):
321317
base_shape = draw(shapes())
322318
n = draw(integers(0, 8)) # 8 is an arbitrary small but interesting-enough value
323319
shape = base_shape + (n, n)
324-
assume(prod(i for i in shape if i) < MAX_ARRAY_SIZE)
320+
assume(math.prod(i for i in shape if i) < MAX_ARRAY_SIZE)
325321
dtype = draw(dtypes)
326322
return broadcast_to(eye(n, dtype=dtype), shape)
327323

array_api_tests/test_statistical_functions.py

+1
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def test_std(x, data):
286286
)
287287
# We can't easily test the result(s) as standard deviation methods vary a lot
288288

289+
289290
def _sum_condition_number(elements):
290291
sum_abs = sum([abs(i) for i in elements])
291292
abs_sum = abs(sum(elements))

0 commit comments

Comments
 (0)