From dbdde18272e037d7e0ce38844db9475787b29d0d Mon Sep 17 00:00:00 2001 From: beckermr Date: Wed, 29 Jan 2025 06:59:59 -0600 Subject: [PATCH] test: add benchmark for fits --- tests/jax/test_benchmarks.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index 3c773dc6..f5eb0802 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -272,3 +272,20 @@ def test_benchmark_rng_discard(benchmark, kind): benchmark, kind, lambda: _run_benchmark_rng_discard(rng).block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_benchmark_invert_ab_noraise(u, v, ab): + return jgs.fitswcs._invert_ab_noraise(u, v, ab)[0] + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_invert_ab_noraise(benchmark, kind): + u = jnp.arange(1000).astype(jnp.float64) + v = jnp.arange(1000).astype(jnp.float64) + ab = jnp.array([[[-0.5, 0.3], [-0.1, 2.0]], [[-1.0, 0.3], [-0.1, 4.0]]]) + dt = _run_benchmarks( + benchmark, + kind, + lambda: _run_benchmark_invert_ab_noraise(u, v, ab).block_until_ready(), + ) + print(f"time: {dt:0.4g} ms", end=" ")