Skip to content

Commit

Permalink
Merge pull request #170 from Joshuaalbert/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Joshuaalbert authored May 27, 2024
2 parents 37019fe + 8ab2230 commit 0bf9fc1
Show file tree
Hide file tree
Showing 13 changed files with 188 additions and 48 deletions.
6 changes: 4 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,13 +363,15 @@ is the best way to achieve speed up.

# Change Log

27 May, 2024 -- JAXS 2.5.1 released. Fixed minor accuracy degradation introduced in 2.4.13.

15 May, 2024 -- JAXNS 2.5.0 released. Added ability to handle non-JAX likelihoods, e.g. if you have a simulation
framework with python bindings you can now use it for likelihoods in JAXNS. Small performance improvements.

22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of
likelihod.
likelihood.

20 Mar, 2024 -- JAXNS 2.4.12 released. Minor bug fixes, and readability improvements. Added Empirial special prior.
20 Mar, 2024 -- JAXNS 2.4.12 released. Minor bug fixes, and readability improvements. Added Empirical special prior.

5 Mar, 2024 -- JAXNS 2.4.11/b released. Add `random_init` to parametrised variables. Enable special priors to be
parametrised.
Expand Down
23 changes: 6 additions & 17 deletions benchmarks/gh117.py → benchmarks/gh117/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
import tensorflow_probability.substrates.jax as tfp
from jax import random

from jaxns import Model, Prior
from jaxns import Model, Prior, DefaultNestedSampler


tfpd = tfp.distributions

Expand All @@ -20,7 +21,6 @@ def prior_model():
model = Model(prior_model=prior_model,
log_likelihood=log_likelihood)

from jaxns import DefaultNestedSampler

# Create the nested sampler class. In this case without any tuning.
exact_ns = DefaultNestedSampler(model=model, max_samples=max_samples)
Expand All @@ -29,7 +29,7 @@ def prior_model():
return termination_reason


def performance_benchmark():
def main():
max_samples = int(1e7)
m = 10
run_model_aot = jax.jit(lambda: run_model(max_samples=max_samples)).lower().compile()
Expand All @@ -47,25 +47,14 @@ def performance_benchmark():
print(f"The best 3 of {m} runs took {best_3:.5f} seconds.")


# _inter_sync_shrinkage_process unroll=1
# get_sample_from_seed unroll=1
# Avg. time taken: 4.79353 seconds.
# The best 3 of 10 runs took 4.63075 seconds.

# _inter_sync_shrinkage_process unroll=2
# get_sample_from_seed unroll=1
# Avg. time taken: 5.04382 seconds.
# The best 3 of 10 runs took 4.74833 seconds.

# _inter_sync_shrinkage_process unroll=1
# get_sample_from_seed unroll=2
# Before fix
# Avg. time taken: 4.40303 seconds.
# The best 3 of 10 runs took 4.37935 seconds.

# With fix
# After fix
# Avg. time taken: 0.00562 seconds.
# The best 3 of 10 runs took 0.00478 seconds.


if __name__ == '__main__':
performance_benchmark()
main()
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
declare -a jaxns_versions=("2.3.0" "2.3.1" "2.3.2" "2.3.4" "2.4.0" "2.4.1")

# Path to your benchmark script
benchmark_script="gh117.py"
benchmark_script="main.py"

# Name for the conda environment
conda_env_name="jaxns_benchmarks_env"
Expand Down
93 changes: 93 additions & 0 deletions benchmarks/gh168/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
import time

import jax
import jax.numpy as jnp
import numpy as np
import pkg_resources
import tensorflow_probability.substrates.jax as tfp
from jax._src.scipy.linalg import solve_triangular

from jaxns import Model, Prior, DefaultNestedSampler

tfpd = tfp.distributions


def run_model(key):
def log_normal(x, mean, cov):
L = jnp.linalg.cholesky(cov)
dx = x - mean
dx = solve_triangular(L, dx, lower=True)
return -0.5 * x.size * jnp.log(2. * jnp.pi) - jnp.sum(jnp.log(jnp.diag(L))) \
- 0.5 * dx @ dx

ndims = 8
prior_mu = 15 * jnp.ones(ndims)
prior_cov = jnp.diag(jnp.ones(ndims)) ** 2

