diff --git a/.github/workflows/python_package.yaml b/.github/workflows/python_package.yaml index a7e21ae5..2ae777ce 100644 --- a/.github/workflows/python_package.yaml +++ b/.github/workflows/python_package.yaml @@ -20,8 +20,7 @@ jobs: strategy: fail-fast: false matrix: - python-version: ["3.10", "3.11", "3.12"] - group: [1, 2, 3, 4] + python-version: ["3.12"] steps: - uses: actions/checkout@v4 @@ -34,7 +33,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest pytest-codspeed pytest-split pytest-randomly + python -m pip install pytest pytest-codspeed pytest-randomly python -m pip install . # temp pin until 0.5 is on conda python -m pip install "jax<0.5.0" @@ -42,10 +41,7 @@ jobs: - name: Test with pytest run: | git submodule update --init --recursive - pytest -v --durations=0 \ - -k "not test_fpack" \ - --randomly-seed=42 \ - --splits=4 --group=${{ matrix.group }} --splitting-algorithm least_duration + pytest -vv --durations=100 --randomly-seed=42 build-status: needs: build diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 549ec299..bd34319c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -15,6 +15,6 @@ repos: hooks: - id: ruff args: [ --fix ] - exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/ + exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/|dev/notebooks/ - id: ruff-format - exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/ + exclude: tests/GalSim/|tests/Coord/|tests/jax/galsim/|dev/notebooks/ diff --git a/dev/notebooks/spergel_fixed_point.ipynb b/dev/notebooks/spergel_fixed_point.ipynb new file mode 100644 index 00000000..c9dbc625 --- /dev/null +++ b/dev/notebooks/spergel_fixed_point.ipynb @@ -0,0 +1,174 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "d7b8bc37-8799-433c-9399-de95a21a1727", + "metadata": {}, + "outputs": [], + "source": [ + "import jax\n", + "jax.config.update(\"jax_enable_x64\", True)\n", + "\n", + "import galsim\n", + "import numpy as np\n", + "\n", + "import jax.numpy as jnp" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "774101b1", + "metadata": {}, + "outputs": [], + "source": [ + "from jax_galsim.spergel import (\n", + " fz_nup1, _gammap1, _spergel_hlr_pade,\n", + " fluxfractionFunc, fz_nu, calculateFluxRadius,\n", + ")\n", + "\n", + "@jax.jit\n", + "def _calculateFluxRadius_newtons_kernel(i, args):\n", + " \"\"\"Newton's method kernel for calculateFluxRadius\n", + "\n", + " Returns\n", + "\n", + " lnz - fluxfractionFunc(z, nu, alpha) / dfluxfractionFunc(z, nu, alpha)_dz / z\n", + "\n", + " which is Newton's kernel but in log space.\n", + " \"\"\"\n", + " lnz, alpha, nu = args\n", + " z = jnp.exp(lnz)\n", + " dn = (jnp.power(2.0, nu) * _gammap1(nu))\n", + " fz = 1.0 - fz_nup1(z, nu) / dn - alpha\n", + " dfzdz = z * fz_nu(z, nu) / dn\n", + "\n", + " # we clip the result to avoid numerical issues near bounds\n", + " lnz = jnp.clip(\n", + " lnz - fz / dfzdz / z,\n", + " min=-100,\n", + " max=100,\n", + " )\n", + "\n", + " return lnz, alpha, nu\n", + "\n", + "\n", + "@jax.jit\n", + "def calculateFluxRadiusNewton(alpha, nu):\n", + " \"\"\"Return radius R enclosing flux fraction alpha in unit of the scale radius r0\n", + "\n", + " Method: Solve F(R/r0=z)/Flux - alpha = 0 using Netwon's method\n", + "\n", + " We can integrate the profile to get\n", + "\n", + " F(R)/F = int( 1/(2^nu Gamma(nu+1)) (r/r0)^(nu+1) K_nu(r/r0) dr/r0; r=0..R) = alpha\n", + "\n", + " So if we define z = R/r0 and f(z) = F(z * r0)/F - alpha, then Newton's method is\n", + "\n", + " z -> z - f(z) / f'(z)\n", + "\n", + " We actually run the method for ln(z) which is\n", + "\n", + " ln(z) -> ln(z) - f(z) / f'(z) / z\n", + "\n", + " Typical use cases include:\n", + "\n", + " - alpha = 1/2 => R = Half-Light-Radius,\n", + " - alpha = 1 - folding-thresold => R used for stepk computation\n", + " \"\"\"\n", + " # seed the iteration with the Pade approximation to the HLR\n", + " # scaled by the fraction of flux to some power\n", + " zalpha = _spergel_hlr_pade(nu) * jnp.sqrt(alpha / 0.5)\n", + " return jnp.exp(jax.lax.fori_loop(\n", + " 0, 100,\n", + " _calculateFluxRadius_newtons_kernel,\n", + " (jnp.log(zalpha), alpha, nu),\n", + " )[0])\n" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "1be23e1b", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "eps, nu, log10(alpha): 1e-12 -0.84 -12.0\n", + "3.5138887102897e-38 -2.2121720121483927e-17 1.0587911840678754e-21 1.8761616702453412e-07 1.000534100015216e-12\n", + "\n", + "eps, nu, log10(alpha): 1e-12 3.999 -12.0\n", + "3.9966817649384216e-06 -1.576433954596703e-15 3.999832106175669e-06 3.1094518726606304e-16 9.984622740022494e-13\n", + "\n", + "eps, nu, log10(alpha): 1e-12 -0.84 -4.3428487456249e-13\n", + "25.572845945758726 0.0 25.572509765625 -3.3306690738754696e-16 0.999999999999\n", + "\n", + "eps, nu, log10(alpha): 1e-12 3.999 -4.3428487456249e-13\n", + "38.6677767503012 0.0 38.6676025390625 -1.1102230246251565e-16 0.999999999999\n", + "\n", + "eps, nu, log10(alpha): 0.1 -0.84 -1.0\n", + "0.0008333666650951336 6.38378239159465e-16 0.0008333666650951221 3.0531133177191805e-16 0.10000000000000096\n", + "\n", + "eps, nu, log10(alpha): 0.1 3.999 -1.0\n", + "1.3092245672406861 6.38378239159465e-16 1.3092245672406833 8.326672684688674e-17 0.10000000000000037\n", + "\n", + "eps, nu, log10(alpha): 0.1 -0.84 -0.045757490560675115\n", + "1.2147258941802845 0.0 1.214725894180284 -1.1102230246251565e-16 0.9000000000000001\n", + "\n", + "eps, nu, log10(alpha): 0.1 3.999 -0.045757490560675115\n", + "6.899340112339111 -1.1102230246251565e-16 6.899340112339113 -1.1102230246251565e-16 0.8999999999999999\n" + ] + } + ], + "source": [ + "for eps in [1e-12, 0.1]:\n", + " for alpha in [eps, 1.0 - eps]:\n", + " for nu in [-0.84, 3.999]:\n", + "\n", + " print(\"\\neps, nu, log10(alpha):\", eps, nu, np.log10(alpha))\n", + " zfp = calculateFluxRadiusNewton(alpha, nu)\n", + " zbs = calculateFluxRadius(alpha, nu)\n", + " print(\n", + " zfp,\n", + " fluxfractionFunc(zfp, nu, alpha),\n", + " zbs,\n", + " fluxfractionFunc(zbs, nu, alpha),\n", + " galsim.Spergel(nu, scale_radius=1.0).calculateIntegratedFlux(zfp),\n", + " )" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "29cd9aa2", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "jax-galsim", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.15" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/jax_galsim/core/utils.py b/jax_galsim/core/utils.py index ce3222d4..c4cc5413 100644 --- a/jax_galsim/core/utils.py +++ b/jax_galsim/core/utils.py @@ -254,7 +254,7 @@ def _func(i, args): flow = func(low) fhigh = func(high) args = (func, low, flow, high, fhigh) - return jax.lax.fori_loop(0, niter, _func, args)[-2] + return jax.lax.fori_loop(0, niter, _func, args, unroll=15)[-2] # start of code from https://github.com/google/jax/blob/main/jax/_src/numpy/util.py # diff --git a/jax_galsim/fitswcs.py b/jax_galsim/fitswcs.py index 6b85517b..5870d755 100644 --- a/jax_galsim/fitswcs.py +++ b/jax_galsim/fitswcs.py @@ -1052,10 +1052,47 @@ def _invert_ab_noraise(u, v, ab, abp=None): dvdxcoef = (jnp.arange(nab)[:, None] * ab[1])[1:, :-1] dvdycoef = (jnp.arange(nab) * ab[1])[:-1, 1:] - for _ in range(10): - x, y, dx, dy = _invert_ab_noraise_loop_body( - x, y, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef - ) + def _step(i, args): + x, y, _, _, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef = args + + # Want Jac^-1 . du + # du + du = horner2d(x, y, ab[0], triangle=True) - u + dv = horner2d(x, y, ab[1], triangle=True) - v + # J + dudx = horner2d(x, y, dudxcoef, triangle=True) + dudy = horner2d(x, y, dudycoef, triangle=True) + dvdx = horner2d(x, y, dvdxcoef, triangle=True) + dvdy = horner2d(x, y, dvdycoef, triangle=True) + # J^-1 . du + det = dudx * dvdy - dudy * dvdx + duu = -(du * dvdy - dv * dudy) / det + dvv = -(-du * dvdx + dv * dudx) / det + + x += duu + y += dvv + + return x, y, duu, dvv, u, v, ab, dudxcoef, dudycoef, dvdxcoef, dvdycoef + + x, y, dx, dy = jax.lax.fori_loop( + 0, + 10, + _step, + ( + x, + y, + jnp.zeros_like(x), + jnp.zeros_like(y), + u, + v, + ab, + dudxcoef, + dudycoef, + dvdxcoef, + dvdycoef, + ), + unroll=True, + )[0:4] x, y = jax.lax.cond( jnp.maximum(jnp.max(jnp.abs(dx)), jnp.max(jnp.abs(dy))) > 2e-12, diff --git a/jax_galsim/moffat.py b/jax_galsim/moffat.py index 0563ba78..fe596398 100644 --- a/jax_galsim/moffat.py +++ b/jax_galsim/moffat.py @@ -61,10 +61,15 @@ def _MoffatCalculateSRFromHLR(re, rm, beta, nitr=100): nb2. In GalSim definition rm = 0 (ex. no truncated Moffat) means in reality rm=+Inf. BUT the case rm==0 is already done, so HERE rm != 0 """ - xcur = re - for _ in range(nitr): - xcur = _bodymi(xcur, rm, re, beta) - return xcur + + # fix loop iteration is faster and reach eps=1e-6 (single precision) + def body(i, xcur): + x = (1 + jnp.power(1 + (rm / xcur) ** 2, 1 - beta)) / 2 + x = jnp.power(x, 1 / (1 - beta)) + x = jnp.sqrt(x - 1) + return re / x + + return jax.lax.fori_loop(0, 100, body, re, unroll=True) @implements(_galsim.Moffat) diff --git a/jax_galsim/spergel.py b/jax_galsim/spergel.py index ba6493f5..b0bf22a8 100644 --- a/jax_galsim/spergel.py +++ b/jax_galsim/spergel.py @@ -167,7 +167,7 @@ def reducedfluxfractionFunc(z, nu, norm): @jax.jit -def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0): +def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=40.0): """Return radius R enclosing flux fraction alpha in unit of the scale radius r0 Method: Solve F(R/r0=z)/Flux - alpha = 0 using bisection algorithm @@ -186,7 +186,10 @@ def calculateFluxRadius(alpha, nu, zmin=0.0, zmax=30.0): nb. it is supposed that nu is in [-0.85, 4.0] checked in the Spergel class init """ return bisect_for_root( - partial(fluxfractionFunc, nu=nu, alpha=alpha), zmin, zmax, niter=75 + partial(fluxfractionFunc, nu=nu, alpha=alpha), + zmin, + zmax, + niter=75, ) diff --git a/tests/jax/test_benchmarks.py b/tests/jax/test_benchmarks.py index f73d569a..4d35e8a7 100644 --- a/tests/jax/test_benchmarks.py +++ b/tests/jax/test_benchmarks.py @@ -301,3 +301,17 @@ def test_benchmark_moffat_init(benchmark, kind): benchmark, kind, lambda: _run_benchmark_moffat_init().block_until_ready() ) print(f"time: {dt:0.4g} ms", end=" ") + + +def _run_benchmark_spergel_calcfluxrad(): + return jgs.spergel.calculateFluxRadius(1e-10, 2.0) + + +@pytest.mark.parametrize("kind", ["compile", "run"]) +def test_benchmark_spergel_calcfluxrad(benchmark, kind): + dt = _run_benchmarks( + benchmark, + kind, + lambda: _run_benchmark_spergel_calcfluxrad().block_until_ready(), + ) + print(f"time: {dt:0.4g} ms", end=" ")