Skip to content

Commit 8ca8408

Browse files
trax-robotcopybara-github
authored andcommitted
disable certain tests in lax_numpy_test when numpy 2.0 is used
PiperOrigin-RevId: 673089320
1 parent 6002f18 commit 8ca8408

File tree

1 file changed

+93
-44
lines changed

1 file changed

+93
-44
lines changed

trax/tf_numpy/jax_tests/lax_numpy_test.py

Lines changed: 93 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -971,16 +971,33 @@ def onp_fun(lhs, rhs):
971971
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=False, atol=tol,
972972
rtol=tol, check_incomplete_shape=True)
973973

974-
@named_parameters(jtu.cases_from_list(
975-
{"testcase_name": "_{}_amin={}_amax={}".format(
976-
jtu.format_shape_dtype_string(shape, dtype), a_min, a_max),
977-
"shape": shape, "dtype": dtype, "a_min": a_min, "a_max": a_max,
978-
"rng_factory": jtu.rand_default}
979-
for shape in all_shapes for dtype in minus(number_dtypes, complex_dtypes)
980-
for a_min, a_max in [(-1, None), (None, 1), (-1, 1),
981-
(-onp.ones(1), None),
982-
(None, onp.ones(1)),
983-
(-onp.ones(1), onp.ones(1))]))
974+
@named_parameters(
975+
jtu.cases_from_list(
976+
{
977+
"testcase_name": "_{}_amin={}_amax={}".format(
978+
jtu.format_shape_dtype_string(shape, dtype), a_min, a_max
979+
),
980+
"shape": shape,
981+
"dtype": dtype,
982+
"a_min": a_min,
983+
"a_max": a_max,
984+
"rng_factory": jtu.rand_default,
985+
}
986+
for shape in all_shapes
987+
for dtype in minus(number_dtypes, complex_dtypes)
988+
for a_min, a_max in [
989+
(-1, None),
990+
(None, 1),
991+
(-onp.ones(1), None),
992+
(None, onp.ones(1)),
993+
]
994+
+ (
995+
[]
996+
if onp.__version__ >= onp.lib.NumpyVersion("2.0.0")
997+
else [(-1, 1), (-onp.ones(1), onp.ones(1))]
998+
)
999+
)
1000+
)
9841001
def testClipStaticBounds(self, shape, dtype, a_min, a_max, rng_factory):
9851002
rng = rng_factory()
9861003
onp_fun = lambda x: onp.clip(x, a_min=a_min, a_max=a_max)
@@ -1357,7 +1374,6 @@ def testDiagIndices(self, ndim, n):
13571374
onp.testing.assert_equal(onp.diag_indices(n, ndim),
13581375
lnp.diag_indices(n, ndim))
13591376

