Skip to content

Commit

Permalink
perf: remove fori loop for fits wcs
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Jan 29, 2025
1 parent ba64136 commit 4764697
Showing 1 changed file with 6 additions and 29 deletions.
35 changes: 6 additions & 29 deletions jax_galsim/fitswcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1028,9 +1028,7 @@ def _invert_ab_noraise(u, v, ab, abp=None):
dvdxcoef = (jnp.arange(nab)[:, None] * ab[1])[1:, :-1]
dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:]

def _step(i, args):
x, y, _, _, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef = args

for _ in range(10):
# Want Jac^-1 . du
# du
du = horner2d(x, y, ab[0], triangle=True) - u
Expand All @@ -1042,32 +1040,11 @@ def _step(i, args):
dvdy = horner2d(x, y, dvdycoef, triangle=True)
# J^-1 . du
det = dudx * dvdy - dudy * dvdx
duu = -(du * dvdy - dv * dudy) / det
dvv = -(-du * dvdx + dv * dudx) / det

x += duu
y += dvv

return x, y, duu, dvv, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef

x, y, dx, dy = jax.lax.fori_loop(
0,
10,
_step,
(
x,
y,
jnp.zeros_like(x),
jnp.zeros_like(y),
u,
v,
ab,
dudxcoef,
dudycoef,
dvdxcoef,
dvdycoef,
),
)[0:4]
dx = -(du * dvdy - dv * dudy) / det
dy = -(-du * dvdx + dv * dudx) / det

x += dx
y += dy

x, y = jax.lax.cond(
jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12,
Expand Down

0 comments on commit 4764697

Please sign in to comment.