Skip to content

Commit

Permalink
perf: try this for slow tests
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Jan 29, 2025
1 parent d7cf0ce commit 9f5e17e
Showing 1 changed file with 27 additions and 16 deletions.
43 changes: 27 additions & 16 deletions jax_galsim/fitswcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1008,6 +1008,30 @@ def FitsWCS(
FitsWCS._opt_params = {"dir": str, "hdu": int, "compression": str, "text_file": bool}


@jax.jit
def _invert_ab_noraise_loop_body(
x, y, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef
):
# Want Jac^-1 . du
# du
du = horner2d(x, y, ab[0], triangle=True) - u
dv = horner2d(x, y, ab[1], triangle=True) - v
# J
dudx = horner2d(x, y, dudxcoef, triangle=True)
dudy = horner2d(x, y, dudycoef, triangle=True)
dvdx = horner2d(x, y, dvdxcoef, triangle=True)
dvdy = horner2d(x, y, dvdycoef, triangle=True)
# J^-1 . du
det = dudx * dvdy - dudy * dvdx
dx = -(du * dvdy - dv * dudy) / det
dy = -(-du * dvdx + dv * dudx) / det

x += dx
y += dy

return x, y, dx, dy


@jax.jit
def _invert_ab_noraise(u, v, ab, abp=None):
# get guess from abp if we have it
Expand All @@ -1029,22 +1053,9 @@ def _invert_ab_noraise(u, v, ab, abp=None):
dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:]

for _ in range(10):
# Want Jac^-1 . du
# du
du = horner2d(x, y, ab[0], triangle=True) - u
dv = horner2d(x, y, ab[1], triangle=True) - v
# J
dudx = horner2d(x, y, dudxcoef, triangle=True)
dudy = horner2d(x, y, dudycoef, triangle=True)
dvdx = horner2d(x, y, dvdxcoef, triangle=True)
dvdy = horner2d(x, y, dvdycoef, triangle=True)
# J^-1 . du
det = dudx * dvdy - dudy * dvdx
dx = -(du * dvdy - dv * dudy) / det
dy = -(-du * dvdx + dv * dudx) / det

x += dx
y += dy
x, y, dx, dy = _invert_ab_noraise_loop_body(
x, y, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef
)

x, y = jax.lax.cond(
jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12,
Expand Down

0 comments on commit 9f5e17e

Please sign in to comment.