Skip to content

Commit 3703efd

Browse files
authored
Merge pull request #191 from Joshuaalbert/develop
Develop
2 parents a714cdb + 24332e8 commit 3703efd

File tree

91 files changed

+4603
-3572
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

91 files changed

+4603
-3572
lines changed

.github/workflows/unittests.yml

-3
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,4 @@ jobs:
3939
flake8 . --count --exit-zero --max-complexity=10 --max-line-length=127 --statistics
4040
- name: Test with pytest
4141
run: |
42-
set -o allexport
43-
source deployment/local.env
44-
set +o allexport
4542
pytest

README.md

+14-25
Original file line numberDiff line numberDiff line change
@@ -144,9 +144,9 @@ Given a probabilistic model, JAXNS can perform nested sampling on it. This allow
144144
posterior samples.
145145

146146
```python
147-
from jaxns import DefaultNestedSampler
147+
from jaxns import NestedSampler
148148

149-
ns = DefaultNestedSampler(model=model, max_samples=1e5)
149+
ns = NestedSampler(model=model, max_samples=1e5)
150150

151151
# Run the sampler
152152
termination_reason, state = ns(jax.random.PRNGKey(42))
@@ -290,7 +290,7 @@ Sampling paper](https://arxiv.org/abs/2312.11330).
290290

291291
```bash
292292
# To create a new env, if necessary
293-
conda create -n jaxns_py python=3.11
293+
conda create -n jaxns_py python=3.12
294294
conda activate jaxns_py
295295
```
296296

@@ -328,9 +328,8 @@ Checkout the examples [here](https://jaxns.readthedocs.io/en/latest/#).
328328

329329
## Caveats
330330

331-
The caveat is that you need to be able to define your likelihood function with JAX. This is usually no big deal because
332-
JAX is just a replacement for NumPy and many likelihoods can be expressed such.
333-
If you're unfamiliar, take a quick tour of JAX (https://jax.readthedocs.io/en/latest/notebooks/quickstart.html).
331+
The caveat is that you need to be able to define your likelihood function with JAX. UPDATE: now you can just
332+
use the `@jaxify_likelihood` decorator to run with arbitrary pythonic likelihoods.
334333

335334
# Speed test comparison with other nested sampling packages
336335

@@ -339,30 +338,20 @@ JAXNS is much faster than PolyChord, MultiNEST, and dynesty, typically achieving
339338
improvement in run time, for models with cheap likelihood evaluations.
340339
This is shown in (https://arxiv.org/abs/2012.15286).
341340

342-
Recently JAXNS has implemented Phantom-Powered Nested Sampling, which significantly reduces the number of required
343-
likelihood evaluations for inferring the posterior. This is shown in (https://arxiv.org/abs/2312.11330).
341+
Recently JAXNS has implemented Phantom-Powered Nested Sampling, which helps for parameter inference. This is shown
342+
in (https://arxiv.org/abs/2312.11330).
344343

345-
# Note on performance with parallelisation
344+
# Note on performance with parallelisation and GPUS
346345

347-
__Note, that this is an experimental feature.__
348-
349-
If you set `num_parallel_workers > 1` you will use `jax.pmap` under the hood for parallelisation.
350-
This is a very powerful feature, but it is important to understand how it works.
351-
It runs identical copies of the nested sampling algorithm on multiple devices.
352-
There is a two-part stopping condition.
353-
First, each copy goes until the user defined stopping condition is met __per device__.
354-
Then, it performs an all-gather and finds at the highest likelihood contour among all copies, and continues all copies
355-
hit this likelihood contour.
356-
This ensures consistency of depth across all copies.
357-
We then merge the copies and compute the final results.
358-
359-
The algorithm is fairly memory bound, so running parallelisation over multiple CPUs on the same machine may not yield
360-
the expected speed up, and depends on how expensive the likelihood evaluations are. Running over separate physical
361-
devices
362-
is the best way to achieve speed up.
346+
To use parallel computing, you can simply pass `devices` to the `NestedSampler` constructor. This will distributed
347+
sampling over the devices. To use GPUs you can pass `jax.devices('gpu')` to the `devices` argument. You can also se all
348+
your CPUs by placing `os.environ["XLA_FLAGS"] = f"--xla_force_host_platform_device_count={os.cpu_count()}"`
349+
before importing JAXNS.
363350

364351
# Change Log
365352

353+
24 Sep, 2024 -- JAXNS 2.6.1 released. Sharded parallel JAXNS. Rewrite of internals to support sharded parallelisation.
354+
366355
20 Aug, 2024 -- JAXNS 2.6.0 released. Removed haiku dependency. Implemented our own
367356
context. `jaxns.framework.context.convert_external_params` enables interfacing with any external NN libary.
368357

benchmarks/difficult_problems/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
import jax.numpy as jnp
66
import tensorflow_probability.substrates.jax as tfp
77

8-
from jaxns import Model, Prior, DefaultNestedSampler
8+
from jaxns import Model, Prior, NestedSampler
99

1010
tfpd = tfp.distributions
1111

@@ -152,7 +152,7 @@ def main():
152152
for model_name, model in all_models().items():
153153
print(f"Testing model {model_name}")
154154
model.sanity_check(jax.random.PRNGKey(0), 1000)
155-
ns = DefaultNestedSampler(model=model,
155+
ns = NestedSampler(model=model,
156156
max_samples=1000000,
157157
verbose=True,
158158
difficult_model=True,

benchmarks/gh117/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import tensorflow_probability.substrates.jax as tfp
55
from jax import random
66

7-
from jaxns import Model, Prior, DefaultNestedSampler
7+
from jaxns import Model, Prior, NestedSampler
88

99

1010
tfpd = tfp.distributions
@@ -23,7 +23,7 @@ def prior_model():
2323

2424

2525
# Create the nested sampler class. In this case without any tuning.
26-
exact_ns = DefaultNestedSampler(model=model, max_samples=max_samples)
26+
exact_ns = NestedSampler(model=model, max_samples=max_samples)
2727

2828
termination_reason, state = exact_ns(random.PRNGKey(42))
2929
return termination_reason

benchmarks/gh168/main.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import tensorflow_probability.substrates.jax as tfp
88
from jax._src.scipy.linalg import solve_triangular
99

10-
from jaxns import Model, Prior, DefaultNestedSampler
10+
from jaxns import Model, Prior, NestedSampler
1111

1212
tfpd = tfp.distributions
1313

@@ -47,7 +47,7 @@ def log_likelihood(x):
4747

4848
model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
4949

50-
ns = DefaultNestedSampler(model=model, max_samples=100000, verbose=False)
50+
ns = NestedSampler(model=model, max_samples=100000, verbose=False)
5151

5252
termination_reason, state = ns(key)
5353
results = ns.to_results(termination_reason=termination_reason, state=state, trim=False)

benchmarks/parallel_problems/main.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import tensorflow_probability.substrates.jax as tfp
1212
from jax._src.scipy.linalg import solve_triangular
1313

14-
from jaxns import Model, Prior, DefaultNestedSampler
14+
from jaxns import Model, Prior, NestedSampler, jaxify_likelihood
1515

1616
tfpd = tfp.distributions
1717

@@ -51,7 +51,7 @@ def log_likelihood(x):
5151

5252
model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
5353

54-
ns = DefaultNestedSampler(model=model, max_samples=100000, verbose=False, num_parallel_workers=len(jax.devices()))
54+
ns = NestedSampler(model=model, max_samples=100000, verbose=False)
5555

5656
termination_reason, state = ns(key)
5757
results = ns.to_results(termination_reason=termination_reason, state=state, trim=False)
@@ -61,7 +61,7 @@ def log_likelihood(x):
6161
def main():
6262
num_devices = len(jax.devices())
6363
jaxns_version = pkg_resources.get_distribution("jaxns").version
64-
m = 1
64+
m = 3
6565
run_model_aot = jax.jit(run_model).lower(jax.random.PRNGKey(0)).compile()
6666
dt = []
6767

@@ -70,14 +70,13 @@ def main():
7070

7171
for i in range(m):
7272
t0 = time.time()
73-
log_Z_error, log_Z_uncert = run_model_aot(jax.random.PRNGKey(i))
74-
log_Z_error.block_until_ready()
73+
log_Z_error, log_Z_uncert = jax.block_until_ready(run_model_aot(jax.random.PRNGKey(i)))
7574
t1 = time.time()
7675
dt.append(t1 - t0)
7776
errors.append(log_Z_error)
7877
uncerts.append(log_Z_uncert)
7978
total_time = sum(dt)
80-
best_3 = sum(sorted(dt)[:min(3, m)]) / 3.
79+
best_3 = sum(sorted(dt)[:3]) / 3.
8180
# print(f"Errors: {errors}")
8281
# print(f"Uncerts: {uncerts}")
8382
print(f"JAXNS {jaxns_version}\n"

benchmarks/parallel_problems/results

+3
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,5 @@
11
2.5.1,12,-0.8741127252578735,0.0003452669479884207,95.40295028686523,77.15287351608276
22
2.5.1,12,-1.30291748046875,0.32637646794319153,102.9092493057251,34.30308310190836
3+
2.3.4,12,0.09217196161439745,0.3709631743726239,2.116373300552368,0.7054577668507894
4+
2.6.0,6,-0.06817162536483465,0.3668122147028353,1.1894174337387085,1.0906640688578289
5+
2.6.0,12,-0.06817162536483465,0.3668122147028353,1.9961725234985352,1.7415425777435303

deployment/Dockerfile

-18
This file was deleted.

deployment/docker-compose.yaml

-12
This file was deleted.

deployment/launch-tests.sh

-10
This file was deleted.

deployment/local.env

-2
This file was deleted.

docs/conf.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
project = "jaxns"
1313
copyright = "2022, Joshua G. Albert"
1414
author = "Joshua G. Albert"
15-
release = "2.6.0"
15+
release = "2.6.1"
1616

1717
# -- General configuration ---------------------------------------------------
1818
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration

docs/examples/Jones_scalar_modelling.ipynb

+65-52
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)