2
2
3
3
import re
4
4
from contextlib import contextmanager
5
- from functools import reduce , wraps
5
+ from functools import wraps
6
6
import math
7
- from operator import mul
8
7
import struct
9
8
from typing import Any , List , Mapping , NamedTuple , Optional , Sequence , Tuple , Union
10
9
@@ -217,9 +216,6 @@ def all_floating_dtypes() -> SearchStrategy[DataType]:
217
216
# Size to use for 2-dim arrays
218
217
SQRT_MAX_ARRAY_SIZE = int (math .sqrt (MAX_ARRAY_SIZE ))
219
218
220
- # np.prod and others have overflow and math.prod is Python 3.8+ only
221
- def prod (seq ):
222
- return reduce (mul , seq , 1 )
223
219
224
220
# hypotheses.strategies.tuples only generates tuples of a fixed size
225
221
def tuples (elements , * , min_size = 0 , max_size = None , unique_by = None , unique = False ):
@@ -233,7 +229,7 @@ def shapes(**kw):
233
229
kw .setdefault ('min_dims' , 0 )
234
230
kw .setdefault ('min_side' , 0 )
235
231
return xps .array_shapes (** kw ).filter (
236
- lambda shape : prod (i for i in shape if i ) < MAX_ARRAY_SIZE
232
+ lambda shape : math . prod (i for i in shape if i ) < MAX_ARRAY_SIZE
237
233
)
238
234
239
235
@@ -245,7 +241,7 @@ def matrix_shapes(draw, stack_shapes=shapes()):
245
241
stack_shape = draw (stack_shapes )
246
242
mat_shape = draw (xps .array_shapes (max_dims = 2 , min_dims = 2 ))
247
243
shape = stack_shape + mat_shape
248
- assume (prod (i for i in shape if i ) < MAX_ARRAY_SIZE )
244
+ assume (math . prod (i for i in shape if i ) < MAX_ARRAY_SIZE )
249
245
return shape
250
246
251
247
square_matrix_shapes = matrix_shapes ().filter (lambda shape : shape [- 1 ] == shape [- 2 ])
@@ -290,7 +286,7 @@ def mutually_broadcastable_shapes(
290
286
)
291
287
.map (lambda BS : BS .input_shapes )
292
288
.filter (lambda shapes : all (
293
- prod (i for i in s if i > 0 ) < MAX_ARRAY_SIZE for s in shapes
289
+ math . prod (i for i in s if i > 0 ) < MAX_ARRAY_SIZE for s in shapes
294
290
))
295
291
)
296
292
@@ -321,7 +317,7 @@ def positive_definite_matrices(draw, dtypes=floating_dtypes):
321
317
base_shape = draw (shapes ())
322
318
n = draw (integers (0 , 8 )) # 8 is an arbitrary small but interesting-enough value
323
319
shape = base_shape + (n , n )
324
- assume (prod (i for i in shape if i ) < MAX_ARRAY_SIZE )
320
+ assume (math . prod (i for i in shape if i ) < MAX_ARRAY_SIZE )
325
321
dtype = draw (dtypes )
326
322
return broadcast_to (eye (n , dtype = dtype ), shape )
327
323
0 commit comments