Skip to content

Commit

Permalink
Merge pull request #113 from Joshuaalbert/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Joshuaalbert authored Dec 20, 2023
2 parents 573b3bc + 7c3c818 commit 6bc23a5
Show file tree
Hide file tree
Showing 12 changed files with 477 additions and 372 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ is the best way to achieve speed up.

# Change Log

20 Dec, 2023 -- JAXNS 2.3.2/3 released. Improved default parameters. `difficult_model` mode. Improve plotting.

18 Dec, 2023 -- JAXNS 2.3.1 released. Paper open science release. Default parameters from paper.

11 Dec, 2023 -- JAXNS 2.3.0 released. Released of Phantom-Powered Nested Sampling algorithm.
Expand Down
190 changes: 72 additions & 118 deletions docs/examples/Jones_scalar_modelling.ipynb

Large diffs are not rendered by default.

107 changes: 53 additions & 54 deletions docs/examples/egg_box.ipynb

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
19 changes: 15 additions & 4 deletions jaxns/internals/log_semiring.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Union
from typing import Union, Literal

from jax import numpy as jnp, lax
from jax.scipy.special import logsumexp
Expand Down Expand Up @@ -180,7 +180,6 @@ def __repr__(self):
return f"LogSpace({self.log_abs_val})"
return f"LogSpace({self.log_abs_val}, {self.sign})"


def sum(self, axis=-1, keepdims=False):
if not self._naked: # no coefficients
return LogSpace(*logsumexp(self.log_abs_val, b=self.sign, axis=axis, keepdims=keepdims, return_sign=True))
Expand Down Expand Up @@ -351,11 +350,23 @@ def is_complex(a):
return a.dtype in [jnp.complex64, jnp.complex128]


