-
Notifications
You must be signed in to change notification settings - Fork 4
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
perf: try using fori_loop but unrolling everything (#141)
* perf: try using fori_loop but unrolling everything * perf: unroll more * fix: cannot unroll discard since n is not static * test: try using -v to see which test fails * test: try this * test: reduce matrix * Update utils.py * perf: try partial loop unrolling * test: add spergel init benchmark * dev: try a newton fixed-point method * style: pre the commit
- Loading branch information
Showing
8 changed files
with
249 additions
and
20 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,174 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "d7b8bc37-8799-433c-9399-de95a21a1727", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import jax\n", | ||
"jax.config.update(\"jax_enable_x64\", True)\n", | ||
"\n", | ||
"import galsim\n", | ||
"import numpy as np\n", | ||
"\n", | ||
"import jax.numpy as jnp" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "774101b1", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from jax_galsim.spergel import (\n", | ||
" fz_nup1, _gammap1, _spergel_hlr_pade,\n", | ||
" fluxfractionFunc, fz_nu, calculateFluxRadius,\n", | ||
")\n", | ||
"\n", | ||
"@jax.jit\n", | ||
"def _calculateFluxRadius_newtons_kernel(i, args):\n", | ||
" \"\"\"Newton's method kernel for calculateFluxRadius\n", | ||
"\n", | ||
" Returns\n", | ||
"\n", | ||
" lnz - fluxfractionFunc(z, nu, alpha) / dfluxfractionFunc(z, nu, alpha)_dz / z\n", | ||
"\n", | ||
" which is Newton's kernel but in log space.\n", | ||
" \"\"\"\n", | ||
" lnz, alpha, nu = args\n", | ||
" z = jnp.exp(lnz)\n", | ||
" dn = (jnp.power(2.0, nu) * _gammap1(nu))\n", | ||
" fz = 1.0 - fz_nup1(z, nu) / dn - alpha\n", | ||
" dfzdz = z * fz_nu(z, nu) / dn\n", | ||
"\n", | ||
" # we clip the result to avoid numerical issues near bounds\n", | ||
" lnz = jnp.clip(\n", | ||
" lnz - fz / dfzdz / z,\n", | ||
" min=-100,\n", | ||
" max=100,\n", | ||
" )\n", | ||
"\n", | ||
" return lnz, alpha, nu\n", | ||
"\n", | ||
"\n", | ||
"@jax.jit\n", | ||
"def calculateFluxRadiusNewton(alpha, nu):\n", | ||
" \"\"\"Return radius R enclosing flux fraction alpha in unit of the scale radius r0\n", | ||
"\n", | ||
" Method: Solve F(R/r0=z)/Flux - alpha = 0 using Netwon's method\n", | ||
"\n", | ||
" We can integrate the profile to get\n", | ||
"\n", | ||
" F(R)/F = int( 1/(2^nu Gamma(nu+1)) (r/r0)^(nu+1) K_nu(r/r0) dr/r0; r=0..R) = alpha\n", | ||
"\n", | ||
" So if we define z = R/r0 and f(z) = F(z * r0)/F - alpha, then Newton's method is\n", | ||
"\n", | ||
" z -> z - f(z) / f'(z)\n", | ||
"\n", | ||
" We actually run the method for ln(z) which is\n", | ||
"\n", | ||
" ln(z) -> ln(z) - f(z) / f'(z) / z\n", | ||
"\n", | ||
" Typical use cases include:\n", | ||
"\n", | ||
" - alpha = 1/2 => R = Half-Light-Radius,\n", | ||
" - alpha = 1 - folding-thresold => R used for stepk computation\n", | ||
" \"\"\"\n", | ||
" # seed the iteration with the Pade approximation to the HLR\n", | ||
" # scaled by the fraction of flux to some power\n", | ||
" zalpha = _spergel_hlr_pade(nu) * jnp.sqrt(alpha / 0.5)\n", | ||
" return jnp.exp(jax.lax.fori_loop(\n", | ||
" 0, 100,\n", | ||
" _calculateFluxRadius_newtons_kernel,\n", | ||
" (jnp.log(zalpha), alpha, nu),\n", | ||
" )[0])\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "1be23e1b", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"\n", | ||
"eps, nu, log10(alpha): 1e-12 -0.84 -12.0\n", | ||
"3.5138887102897e-38 -2.2121720121483927e-17 1.0587911840678754e-21 1.8761616702453412e-07 1.000534100015216e-12\n", | ||
"\n", | ||
"eps, nu, log10(alpha): 1e-12 3.999 -12.0\n", | ||
"3.9966817649384216e-06 -1.576433954596703e-15 3.999832106175669e-06 3.1094518726606304e-16 9.984622740022494e-13\n", | ||
"\n", | ||
"eps, nu, log10(alpha): 1e-12 -0.84 -4.3428487456249e-13\n", | ||
"25.572845945758726 0.0 25.572509765625 -3.3306690738754696e-16 0.999999999999\n", | ||
"\n", | ||
"eps, nu, log10(alpha): 1e-12 3.999 -4.3428487456249e-13\n", | ||
"38.6677767503012 0.0 38.6676025390625 -1.1102230246251565e-16 0.999999999999\n", | ||
"\n", | ||
"eps, nu, log10(alpha): 0.1 -0.84 -1.0\n", | ||
"0.0008333666650951336 6.38378239159465e-16 0.0008333666650951221 3.0531133177191805e-16 0.10000000000000096\n", | ||
"\n", | ||
"eps, nu, log10(alpha): 0.1 3.999 -1.0\n", | ||
"1.3092245672406861 6.38378239159465e-16 1.3092245672406833 8.326672684688674e-17 0.10000000000000037\n", | ||
"\n", | ||
"eps, nu, log10(alpha): 0.1 -0.84 -0.045757490560675115\n", | ||
"1.2147258941802845 0.0 1.214725894180284 -1.1102230246251565e-16 0.9000000000000001\n", | ||
"\n", | ||
"eps, nu, log10(alpha): 0.1 3.999 -0.045757490560675115\n", | ||
"6.899340112339111 -1.1102230246251565e-16 6.899340112339113 -1.1102230246251565e-16 0.8999999999999999\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"for eps in [1e-12, 0.1]:\n", | ||
" for alpha in [eps, 1.0 - eps]:\n", | ||
" for nu in [-0.84, 3.999]:\n", | ||
"\n", | ||
" print(\"\\neps, nu, log10(alpha):\", eps, nu, np.log10(alpha))\n", | ||
" zfp = calculateFluxRadiusNewton(alpha, nu)\n", | ||
" zbs = calculateFluxRadius(alpha, nu)\n", | ||
" print(\n", | ||
" zfp,\n", | ||
" fluxfractionFunc(zfp, nu, alpha),\n", | ||
" zbs,\n", | ||
" fluxfractionFunc(zbs, nu, alpha),\n", | ||
" galsim.Spergel(nu, scale_radius=1.0).calculateIntegratedFlux(zfp),\n", | ||
" )" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "29cd9aa2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "jax-galsim", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.10.15" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters