Skip to content

Commit

Permalink
Merge pull request #116 from Joshuaalbert/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Joshuaalbert authored Dec 20, 2023
2 parents 6bc23a5 + 21fa6b9 commit 05bf627
Show file tree
Hide file tree
Showing 7 changed files with 427 additions and 327 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,8 @@ is the best way to achieve speed up.

# Change Log

21 Dec, 2023 -- JAXNS 2.3.4 released. Correction for ESS and logZ uncert. `parameter_estimation` mode.

20 Dec, 2023 -- JAXNS 2.3.2/3 released. Improved default parameters. `difficult_model` mode. Improve plotting.

18 Dec, 2023 -- JAXNS 2.3.1 released. Paper open science release. Default parameters from paper.
Expand Down
116 changes: 56 additions & 60 deletions docs/examples/gaussian_shells.ipynb

Large diffs are not rendered by default.

227 changes: 70 additions & 157 deletions docs/examples/mvn_data_mvn_prior.ipynb

Large diffs are not rendered by default.

385 changes: 283 additions & 102 deletions docs/papers/phantom-powered-nested-sampling/phantom_bias_tradeoff.ipynb

Large diffs are not rendered by default.

12 changes: 9 additions & 3 deletions jaxns/nested_sampler/standard_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -561,15 +561,23 @@ def _to_results(self, termination_reason: IntArray, state: StaticStandardNestedS
)
log_Z_uncert = jnp.sqrt(log_Z_var)

# Correction by sqrt(k+1)
total_phantom_samples = jnp.sum(sample_collection.phantom.astype(int_type))
phantom_fraction = total_phantom_samples / num_samples # k / (k+1)
k = phantom_fraction / (1. - phantom_fraction)
log_Z_uncert = log_Z_uncert * jnp.sqrt(1. + k)

# Kish's ESS = [sum dZ]^2 / [sum dZ^2]
ESS = effective_sample_size(final_evidence_stats.log_Z_mean, final_evidence_stats.log_dZ2_mean)
ESS = ESS / (1. + k)

samples = vmap(self.model.transform)(U_samples)

log_L_samples = log_L
dp_mean = LogSpace(per_sample_evidence_stats.log_dZ_mean)
dp_mean = normalise_log_space(dp_mean)
H_mean_instable = -((dp_mean * LogSpace(jnp.log(jnp.abs(log_L_samples)), jnp.sign(log_L_samples))).sum().value - log_Z_mean)
H_mean_instable = -((dp_mean * LogSpace(jnp.log(jnp.abs(log_L_samples)),
jnp.sign(log_L_samples))).sum().value - log_Z_mean)
# H \approx E[-log(compression)] = E[-log(X)] (More stable than E[log(L) - log(Z)]
H_mean_stable = -((dp_mean * LogSpace(jnp.log(-per_sample_evidence_stats.log_X_mean))).sum().value)
H_mean = jnp.where(jnp.isfinite(H_mean_instable), H_mean_instable, H_mean_stable)
Expand All @@ -582,8 +590,6 @@ def _to_results(self, termination_reason: IntArray, state: StaticStandardNestedS
log_posterior_density = log_L + vmap(self.model.log_prob_prior)(
U_samples)

total_phantom_samples = jnp.sum(sample_collection.phantom.astype(int_type))

return NestedSamplerResults(
log_Z_mean=log_Z_mean, # estimate of log(E[Z])
log_Z_uncert=log_Z_uncert, # estimate of log(StdDev[Z])
Expand Down
10 changes: 6 additions & 4 deletions jaxns/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,8 @@ class DefaultNestedSampler:
def __init__(self, model: BaseAbstractModel, max_samples: Union[int, float], num_live_points: Optional[int] = None,
s: Optional[int] = None, k: Optional[int] = None, c: Optional[int] = None,
num_parallel_workers: int = 1,
difficult_model: bool = False):
difficult_model: bool = False,
parameter_estimation: bool = False):
"""
Initialises the nested sampler.
Expand All @@ -50,14 +51,15 @@ def __init__(self, model: BaseAbstractModel, max_samples: Union[int, float], num
c: number of parallel Markov-chains to use. Defaults to 20 * D.
num_parallel_workers: number of parallel workers to use. Defaults to 1. Experimental feature.
difficult_model: if True, uses more robust default settings. Defaults to False.
parameter_estimation: if True, uses more robust default settings for parameter estimation. Defaults to False.
"""
if difficult_model:
self._s = 10 if s is None else int(s)
else:
self._s = 4 if s is None else int(s)
self._s = 5 if s is None else int(s)
if self._s <= 0:
raise ValueError(f"Expected s > 0, got s={self._s}")
if difficult_model:
if parameter_estimation:
self._k = model.U_ndims if k is None else int(k)
else:
self._k = model.U_ndims // 2 if k is None else int(k)
Expand All @@ -70,7 +72,7 @@ def __init__(self, model: BaseAbstractModel, max_samples: Union[int, float], num
if difficult_model:
self._c = 50 * model.U_ndims if c is None else int(c)
else:
self._c = 20 * model.U_ndims if c is None else int(c)
self._c = 30 * model.U_ndims if c is None else int(c)
if self._c <= 0:
raise ValueError(f"Expected c > 0, got c={self._c}")
# Sanity check for max_samples (should be able to at least do one shrinkage)
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
long_description = fh.read()

setup(name='jaxns',
version='2.3.3',
version='2.3.4',
description='Nested Sampling in JAX',
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down

0 comments on commit 05bf627

Please sign in to comment.