1360-
13611377
@named_parameters(jtu.cases_from_list(
13621378
{"testcase_name": "_shape={}_k={}".format(
13631379
jtu.format_shape_dtype_string(shape, dtype), k),
@@ -1951,7 +1967,6 @@ def testFlipud(self, shape, dtype, rng_factory):
19511967
self._CompileAndCheck(
19521968
lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True)
19531969

1954-
19551970
@named_parameters(jtu.cases_from_list(
19561971
{"testcase_name": "_{}".format(
19571972
jtu.format_shape_dtype_string(shape, dtype)),
@@ -1968,7 +1983,6 @@ def testFliplr(self, shape, dtype, rng_factory):
19681983
self._CompileAndCheck(
19691984
lnp_op, args_maker, check_dtypes=True, check_incomplete_shape=True)
19701985

1971-
19721986
@named_parameters(jtu.cases_from_list(
19731987
{"testcase_name": "_{}_k={}_axes={}".format(
19741988
jtu.format_shape_dtype_string(shape, dtype), k, axes),
@@ -2295,7 +2309,6 @@ def onp_fun(*args):
22952309
tol=tol)
22962310
self._CompileAndCheck(lnp_fun, args_maker, check_dtypes=True, rtol=tol)
22972311

2298-
22992312
@named_parameters(jtu.cases_from_list(
23002313
{"testcase_name": "_shape={}".format(
23012314
jtu.format_shape_dtype_string(shape, dtype)),
@@ -2318,7 +2331,6 @@ def testWhereOneArgument(self, shape, dtype):
23182331
check_unknown_rank=False,
23192332
check_experimental_compile=False, check_xla_forced_compile=False)
23202333

2321-
23222334
@named_parameters(jtu.cases_from_list(
23232335
{"testcase_name": "_{}".format("_".join(
23242336
jtu.format_shape_dtype_string(shape, dtype)
@@ -2373,7 +2385,6 @@ def onp_fun(condlist, choicelist, default):
23732385
check_incomplete_shape=True,
23742386
rtol={onp.float64: 1e-7, onp.complex128: 1e-7})
23752387

2376-
23772388
@jtu.disable
23782389
def testIssue330(self):
23792390
x = lnp.full((1, 1), lnp.array([1])[0]) # doesn't crash
@@ -2429,7 +2440,6 @@ def testAtLeastNdLiterals(self, pytype, dtype, op):
24292440
self._CompileAndCheck(
24302441
lnp_fun, args_maker, check_dtypes=True, check_incomplete_shape=True)
24312442

2432-
24332443
def testLongLong(self):
24342444
self.assertAllClose(
24352445
onp.int64(7), npe.jit(lambda x: x)(onp.longlong(7)), check_dtypes=True)
@@ -2676,19 +2686,38 @@ def testMeshGrid(self, shapes, dtype, indexing, sparse, rng_factory):
26762686

26772687
@named_parameters(
26782688
jtu.cases_from_list(
2679-
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
2680-
"_retstep={}_dtype={}").format(
2681-
start_shape, stop_shape, num, endpoint, retstep, dtype),
2682-
"start_shape": start_shape, "stop_shape": stop_shape,
2683-
"num": num, "endpoint": endpoint, "retstep": retstep,
2684-
"dtype": dtype, "rng_factory": rng_factory}
2685-
for start_shape in [(), (2,), (2, 2)]
2686-
for stop_shape in [(), (2,), (2, 2)]
2687-
for num in [0, 1, 2, 5, 20]
2688-
for endpoint in [True, False]
2689-
for retstep in [True, False]
2690-
for dtype in number_dtypes + [None,]
2691-
for rng_factory in [jtu.rand_default]))
2689+
{
2690+
"testcase_name": (
2691+
"_start_shape={}_stop_shape={}_num={}_endpoint={}"
2692+
"_retstep={}_dtype={}"
2693+
).format(start_shape, stop_shape, num, endpoint, retstep, dtype),
2694+
"start_shape": start_shape,
2695+
"stop_shape": stop_shape,
2696+
"num": num,
2697+
"endpoint": endpoint,
2698+
"retstep": retstep,
2699+
"dtype": dtype,
2700+
"rng_factory": rng_factory,
2701+
}
2702+
for start_shape in [(), (2,), (2, 2)]
2703+
for stop_shape in [(), (2,), (2, 2)]
2704+
for num in [0, 1, 2, 5, 20]
2705+
for endpoint in [True, False]
2706+
for retstep in [True, False]
2707+
for dtype in (
2708+
(
2709+
float_dtypes
2710+
+ complex_dtypes
2711+
+ [
2712+
None,
2713+
]
2714+
)
2715+
if onp.__version__ >= onp.lib.NumpyVersion("2.0.0")
2716+
else (number_dtypes + [None])
2717+
)
2718+
for rng_factory in [jtu.rand_default]
2719+
)
2720+
)
26922721
def testLinspace(self, start_shape, stop_shape, num, endpoint,
26932722
retstep, dtype, rng_factory):
26942723
if not endpoint and onp.issubdtype(dtype, onp.integer):
@@ -2770,20 +2799,40 @@ def testLogspace(self, start_shape, stop_shape, num,
27702799

27712800
@named_parameters(
27722801
jtu.cases_from_list(
2773-
{"testcase_name": ("_start_shape={}_stop_shape={}_num={}_endpoint={}"
2774-
"_dtype={}").format(
2775-
start_shape, stop_shape, num, endpoint, dtype),
2776-
"start_shape": start_shape,
2777-
"stop_shape": stop_shape,
2778-
"num": num, "endpoint": endpoint,
2779-
"dtype": dtype, "rng_factory": rng_factory}
2780-
for start_shape in [(), (2,), (2, 2)]
2781-
for stop_shape in [(), (2,), (2, 2)]
2782-
for num in [0, 1, 2, 5, 20]
2783-
for endpoint in [True, False]
2784-
# NB: numpy's geomspace gives nonsense results on integer types
2785-
for dtype in inexact_dtypes + [None,]
2786-
for rng_factory in [jtu.rand_default]))
2802+
{
2803+
"testcase_name": (
2804+
"_start_shape={}_stop_shape={}_num={}_endpoint={}_dtype={}"
2805+
).format(start_shape, stop_shape, num, endpoint, dtype),
2806+
"start_shape": start_shape,
2807+
"stop_shape": stop_shape,
2808+
"num": num,
2809+
"endpoint": endpoint,
2810+
"dtype": dtype,
2811+
"rng_factory": rng_factory,
2812+
}
2813+
for start_shape in [(), (2,), (2, 2)]
2814+
for stop_shape in [(), (2,), (2, 2)]
2815+
for num in [0, 1, 2, 5, 20]
2816+
for endpoint in [True, False]
2817+
# NB: numpy's geomspace gives nonsense results on integer types
2818+
for dtype in (
2819+
(
2820+
float_dtypes
2821+
+ [
2822+
None,
2823+
]
2824+
)
2825+
if onp.__version__ >= onp.lib.NumpyVersion("2.0.0")
2826+
else (
2827+
inexact_dtypes
2828+
+ [
2829+
None,
2830+
]
2831+
)
2832+
)
2833+
for rng_factory in [jtu.rand_default]
2834+
)
2835+
)
27872836
def testGeomspace(self, start_shape, stop_shape, num,
27882837
endpoint, dtype, rng_factory):
27892838
rng = rng_factory()

0 commit comments

Comments
 (0)