Skip to content

Commit

Permalink
perf: try using fori_loop but unrolling everything (#141)
Browse files Browse the repository at this point in the history
* perf: try using fori_loop but unrolling everything

* perf: unroll more

* fix: cannot unroll discard since n is not static

* test: try using -v to see which test fails

* test: try this

* test: reduce matrix

* Update utils.py

* perf: try partial loop unrolling

* test: add spergel init benchmark

* dev: try a newton fixed-point method

* style: pre the commit
  • Loading branch information
beckermr authored Jan 31, 2025
1 parent 3e45fb5 commit 84abf17
Show file tree
Hide file tree
Showing 8 changed files with 249 additions and 20 deletions.
10 changes: 3 additions & 7 deletions .github/workflows/python_package.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,7 @@ jobs:
strategy:
fail-fast: false
matrix:
python-version: ["3.10", "3.11", "3.12"]
group: [1, 2, 3, 4]
python-version: ["3.12"]

steps:
- uses: actions/checkout@v4
Expand All @@ -34,18 +33,15 @@ jobs:
- name: Install dependencies
run: |
python -m pip install --upgrade pip
python -m pip install pytest pytest-codspeed pytest-split pytest-randomly
python -m pip install pytest pytest-codspeed pytest-randomly
python -m pip install .
# temp pin until 0.5 is on conda
python -m pip install "jax<0.5.0"
- name: Test with pytest
run: |
git submodule update --init --recursive
pytest -v --durations=0 \
-k "not test_fpack" \
--randomly-seed=42 \
--splits=4 --group=${{ matrix.group }} --splitting-algorithm least_duration
pytest -vv --durations=100 --randomly-seed=42
build-status:
needs: build
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,6 @@ repos:
hooks:
- id: ruff
args: [ --fix ]
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/|dev/notebooks/
- id: ruff-format
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/
exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/|dev/notebooks/
174 changes: 174 additions & 0 deletions dev/notebooks/spergel_fixed_point.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,174 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "d7b8bc37-8799-433c-9399-de95a21a1727",
"metadata": {},
"outputs": [],
"source": [
"import jax\n",
"jax.config.update(\"jax_enable_x64\", True)\n",
"\n",
"import galsim\n",
"import numpy as np\n",
"\n",
"import jax.numpy as jnp"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "774101b1",
"metadata": {},
"outputs": [],
"source": [
"from jax_galsim.spergel import (\n",
" fz_nup1, _gammap1, _spergel_hlr_pade,\n",
" fluxfractionFunc, fz_nu, calculateFluxRadius,\n",
")\n",
"\n",
"@jax.jit\n",
"def _calculateFluxRadius_newtons_kernel(i, args):\n",
" \"\"\"Newton's method kernel for calculateFluxRadius\n",
"\n",
" Returns\n",
"\n",
" lnz - fluxfractionFunc(z, nu, alpha) / dfluxfractionFunc(z, nu, alpha)_dz / z\n",
"\n",
" which is Newton's kernel but in log space.\n",
" \"\"\"\n",
" lnz, alpha, nu = args\n",
" z = jnp.exp(lnz)\n",
" dn = (jnp.power(2.0, nu) * _gammap1(nu))\n",
" fz = 1.0 - fz_nup1(z, nu) / dn - alpha\n",
" dfzdz = z * fz_nu(z, nu) / dn\n",
"\n",
" # we clip the result to avoid numerical issues near bounds\n",
" lnz = jnp.clip(\n",
" lnz - fz / dfzdz / z,\n",
" min=-100,\n",
" max=100,\n",
" )\n",
"\n",
" return lnz, alpha, nu\n",
"\n",
"\n",
"@jax.jit\n",
"def calculateFluxRadiusNewton(alpha, nu):\n",
" \"\"\"Return radius R enclosing flux fraction alpha in unit of the scale radius r0\n",
"\n",
" Method: Solve F(R/r0=z)/Flux - alpha = 0 using Netwon's method\n",
"\n",
" We can integrate the profile to get\n",
"\n",
" F(R)/F = int( 1/(2^nu Gamma(nu+1)) (r/r0)^(nu+1) K_nu(r/r0) dr/r0; r=0..R) = alpha\n",
"\n",
" So if we define z = R/r0 and f(z) = F(z * r0)/F - alpha, then Newton's method is\n",
"\n",
" z -> z - f(z) / f'(z)\n",
"\n",
" We actually run the method for ln(z) which is\n",
"\n",
" ln(z) -> ln(z) - f(z) / f'(z) / z\n",
"\n",
" Typical use cases include:\n",
"\n",
" - alpha = 1/2 => R = Half-Light-Radius,\n",
" - alpha = 1 - folding-thresold => R used for stepk computation\n",
" \"\"\"\n",
" # seed the iteration with the Pade approximation to the HLR\n",
" # scaled by the fraction of flux to some power\n",
" zalpha = _spergel_hlr_pade(nu) * jnp.sqrt(alpha / 0.5)\n",
" return jnp.exp(jax.lax.fori_loop(\n",
" 0, 100,\n",
" _calculateFluxRadius_newtons_kernel,\n",
" (jnp.log(zalpha), alpha, nu),\n",
" )[0])\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "1be23e1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"eps, nu, log10(alpha): 1e-12 -0.84 -12.0\n",
"3.5138887102897e-38 -2.2121720121483927e-17 1.0587911840678754e-21 1.8761616702453412e-07 1.000534100015216e-12\n",
"\n",
"eps, nu, log10(alpha): 1e-12 3.999 -12.0\n",
"3.9966817649384216e-06 -1.576433954596703e-15 3.999832106175669e-06 3.1094518726606304e-16 9.984622740022494e-13\n",
"\n",
"eps, nu, log10(alpha): 1e-12 -0.84 -4.3428487456249e-13\n",
"25.572845945758726 0.0 25.572509765625 -3.3306690738754696e-16 0.999999999999\n",
"\n",
"eps, nu, log10(alpha): 1e-12 3.999 -4.3428487456249e-13\n",
"38.6677767503012 0.0 38.6676025390625 -1.1102230246251565e-16 0.999999999999\n",
"\n",
"eps, nu, log10(alpha): 0.1 -0.84 -1.0\n",
"0.0008333666650951336 6.38378239159465e-16 0.0008333666650951221 3.0531133177191805e-16 0.10000000000000096\n",
"\n",
"eps, nu, log10(alpha): 0.1 3.999 -1.0\n",
"1.3092245672406861 6.38378239159465e-16 1.3092245672406833 8.326672684688674e-17 0.10000000000000037\n",
"\n",
"eps, nu, log10(alpha): 0.1 -0.84 -0.045757490560675115\n",
"1.2147258941802845 0.0 1.214725894180284 -1.1102230246251565e-16 0.9000000000000001\n",
"\n",
"eps, nu, log10(alpha): 0.1 3.999 -0.045757490560675115\n",
"6.899340112339111 -1.1102230246251565e-16 6.899340112339113 -1.1102230246251565e-16 0.8999999999999999\n"
]
}
],
"source": [
"for eps in [1e-12, 0.1]:\n",
" for alpha in [eps, 1.0 - eps]:\n",
" for nu in [-0.84, 3.999]:\n",
"\n",
" print(\"\\neps, nu, log10(alpha):\", eps, nu, np.log10(alpha))\n",
" zfp = calculateFluxRadiusNewton(alpha, nu)\n",
" zbs = calculateFluxRadius(alpha, nu)\n",
" print(\n",
" zfp,\n",
" fluxfractionFunc(zfp, nu, alpha),\n",
" zbs,\n",
" fluxfractionFunc(zbs, nu, alpha),\n",
" galsim.Spergel(nu, scale_radius=1.0).calculateIntegratedFlux(zfp),\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "29cd9aa2",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "jax-galsim",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
2 changes: 1 addition & 1 deletion jax_galsim/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def _func(i, args):
flow = func(low)
fhigh = func(high)
args = (func, low, flow, high, fhigh)
return jax.lax.fori_loop(0, niter, _func, args)[-2]
return jax.lax.fori_loop(0, niter, _func, args, unroll=15)[-2]


# start of code from https://github.com/google/jax/blob/main/jax/_src/numpy/util.py #
Expand Down
45 changes: 41 additions & 4 deletions jax_galsim/fitswcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1052,10 +1052,47 @@ def _invert_ab_noraise(u, v, ab, abp=None):
dvdxcoef = (jnp.arange(nab)[:, None] * ab[1])[1:, :-1]
dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:]