data_mu = jnp.zeros(ndims)
data_cov = jnp.diag(jnp.ones(ndims)) ** 2
data_cov = jnp.where(data_cov == 0., 0.99, data_cov)

log_Z_true = log_normal(data_mu, prior_mu, prior_cov + data_cov)
# not super happy with this being 1.58 and being off by like 0.1. Probably related to the ESS.
post_mu = prior_cov @ jnp.linalg.inv(prior_cov + data_cov) @ data_mu + data_cov @ jnp.linalg.inv(
prior_cov + data_cov) @ prior_mu

# print(f"True post mu:{post_mu}")
# print(f"True log Z: {log_Z_true}")

def prior_model():
x = yield Prior(
tfpd.MultivariateNormalTriL(loc=prior_mu, scale_tril=jnp.linalg.cholesky(prior_cov)),
name='x')
return x

def log_likelihood(x):
return tfpd.MultivariateNormalTriL(loc=data_mu, scale_tril=jnp.linalg.cholesky(data_cov)).log_prob(x)

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)

ns = DefaultNestedSampler(model=model, max_samples=100000, verbose=False)

termination_reason, state = ns(key)
results = ns.to_results(termination_reason=termination_reason, state=state, trim=False)
return results.log_Z_mean - log_Z_true, results.log_Z_uncert


def main():
jaxns_version = pkg_resources.get_distribution("jaxns").version
m = 10
run_model_aot = jax.jit(run_model).lower(jax.random.PRNGKey(0)).compile()
dt = []

errors = []
uncerts = []

for i in range(m):
t0 = time.time()
log_Z_error, log_Z_uncert = run_model_aot(jax.random.PRNGKey(i))
log_Z_error.block_until_ready()
t1 = time.time()
dt.append(t1 - t0)
errors.append(log_Z_error)
uncerts.append(log_Z_uncert)
total_time = sum(dt)
best_3 = sum(sorted(dt)[:3]) / 3.
# print(f"Errors: {errors}")
# print(f"Uncerts: {uncerts}")
print(f"JAXNS {jaxns_version}\n"
f"\tMean error: {np.mean(errors)}\n"
f"\tMean uncert: {np.mean(uncerts)}\n"
f"Avg. time taken: {total_time / m:.5f} seconds.\n"
f"The best 3 of {m} runs took {best_3:.5f} seconds.")

with open('results', 'a') as fp:
fp.write(f"{jaxns_version},{np.mean(errors)},{np.mean(uncerts)},{total_time / m},{best_3}\n")

# Before fix
# 2.5.0,2.851858615875244,0.3351728320121765,0.7272443532943725,0.7075355052947998

# After fix
# 2.4.12,0.42119064927101135,0.3309990465641022,1.001870584487915,0.9800511995951334
if __name__ == '__main__':
main()
8 changes: 8 additions & 0 deletions benchmarks/gh168/results
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
2.4.6,0.42119064927101135,0.3309990465641022,1.0280845165252686,0.9580310185750326
2.4.7,0.42119064927101135,0.3309990465641022,0.9979925632476807,0.961020310719808
2.4.8,0.42119064927101135,0.3309990465641022,1.006078839302063,0.9840319156646729
2.4.10,0.42119064927101135,0.3309990465641022,0.9860331773757934,0.9623709519704183
2.4.11,0.42119064927101135,0.3309990465641022,1.037106418609619,1.0193532307942708
2.4.12,0.42119064927101135,0.3309990465641022,1.001870584487915,0.9800511995951334
2.4.13,2.851858615875244,0.3351728320121765,0.7272443532943725,0.7075355052947998
2.5.0,2.851858615875244,0.3351728320121765,0.7534827709197998,0.739771842956543
36 changes: 36 additions & 0 deletions benchmarks/gh168/run_benchmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
#!/bin/bash

# Array of jaxns versions to be installed
declare -a jaxns_versions=("2.4.6" "2.4.7" "2.4.8" "2.4.10" "2.4.11" "2.4.12" "2.4.13" "2.5.0")

# Path to your benchmark script
benchmark_script="main.py"

# Name for the conda environment
conda_env_name="jaxns_benchmarks_env"

