diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index 6b85517b..5870d755 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -1052,10 +1052,47 @@ def _invert_ab_noraise(u, v, ab, abp=None): dvdxcoef = (jnp.arange(nab)[:, None] * ab[1])[1:, :-1] dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:] - for _ in range(10): - x, y, dx, dy = _invert_ab_noraise_loop_body( - x, y, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef - ) + def _step(i, args): + x, y, _, _, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef = args + + # Want Jac^-1 . du + # du + du = horner2d(x, y, ab[0], triangle=True) - u + dv = horner2d(x, y, ab[1], triangle=True) - v + # J + dudx = horner2d(x, y, dudxcoef, triangle=True) + dudy = horner2d(x, y, dudycoef, triangle=True) + dvdx = horner2d(x, y, dvdxcoef, triangle=True) + dvdy = horner2d(x, y, dvdycoef, triangle=True) + # J^-1 . du + det = dudx * dvdy - dudy * dvdx + duu = -(du * dvdy - dv * dudy) / det + dvv = -(-du * dvdx + dv * dudx) / det + + x += duu + y += dvv + + return x, y, duu, dvv, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef + + x, y, dx, dy = jax.lax.fori_loop( + 0, + 10, + _step, + ( + x, + y, + jnp.zeros_like(x), + jnp.zeros_like(y), + u, + v, + ab, + dudxcoef, + dudycoef, + dvdxcoef, + dvdycoef, + ), + unroll=True, + )[0:4] x, y = jax.lax.cond( jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12, diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 0563ba78..fe596398 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -61,10 +61,15 @@ def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=100): nb2. In GalSim definition rm = 0 (ex. no truncated Moffat) means in reality rm=+Inf. BUT the case rm==0 is already done, so HERE rm != 0 """ - xcur = re - for _ in range(nitr): - xcur = _bodymi(xcur, rm, re, beta) - return xcur + + # fix loop iteration is faster and reach eps=1e-6 (single precision) + def body(i, xcur): + x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2 + x = jnp.power(x, 1 / (1 - beta)) + x = jnp.sqrt(x - 1) + return re / x + + return jax.lax.fori_loop(0, 100, body, re, unroll=True) @implements(_galsim.Moffat)