for _ in range(10):
x, y, dx, dy = _invert_ab_noraise_loop_body(
x, y, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef
)
def _step(i, args):
x, y, _, _, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef = args

# Want Jac^-1 . du
# du
du = horner2d(x, y, ab[0], triangle=True) - u
dv = horner2d(x, y, ab[1], triangle=True) - v
# J
dudx = horner2d(x, y, dudxcoef, triangle=True)
dudy = horner2d(x, y, dudycoef, triangle=True)
dvdx = horner2d(x, y, dvdxcoef, triangle=True)
dvdy = horner2d(x, y, dvdycoef, triangle=True)
# J^-1 . du
det = dudx * dvdy - dudy * dvdx
duu = -(du * dvdy - dv * dudy) / det
dvv = -(-du * dvdx + dv * dudx) / det

x += duu
y += dvv

return x, y, duu, dvv, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef

x, y, dx, dy = jax.lax.fori_loop(
0,
10,
_step,
(
x,
y,
jnp.zeros_like(x),
jnp.zeros_like(y),
u,
v,
ab,
dudxcoef,
dudycoef,
dvdxcoef,
dvdycoef,
),
unroll=True,
)[0:4]

x, y = jax.lax.cond(
jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12,
Expand Down
13 changes: 9 additions & 4 deletions jax_galsim/moffat.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,10 +61,15 @@ def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=100):
nb2. In GalSim definition rm = 0 (ex. no truncated Moffat) means in reality rm=+Inf.
BUT the case rm==0 is already done, so HERE rm != 0
"""
xcur = re
for _ in range(nitr):
xcur = _bodymi(xcur, rm, re, beta)
return xcur

# fix loop iteration is faster and reach eps=1e-6 (single precision)
def body(i, xcur):
x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2
x = jnp.power(x, 1 / (1 - beta))
x = jnp.sqrt(x - 1)
return re / x

return jax.lax.fori_loop(0, 100, body, re, unroll=True)


@implements(_galsim.Moffat)
Expand Down
7 changes: 5 additions & 2 deletions jax_galsim/spergel.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@ def reducedfluxfractionFunc(z, nu, norm):


@jax.jit
def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0):
def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=40.0):
"""Return radius R enclosing flux fraction alpha in unit of the scale radius r0
Method: Solve F(R/r0=z)/Flux - alpha = 0 using bisection algorithm
Expand All @@ -186,7 +186,10 @@ def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0):
nb. it is supposed that nu is in [-0.85, 4.0] checked in the Spergel class init
"""
return bisect_for_root(
partial(fluxfractionFunc, nu=nu, alpha=alpha), zmin, zmax, niter=75
partial(fluxfractionFunc, nu=nu, alpha=alpha),
zmin,
zmax,
niter=75,
)


Expand Down
14 changes: 14 additions & 0 deletions tests/jax/test_benchmarks.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,3 +301,17 @@ def test_benchmark_moffat_init(benchmark, kind):
benchmark, kind, lambda: _run_benchmark_moffat_init().block_until_ready()
)
print(f"time: {dt:0.4g} ms", end=" ")


def _run_benchmark_spergel_calcfluxrad():
return jgs.spergel.calculateFluxRadius(1e-10, 2.0)


@pytest.mark.parametrize("kind", ["compile", "run"])
def test_benchmark_spergel_calcfluxrad(benchmark, kind):
dt = _run_benchmarks(
benchmark,
kind,
lambda: _run_benchmark_spergel_calcfluxrad().block_until_ready(),
)
print(f"time: {dt:0.4g} ms", end=" ")

0 comments on commit 84abf17

Please sign in to comment.