def normalise_log_space(x: LogSpace) -> LogSpace:
def normalise_log_space(x: LogSpace, norm_type: Literal['sum', 'max'] = 'sum') -> LogSpace:
"""
Safely normalise a LogSpace, accounting for zero-sum.
Args:
x: LogSpace to normalise
norm_type: 'sum' or 'max' normalisation
Returns:
normalised LogSpace
"""
norm = x.sum()
if norm_type == 'sum':
norm = x.sum()
elif norm_type == 'max':
norm = x.max()
else:
raise ValueError(f"Unknown norm_type {norm_type}")
x /= norm
x = LogSpace(jnp.where(jnp.isneginf(norm.log_abs_val), -jnp.inf, x.log_abs_val))
return x
7 changes: 4 additions & 3 deletions jaxns/nested_sampler/standard_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -569,9 +569,10 @@ def _to_results(self, termination_reason: IntArray, state: StaticStandardNestedS
log_L_samples = log_L
dp_mean = LogSpace(per_sample_evidence_stats.log_dZ_mean)
dp_mean = normalise_log_space(dp_mean)
H_mean = LogSpace(jnp.where(jnp.isneginf(dp_mean.log_abs_val),
-jnp.inf,
dp_mean.log_abs_val + log_L_samples)).sum().value - log_Z_mean
H_mean_instable = -((dp_mean * LogSpace(jnp.log(jnp.abs(log_L_samples)), jnp.sign(log_L_samples))).sum().value - log_Z_mean)
# H \approx E[-log(compression)] = E[-log(X)] (More stable than E[log(L) - log(Z)]
H_mean_stable = -((dp_mean * LogSpace(jnp.log(-per_sample_evidence_stats.log_X_mean))).sum().value)
H_mean = jnp.where(jnp.isfinite(H_mean_instable), H_mean_instable, H_mean_stable)
X_mean = LogSpace(per_sample_evidence_stats.log_X_mean)
num_likelihood_evaluations_per_sample = num_likelihood_evaluations
total_num_likelihood_evaluations = jnp.sum(num_likelihood_evaluations_per_sample)
Expand Down
416 changes: 247 additions & 169 deletions jaxns/plotting.py

Large diffs are not rendered by default.

74 changes: 63 additions & 11 deletions jaxns/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import jax.numpy as jnp
import tensorflow_probability.substrates.jax as tfp
from jax import tree_map, core

from jaxns.framework.bases import BaseAbstractModel
from jaxns.internals.types import PRNGKey, IntArray, StaticStandardNestedSamplerState, TerminationCondition, \
Expand Down Expand Up @@ -32,22 +33,49 @@ class DefaultNestedSampler:
"""

def __init__(self, model: BaseAbstractModel, max_samples: Union[int, float], num_live_points: Optional[int] = None,
num_parallel_workers: int = 1):
s: Optional[int] = None, k: Optional[int] = None, c: Optional[int] = None,
num_parallel_workers: int = 1,
difficult_model: bool = False):
"""
Initialises the nested sampler.
s,k,c are defined in the paper: https://arxiv.org/abs/2312.11330
Args:
model: a model to perform nested sampling on
max_samples: maximum number of samples to take
num_live_points: number of live points to use. Defaults is 20 * D * (D/2 + 1).
num_parallel_workers: number of parallel workers to use. Defaults to 1.
"""
self._k = model.U_ndims // 2
self._s = 4
num_live_points: approximate number of live points to use. Defaults is c * (k + 1).
s: number of slices to use per dimension. Defaults to 4.
k: number of phantom samples to use. Defaults to D/2.
c: number of parallel Markov-chains to use. Defaults to 20 * D.
num_parallel_workers: number of parallel workers to use. Defaults to 1. Experimental feature.
difficult_model: if True, uses more robust default settings. Defaults to False.
"""
if difficult_model:
self._s = 10 if s is None else int(s)
else:
self._s = 4 if s is None else int(s)
if self._s <= 0:
raise ValueError(f"Expected s > 0, got s={self._s}")
if difficult_model:
self._k = model.U_ndims if k is None else int(k)
else:
self._k = model.U_ndims // 2 if k is None else int(k)
if not (0 <= self._k < self._s * model.U_ndims):
raise ValueError(f"Expected 0 <= k < s * U_ndims, got k={self._k}, s={self._s}, U_ndims={model.U_ndims}")
if num_live_points is not None:
self._c = max(1, int(num_live_points / (self._k + 1)))
logger.info(f"Number of parallel Markov-chains set to: {self._c}")
else:
self._c = 20 * model.U_ndims
if difficult_model:
self._c = 50 * model.U_ndims if c is None else int(c)
else:
self._c = 20 * model.U_ndims if c is None else int(c)
if self._c <= 0:
raise ValueError(f"Expected c > 0, got c={self._c}")
# Sanity check for max_samples (should be able to at least do one shrinkage)
if max_samples < self._c * (self._k + 1):
logger.warning(f"max_samples={max_samples} is likely too small!")
self._nested_sampler = StandardStaticNestedSampler(
model=model,
num_live_points=self._c,
Expand Down Expand Up @@ -86,17 +114,18 @@ def summary(self, results: NestedSamplerResults) -> str:
"""
return summary(results)

def plot_cornerplot(self, results: NestedSamplerResults, vars: Optional[List[str]] = None,
save_name: Optional[str] = None):
def plot_cornerplot(self, results: NestedSamplerResults, variables: Optional[List[str]] = None,
save_name: Optional[str] = None, kde_overlay: bool = False):
"""
Plots a corner plot of the samples.
Args:
results: results of the nested sampling run
vars: variables to plot. If not given, defaults to all variables.
variables: variables to plot. If not given, defaults to all variables.
save_name: if given, saves the plot to the given file name
kde_overlay: if True, overlays a KDE plot on the 1D histograms
"""
plot_cornerplot(results, vars=vars, save_name=save_name)
plot_cornerplot(results, variables=variables, save_name=save_name, kde_overlay=kde_overlay)

def plot_diagnostics(self, results: NestedSamplerResults, save_name: Optional[str] = None):
"""
Expand Down Expand Up @@ -173,6 +202,29 @@ def to_results(self, termination_reason: IntArray, state: StaticStandardNestedSa
trim=trim
)

@staticmethod
def trim_results(results: NestedSamplerResults) -> NestedSamplerResults:
"""
Trims the results to the number of samples taken. Requires static context.
Args:
results: results to trim
Returns:
trimmed results
"""

if isinstance(results.total_num_samples, core.Tracer):
raise RuntimeError("Tracer detected, but expected imperative context.")

def trim(x):
if x.size > 1:
return x[:results.total_num_samples]
return x

results = tree_map(trim, results)
return results


class ApproximateNestedSampler(DefaultNestedSampler):
def __init__(self, *args, **kwargs):
Expand Down
10 changes: 10 additions & 0 deletions jaxns/tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
from jax import random, numpy as jnp

from jaxns.plotting import weighted_percentile
from jaxns.utils import resample, _bit_mask


Expand All @@ -14,3 +16,11 @@ def test_bit_mask():
assert _bit_mask(1, width=2) == [1, 0]
assert _bit_mask(2, width=2) == [0, 1]
assert _bit_mask(3, width=2) == [1, 1]


def test_weighted_percentile():
# Test the weighted percentile function
samples = np.asarray([1, 2, 3, 4, 5])
log_weights = np.asarray([0, 0, 0, 0, 0])
percentiles = [50]
assert np.allclose(weighted_percentile(samples, log_weights, percentiles), 3.0)
20 changes: 10 additions & 10 deletions jaxns/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,9 @@ def _round(v, uncert_v):
if bit == 1:
_print(condition)
_print("--------")
_print(f"likelihood evals: {results.total_num_likelihood_evaluations}")
_print(f"samples: {results.total_num_samples}")
_print(f"phantom samples: {float(results.total_phantom_samples):.1f}")
_print(f"likelihood evals: {int(results.total_num_likelihood_evaluations):d}")
_print(f"samples: {int(results.total_num_samples):d}")
_print(f"phantom samples: {int(results.total_phantom_samples):d}")
_print(
f"likelihood evals / sample: {float(results.total_num_likelihood_evaluations / results.total_num_samples):.1f}"
)
Expand All @@ -327,17 +327,17 @@ def _round(v, uncert_v):
# _print("H={} +- {}".format(
# _round(results.H_mean, results.H_uncert), _round(results.H_uncert, results.H_uncert)))
_print(
f"H={_round(results.H_mean, results.H_mean)}"
f"H={_round(results.H_mean, 0.1)}"
)
_print(
f"ESS={float(results.ESS)}"
f"ESS={int(results.ESS):d}"
)
max_like_idx = jnp.argmax(results.log_L_samples)
max_like_idx = np.argmax(results.log_L_samples)
max_like_points = tree_map(lambda x: x[max_like_idx], results.samples)
samples = resample(random.PRNGKey(23426), results.samples, results.log_dp_mean, S=max(10, int(results.ESS)),
replace=True)

max_map_idx = jnp.argmax(results.log_posterior_density)
max_map_idx = np.argmax(results.log_posterior_density)
map_points = tree_map(lambda x: x[max_map_idx], results.samples)

for name in samples.keys():
Expand All @@ -351,7 +351,7 @@ def _round(v, uncert_v):
f"{var_name}: mean +- std.dev. | 10%ile / 50%ile / 90%ile | MAP est. | max(L) est."
)
for dim in range(ndims):
_uncert = jnp.std(_samples[:, dim])
_uncert = np.std(_samples[:, dim])
_max_like_point = _max_like_points[dim]
_map_point = _map_points[dim]
# two sig-figs based on uncert
Expand All @@ -363,8 +363,8 @@ def _round(ar):
_uncert = _round(_uncert)
_print("{}: {} +- {} | {} / {} / {} | {} | {}".format(
name if ndims == 1 else "{}[{}]".format(name, dim),
_round(jnp.mean(_samples[:, dim])), _uncert,
*[_round(a) for a in jnp.percentile(_samples[:, dim], jnp.asarray([10, 50, 90]))],
_round(np.mean(_samples[:, dim])), _uncert,
*[_round(a) for a in np.percentile(_samples[:, dim], np.asarray([10, 50, 90]))],
_round(_map_point),
_round(_max_like_point)
))
Expand Down
1 change: 0 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ chex
typing_extensions
matplotlib
numpy
pytest
scipy
tensorflow_probability
tqdm
Expand Down
3 changes: 1 addition & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
'typing_extensions',
'matplotlib',
'numpy',
'pytest',
'scipy',
'tensorflow_probability',
'tqdm',
Expand All @@ -22,7 +21,7 @@
long_description = fh.read()

setup(name='jaxns',
version='2.3.1',
version='2.3.3',
description='Nested Sampling in JAX',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 6bc23a5

Please sign in to comment.