Skip to content

Commit

Permalink
test: add benchmark for fits
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Jan 29, 2025
1 parent ba64136 commit dbdde18
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions tests/jax/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,20 @@ def test_benchmark_rng_discard(benchmark, kind):
benchmark, kind, lambda: _run_benchmark_rng_discard(rng).block_until_ready()
)
print(f"time: {dt:0.4g} ms", end=" ")


def _run_benchmark_invert_ab_noraise(u, v, ab):
return jgs.fitswcs._invert_ab_noraise(u, v, ab)[0]


@pytest.mark.parametrize("kind", ["compile", "run"])
def test_benchmark_invert_ab_noraise(benchmark, kind):
u = jnp.arange(1000).astype(jnp.float64)
v = jnp.arange(1000).astype(jnp.float64)
ab = jnp.array([[[-0.5, 0.3], [-0.1, 2.0]], [[-1.0, 0.3], [-0.1, 4.0]]])
dt = _run_benchmarks(
benchmark,
kind,
lambda: _run_benchmark_invert_ab_noraise(u, v, ab).block_until_ready(),
)
print(f"time: {dt:0.4g} ms", end=" ")

0 comments on commit dbdde18

Please sign in to comment.