# Function to create Conda environment and install jaxns
create_and_activate_env() {
local version=$1
echo "Creating Conda environment for jaxns version $version with Python 3.11..."
conda create --name $conda_env_name python=3.11 -y
eval "$(conda shell.bash hook)"
conda activate $conda_env_name
pip install jaxns==$version
}

# Function to tear down Conda environment
tear_down_env() {
echo "Tearing down Conda environment..."
conda deactivate
conda env remove --name $conda_env_name
}

# Main loop to install each version, run benchmark, and tear down env
for version in "${jaxns_versions[@]}"; do
create_and_activate_env $version
python $benchmark_script
tear_down_env
done

echo "Benchmarking complete."
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
project = "jaxns"
copyright = "2022, Joshua G. Albert"
author = "Joshua G. Albert"
release = "2.5.0"
release = "2.5.1"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
5 changes: 2 additions & 3 deletions jaxns/nested_sampler/standard_static.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,9 +775,8 @@ def replica(key: PRNGKey) -> Tuple[StaticStandardNestedSamplerState, IntArray]:
if self.num_parallel_workers > 1:
# We need to do a final sampling run to make all the chains consistent,
# to a likelihood contour (i.e. standardise on L(X)). Would mean that some workers are idle.
target_log_L_contour = jnp.max(
parallel.all_gather(termination_register.log_L_contour, 'i')
)
target_log_L_contour = parallel.pmax(termination_register.log_L_contour, 'i')

