Skip to content

Commit

Permalink
perf: try using fori_loop but unrolling everything
Browse files Browse the repository at this point in the history
  • Loading branch information
beckermr committed Jan 30, 2025
1 parent 3e45fb5 commit 0571c59
Show file tree
Hide file tree
Showing 2 changed files with 50 additions and 8 deletions.
45 changes: 41 additions & 4 deletions jax_galsim/fitswcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,10 +1052,47 @@ def _invert_ab_noraise(u, v, ab, abp=None):
dvdxcoef = (jnp.arange(nab)[:, None] * ab[1])[1:, :-1]
dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:]

for _ in range(10):
x, y, dx, dy = _invert_ab_noraise_loop_body(
x, y, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef
)
def _step(i, args):
x, y, _, _, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef = args

# Want Jac^-1 . du
# du
du = horner2d(x, y, ab[0], triangle=True) - u
dv = horner2d(x, y, ab[1], triangle=True) - v
# J
dudx = horner2d(x, y, dudxcoef, triangle=True)
dudy = horner2d(x, y, dudycoef, triangle=True)
dvdx = horner2d(x, y, dvdxcoef, triangle=True)
dvdy = horner2d(x, y, dvdycoef, triangle=True)
# J^-1 . du
det = dudx * dvdy - dudy * dvdx
duu = -(du * dvdy - dv * dudy) / det
dvv = -(-du * dvdx + dv * dudx) / det

x += duu
y += dvv

return x, y, duu, dvv, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef

x, y, dx, dy = jax.lax.fori_loop(
0,
10,
_step,
(
x,
y,
jnp.zeros_like(x),
jnp.zeros_like(y),
u,
v,
ab,
dudxcoef,
dudycoef,
dvdxcoef,
dvdycoef,
),
unroll=True,
)[0:4]

x, y = jax.lax.cond(
jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12,
Expand Down
13 changes: 9 additions & 4 deletions jax_galsim/moffat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,15 @@ def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=100):
nb2. In GalSim definition rm = 0 (ex. no truncated Moffat) means in reality rm=+Inf.
BUT the case rm==0 is already done, so HERE rm != 0
"""
xcur = re
for _ in range(nitr):
xcur = _bodymi(xcur, rm, re, beta)
return xcur

# fix loop iteration is faster and reach eps=1e-6 (single precision)
def body(i, xcur):
x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2
x = jnp.power(x, 1 / (1 - beta))
x = jnp.sqrt(x - 1)
return re / x

return jax.lax.fori_loop(0, 100, body, re, unroll=True)


@implements(_galsim.Moffat)
Expand Down

0 comments on commit 0571c59

Please sign in to comment.