termination_cond = TerminationCondition(
dlogZ=jnp.asarray(0., float_type),
log_L_contour=target_log_L_contour,
Expand Down
4 changes: 2 additions & 2 deletions jaxns/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def plot_cornerplot(results: NestedSamplerResults, variables: Optional[List[str]
# Plot the 2D histogram, over ranges set by the 1_per and 99_per of each parameter
ranges = [param_limits[parameters[col]], param_limits[parameters[row]]]
ax.hist2d(_samples[:, 1], _samples[:, 0], bins=(nbins, nbins), density=True,
cmap=plt.cm.get_cmap('bone_r'),
cmap="bone_r",
weights=_weights, range=ranges)

if kde_overlay: # Put KDE contour on the 2D histograms
Expand Down Expand Up @@ -432,7 +432,7 @@ def add_colorbar_to_axes(ax, cmap, norm=None, vmin=None, vmax=None, label=None):
cax = divider.append_axes('right', size='5%', pad=0.05)
if norm is None:
norm = plt.Normalize(vmin=vmin, vmax=vmax)
sm = plt.cm.ScalarMappable(norm, cmap=plt.cm.get_cmap(cmap))
sm = plt.cm.ScalarMappable(norm, cmap=plt.colormaps.get_cmap(cmap))
if label is None:
ax.figure.colorbar(sm, cax=cax, orientation='vertical')
else:
Expand Down
2 changes: 1 addition & 1 deletion jaxns/public.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def __init__(self, model: BaseAbstractModel,
model=model,
num_slices=model.U_ndims * self._s,
num_phantom_save=self._k,
midpoint_shrink=True,
midpoint_shrink=not difficult_model,
perfect=True
),
init_efficiency_threshold=init_efficiency_threshold,
Expand Down
29 changes: 21 additions & 8 deletions jaxns/samplers/uni_slice_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,8 @@ def _pick_point_in_interval(key: PRNGKey, point_U0: FloatArray, direction: Float
point_U: [D]
t: selection point between [left, right]
"""
t = random.uniform(key, minval=left, maxval=right, dtype=float_type)
u = random.uniform(key, dtype=float_type)
t = left + u * (right - left)
point_U = point_U0 + t * direction
# close_to_zero = (left >= -10*jnp.finfo(left.dtype).eps) & (right <= 10*jnp.finfo(right.dtype).eps)
# point_U = jnp.where(close_to_zero, point_U0, point_U)
Expand All @@ -84,7 +85,7 @@ def _pick_point_in_interval(key: PRNGKey, point_U0: FloatArray, direction: Float


def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: FloatArray,
midpoint_shrink: bool) -> Tuple[FloatArray, FloatArray]:
midpoint_shrink: bool, alpha: jax.Array) -> Tuple[FloatArray, FloatArray]:
"""
Not successful proposal, so shrink, optionally apply exponential shrinkage.
"""
Expand All @@ -97,13 +98,15 @@ def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: Float
# Therefore, it must only use the knowledge of ordering of the likelihoods.
# Basic version: shrink to midpoint of interval, i.e. alpha = 0.5.
# Extended version: shrink to random point in interval.
alpha = random.uniform(key)
# do_midpoint_shrink = random.uniform(key) < 0.5
# alpha = 1 # 0.8 # random.uniform(key)
left = jnp.where((t < 0.), alpha * left, left)
right = jnp.where((t > 0.), alpha * right, right)
return left, right


def _new_proposal(key: PRNGKey, seed_point: SeedPoint, midpoint_shrink: bool, perfect: bool,
def _new_proposal(key: PRNGKey, seed_point: SeedPoint, midpoint_shrink: bool, alpha: jax.Array,
perfect: bool,
gradient_slice: bool,
log_L_constraint: FloatArray,
model: BaseAbstractModel) -> Tuple[FloatArray, FloatArray, IntArray]:
Expand Down Expand Up @@ -150,7 +153,8 @@ def body(carry: Carry) -> Carry:
t=carry.t,
left=carry.left,
right=carry.right,
midpoint_shrink=midpoint_shrink
midpoint_shrink=midpoint_shrink,
alpha=alpha
)
point_U, t = _pick_point_in_interval(
key=t_key,
Expand Down Expand Up @@ -312,14 +316,19 @@ def get_seed_point(self, key: PRNGKey, sampler_state: SamplerState,
def get_sample_from_seed(self, key: PRNGKey, seed_point: SeedPoint, log_L_constraint: FloatArray,
sampler_state: SamplerState) -> Tuple[Sample, Sample]:

def propose_op(sample: Sample, key: PRNGKey) -> Sample:
class XType(NamedTuple):
key: jax.Array
alpha: jax.Array

def propose_op(sample: Sample, x: XType) -> Sample:
U_sample, log_L, num_likelihood_evaluations = _new_proposal(
key=key,
key=x.key,
seed_point=SeedPoint(
U0=sample.U_sample,
log_L0=sample.log_L
),
midpoint_shrink=self.midpoint_shrink,
alpha=x.alpha,
perfect=self.perfect,
gradient_slice=self.gradient_slice,
log_L_constraint=log_L_constraint,
Expand All @@ -338,10 +347,14 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample:
log_L=seed_point.log_L0,
num_likelihood_evaluations=jnp.asarray(0, int_type)
)
xs = XType(
key=random.split(key, self.num_slices),
alpha=jnp.linspace(0.5, 1., self.num_slices)
)
final_sample, cumulative_samples = cumulative_op_static(
op=propose_op,
init=init_sample,
xs=random.split(key, self.num_slices)
xs=xs
)

# Last sample is the final sample, the rest are potential phantom samples
Expand Down
24 changes: 12 additions & 12 deletions jaxns/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,23 +321,23 @@ def multiellipsoidal_mvn_run_results(basic_mvn_model):

@pytest.fixture(scope='package')
def all_run_results(
# basic_run_results,
basic_run_results,
basic_with_obj_run_results,
# basic2_run_results,
# basic3_run_results,
# plateau_run_results,
# basic_mvn_run_results,
basic2_run_results,
basic3_run_results,
plateau_run_results,
basic_mvn_run_results,
# basic_mvn_run_results_parallel,
# multiellipsoidal_mvn_run_results
multiellipsoidal_mvn_run_results
):
# Return tuples with names
return [
# ('basic', basic_run_results),
('basic', basic_run_results),
('basic_with_obj', basic_with_obj_run_results),
# ('basic2', basic2_run_results),
# ('basic3', basic3_run_results),
# ('plateau', plateau_run_results),
# ('basic_mvn', basic_mvn_run_results),
('basic2', basic2_run_results),
('basic3', basic3_run_results),
('plateau', plateau_run_results),
('basic_mvn', basic_mvn_run_results),
# ('basic_mvn_parallel', basic_mvn_run_results_parallel),
# ('multiellipsoidal_mvn', multiellipsoidal_mvn_run_results)
('multiellipsoidal_mvn', multiellipsoidal_mvn_run_results)
]
Loading

0 comments on commit 0bf9fc1

Please sign in to comment.