Skip to content

Commit

Permalink
Merge pull request #160 from Joshuaalbert/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
Joshuaalbert authored Apr 22, 2024
2 parents 20fdfca + 649c930 commit e74d848
Show file tree
Hide file tree
Showing 26 changed files with 472 additions and 193 deletions.
9 changes: 8 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,12 @@ ns_jit = jax.jit(ns)
You can inspect the results, and plot them.

```python
from jaxns import summary, plot_diagnostics, plot_cornerplot
from jaxns import summary, plot_diagnostics, plot_cornerplot, save_results, load_results

# Optionally save the results to file
save_results(results, 'results.json')
# To load the results back use this
results = load_results('results.json')

summary(results)
plot_diagnostics(results)
Expand Down Expand Up @@ -358,6 +363,8 @@ is the best way to achieve speed up.

# Change Log

22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of likelihod.

20 Mar, 2024 -- JAXNS 2.4.12 released. Minor bug fixes, and readability improvements. Added Empirial special prior.

5 Mar, 2024 -- JAXNS 2.4.11/b released. Add `random_init` to parametrised variables. Enable special priors to be
Expand Down
12 changes: 3 additions & 9 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
# For the full list of built-in configuration values, see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html

import sphinx_rtd_theme

# sys.path.insert(0, os.path.abspath("..")) # add project root to abs path


Expand All @@ -14,8 +12,7 @@
project = "jaxns"
copyright = "2022, Joshua G. Albert"
author = "Joshua G. Albert"
release = "2.4.12"

release = "2.4.13"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down Expand Up @@ -57,25 +54,22 @@

autodoc_typehints = "description"


# -- Options for AutoAPI -----------------------------------------------------

autoapi_dirs = ["../jaxns"]
autoapi_root = "api" # where to put the generated files relative to root
autoapi_options = ["members", "undoc-members", "show-inheritance",
"special-members", "imported-members"]
autoapi_options = ["members", "undoc-members", "show-inheritance",
"special-members", "imported-members"]
autoapi_member_order = "bysource" # order members by source code
autoapi_ignore = ["*/tests/*"] # ignore tests
autoapi_template_dir = "_templates/autoapi"
autoapi_python_class_content = "both" # Use both class and __init__ docstrings
autoapi_add_toctree_entry = False


# -- Options for NBSphinx ----------------------------------------------------

nbsphinx_execute = "never" # never execute notebooks (slow) during building


# -- Copy notebooks to docs --------------------------------------------------
# Copies the notebooks from the project directory to the docs directory so that
# they can be parsed by nbsphinx.
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/mvn_data_mvn_prior.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -452,9 +452,9 @@
"results = ns.to_results(termination_reason=termination_reason, state=state)\n",
"\n",
"# We can always save results to play with later\n",
"ns.save_results(results, 'save.npz')\n",
"ns.save_results(results, 'save.json')\n",
"# loads previous results by uncommenting below\n",
"# results = load_results('save.npz')\n",
"# results = load_results('save.json')\n",
"\n"
]
},
Expand Down
12 changes: 6 additions & 6 deletions jaxns/experimental/evidence_maximisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,13 @@ def log_evidence(params: hk.MutableParams, data: MStepData):
def loss(params: hk.MutableParams, data: MStepData):
log_Z, grad = jax.value_and_grad(log_evidence, argnums=0)(params, data)
obj = -log_Z
grad = jax.tree_map(jnp.negative, grad)
grad = jax.tree.map(jnp.negative, grad)

# If objective is -+inf, or nan, then the gradient is nan
grad = jax.tree_map(lambda x: jnp.where(jnp.isfinite(obj), x, jnp.zeros_like(x)), grad)
grad = jax.tree.map(lambda x: jnp.where(jnp.isfinite(obj), x, jnp.zeros_like(x)), grad)

# Clip the gradient
grad = jax.tree_map(lambda x: jnp.clip(x, -10, 10), grad)
grad = jax.tree.map(lambda x: jnp.clip(x, -10, 10), grad)

aux = (log_Z,)
if self.verbose:
Expand Down Expand Up @@ -247,7 +247,7 @@ def op(log_Z, data):
def loss(params: hk.MutableParams, data: MStepData):
log_Z, grad = jax.value_and_grad(log_evidence, argnums=0)(params, data)
obj = -log_Z
grad = jax.tree_map(jnp.negative, grad)
grad = jax.tree.map(jnp.negative, grad)
aux = (log_Z,)
if self.verbose:
jax.debug.print("log_Z={log_Z}", log_Z=log_Z)
Expand Down Expand Up @@ -316,11 +316,11 @@ def _pad_to_n(x, fill_value, dtype):
log_Z = None
while epoch < self.max_num_epochs:
params, (log_Z,) = self._m_step(key=key, params=params, data=data)
l_oo = jax.tree_map(lambda x, y: jnp.max(jnp.abs(x - y)) if np.size(x) > 0 else 0.,
l_oo = jax.tree.map(lambda x, y: jnp.max(jnp.abs(x - y)) if np.size(x) > 0 else 0.,
last_params, params)
last_params = params
p_bar.set_description(f"{desc}: Epoch {epoch}: log_Z={log_Z}, l_oo={l_oo}")
if all(_l_oo < self.gtol for _l_oo in jax.tree_leaves(l_oo)):
if all(_l_oo < self.gtol for _l_oo in jax.tree.leaves(l_oo)):
break
epoch += 1

Expand Down
27 changes: 21 additions & 6 deletions jaxns/experimental/global_optimisation.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import io
from typing import NamedTuple, Optional, Union, TextIO, Tuple, List

import jax
import jax.nn
import jax.numpy as jnp
import numpy as np
from jax import lax, random, pmap, tree_map
from jax import lax, random, pmap
from jax._src.lax import parallel
from jax._src.scipy.special import logit
from jaxopt import NonlinearCG
Expand Down Expand Up @@ -134,16 +135,20 @@ def _set_done_bit(bit_done, bit_reason, done, termination_reason):
# relative spread of log-likelihood values below threshold
max_log_L = jnp.max(state.samples.log_L)
min_log_L = jnp.min(state.samples.log_L)
diff_log_L = jnp.abs(max_log_L - min_log_L)
reached_rtol = diff_log_L <= 0.5 * term_cond.rtol * jnp.abs(max_log_L + min_log_L)
diff_log_L = jnp.abs(max_log_L - min_log_L) # NaN = inf - inf
diff_log_L = jnp.where(jnp.isnan(diff_log_L), jnp.inf, diff_log_L)
mean_log_L = 0.5 * jnp.abs(max_log_L + min_log_L) # = inf - inf
mean_log_L = jnp.where(jnp.isnan(mean_log_L), jnp.inf, mean_log_L)
reached_rtol = diff_log_L <= term_cond.rtol * mean_log_L
done, termination_reason = _set_done_bit(reached_rtol, 2,
done=done, termination_reason=termination_reason)

if term_cond.atol is not None:
# absolute spread of log-likelihood values below threshold
max_log_L = jnp.max(state.samples.log_L)
min_log_L = jnp.min(state.samples.log_L)
diff_log_L = jnp.abs(max_log_L - min_log_L)
diff_log_L = jnp.abs(max_log_L - min_log_L) # NaN = inf - inf
diff_log_L = jnp.where(jnp.isnan(diff_log_L), jnp.inf, diff_log_L)
reached_atol = diff_log_L <= term_cond.atol
done, termination_reason = _set_done_bit(reached_atol, 3,
done=done, termination_reason=termination_reason)
Expand Down Expand Up @@ -227,7 +232,7 @@ def _repeat(x):
return jnp.repeat(x, (k + 1), axis=0)

fake_state = fake_state._replace(
sample_collection=tree_map(_repeat, fake_state.sample_collection)
sample_collection=jax.tree.map(_repeat, fake_state.sample_collection)
)

fake_state, fake_termination_register = _inter_sync_shrinkage_process(
Expand All @@ -253,7 +258,7 @@ def _select(x):
return x[choose_idx, jnp.arange(num_samples)] # [N, ...]

fake_state = fake_state._replace(
sample_collection=tree_map(
sample_collection=jax.tree.map(
_select,
fake_state.sample_collection
)
Expand Down Expand Up @@ -363,7 +368,17 @@ def _to_results(self, termination_reason: IntArray, state: GlobalOptimisationSta
max_log_L = state.samples.log_L[best_idx]
min_log_L = jnp.min(state.samples.log_L)
relative_spread = 2. * jnp.abs(max_log_L - min_log_L) / jnp.abs(max_log_L + min_log_L)
relative_spread = jnp.where(
jnp.isnan(relative_spread),
jnp.asarray(jnp.inf, relative_spread.dtype),
relative_spread
)
absolute_spread = jnp.abs(max_log_L - min_log_L)
absolute_spread = jnp.where(
jnp.isnan(absolute_spread),
jnp.asarray(jnp.inf, absolute_spread.dtype),
absolute_spread
)
return GlobalOptimisationResults(
U_solution=state.samples.U_sample[best_idx],
X_solution=X_solution,
Expand Down
6 changes: 3 additions & 3 deletions jaxns/experimental/tests/test_evidence_maximisation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def change(x, y):
return True

assert all(
change(p, p_) for p, p_ in zip(jax.tree_util.tree_leaves(model.params), jax.tree_util.tree_leaves(params)))
change(p, p_) for p, p_ in zip(jax.tree.leaves(model.params), jax.tree.leaves(params)))


def test_basic_zero_size_param():
Expand All @@ -52,7 +52,7 @@ def log_likelihood(y, z, sigma):
model = Model(prior_model=prior_model, log_likelihood=log_likelihood)

em = EvidenceMaximisation(model=model, ns_kwargs=dict(max_samples=1e5))
assert any(np.size(p) == 0 for p in jax.tree_util.tree_leaves(model.params))
assert any(np.size(p) == 0 for p in jax.tree.leaves(model.params))

ns_results, params = em.train(num_steps=1)

Expand All @@ -62,4 +62,4 @@ def change(x, y):
return True

assert all(
change(p, p_) for p, p_ in zip(jax.tree_util.tree_leaves(model.params), jax.tree_util.tree_leaves(params)))
change(p, p_) for p, p_ in zip(jax.tree.leaves(model.params), jax.tree.leaves(params)))
52 changes: 50 additions & 2 deletions jaxns/framework/special_priors.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
from jaxns.framework.bases import BaseAbstractPrior
from jaxns.framework.prior import SingularPrior, prior_to_parametrised_singular
from jaxns.internals.log_semiring import cumulative_logsumexp
from jaxns.internals.types import FloatArray, IntArray, BoolArray, float_type, int_type
from jaxns.internals.types import FloatArray, IntArray, BoolArray, float_type, int_type, UType, RandomVariableType, \
MeasureType

tfpd = tfp.distributions

Expand All @@ -20,7 +21,9 @@
"Categorical",
"ForcedIdentifiability",
"Poisson",
"UnnormalisedDirichlet"
"UnnormalisedDirichlet",
"Empirical",
"TruncationWrapper"
]


Expand Down Expand Up @@ -492,3 +495,48 @@ def _quantile(self, U):
x = jax.vmap(lambda u, per: jnp.interp(u * 100., self._q, per), in_axes=(0, 1))(U_flat, self._percentiles)
x = lax.reshape(x, U.shape)
return x


class TruncationWrapper(SpecialPrior):
"""
Wraps another prior to make it truncated.
For truncated distribution the quantile transforms to:
Q_truncated(p) = Q_untruncated( p * (F_truncated(high) - F_truncated(low)) + F_truncated(low))
And the CDF transforms to:
F_truncated(x) = (F_untruncated(x) - F_untruncated(low)) / (F_untruncated(high) - F_untruncated(low))
"""

def __init__(self, prior: BaseAbstractPrior, low: Union[jax.Array, float], high: Union[jax.Array, float],
name: Optional[str] = None):
super(TruncationWrapper, self).__init__(name=name)
self.prior = prior
self.low = jnp.minimum(low, high)
self.high = jnp.maximum(low, high)
self.cdf_low = self.prior._inverse(self.low)
self.cdf_diff = self.prior._inverse(self.high) - self.prior._inverse(self.low)

def _inverse(self, X: RandomVariableType) -> UType:
return jnp.clip((self.prior._inverse(X) - self.cdf_low) / jnp.maximum(self.cdf_diff, 1e-6),
0., 1.)

def _forward(self, U: UType) -> RandomVariableType:
return jnp.clip(self.prior._forward(jnp.clip(U * self.cdf_diff + self.cdf_low,
0., 1.)),
self.low, self.high)

def _log_prob(self, X: RandomVariableType) -> MeasureType:
outside_mask = jnp.bitwise_or(X < self.low, X > self.high)
return jnp.where(outside_mask, -jnp.inf, self.prior._log_prob(X) - jnp.log(self.cdf_diff))

def _base_shape(self) -> Tuple[int, ...]:
return self.prior._base_shape()

def _shape(self) -> Tuple[int, ...]:
return self.prior._shape()

def _dtype(self) -> jnp.dtype:
return self.prior._dtype()
77 changes: 76 additions & 1 deletion jaxns/framework/tests/test_prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from jaxns.framework.ops import parse_prior, prepare_input, compute_log_likelihood
from jaxns.framework.prior import Prior, InvalidPriorName
from jaxns.framework.special_priors import Bernoulli, Categorical, Poisson, Beta, ForcedIdentifiability, \
UnnormalisedDirichlet, _poisson_quantile_bisection, _poisson_quantile, Empirical
UnnormalisedDirichlet, _poisson_quantile_bisection, _poisson_quantile, Empirical, TruncationWrapper
from jaxns.framework.wrapped_tfp_distribution import InvalidDistribution, distribution_chain
from jaxns.internals.types import float_type

Expand Down Expand Up @@ -343,3 +343,78 @@ def test_empirical():
# print(u)
assert jnp.allclose(u, u_input)
assert u.shape[1:] == prior.base_shape

def test_truncation_wrapper():
prior = Prior(tfpd.Normal(loc=jnp.zeros(5), scale=jnp.ones(5)))
trancated_prior = TruncationWrapper(prior=prior, low=0., high=1.)

x = trancated_prior.forward(jnp.ones(trancated_prior.base_shape, float_type))
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
assert jnp.all(x >= 0.)
assert jnp.all(x <= 1.)
assert x.shape == (5,)

x = trancated_prior.forward(jnp.zeros(trancated_prior.base_shape, float_type))
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
assert jnp.all(x >= 0.)
assert jnp.all(x <= 1.)
assert x.shape == (5,)

u_input = vmap(lambda key: random.uniform(key, shape=trancated_prior.base_shape))(random.split(random.PRNGKey(42), 1000))
x = vmap(lambda u: trancated_prior.forward(u))(u_input)
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
assert jnp.all(x >= 0.)
assert jnp.all(x <= 1.)

u = vmap(lambda x: trancated_prior.inverse(x))(x)
np.testing.assert_allclose(u, u_input, atol=5e-7)

prior = Prior(tfpd.Normal(loc=jnp.zeros(5), scale=jnp.ones(5)))
trancated_prior = TruncationWrapper(prior=prior, low=-jnp.inf, high=1.)

x = trancated_prior.forward(jnp.ones(trancated_prior.base_shape, float_type))
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
assert jnp.all(x >= -jnp.inf)
assert jnp.all(x <= 1.)
assert x.shape == (5,)

x = trancated_prior.forward(jnp.zeros(trancated_prior.base_shape, float_type))
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
assert jnp.all(x >= -jnp.inf)
assert jnp.all(x <= 1.)
assert x.shape == (5,)

u_input = vmap(lambda key: random.uniform(key, shape=trancated_prior.base_shape))(
random.split(random.PRNGKey(42), 1000))
x = vmap(lambda u: trancated_prior.forward(u))(u_input)
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
assert jnp.all(x >= -jnp.inf)
assert jnp.all(x <= 1.)

u = vmap(lambda x: trancated_prior.inverse(x))(x)
np.testing.assert_allclose(u, u_input, atol=5e-7)

prior = Prior(tfpd.Normal(loc=jnp.zeros(5), scale=0.01*jnp.ones(5)))
trancated_prior = TruncationWrapper(prior=prior, low=0., high=1.)

x = trancated_prior.forward(jnp.ones(trancated_prior.base_shape, float_type))
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
assert jnp.all(x >= 0.)
assert jnp.all(x <= 1.)
assert x.shape == (5,)

x = trancated_prior.forward(jnp.zeros(trancated_prior.base_shape, float_type))
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
assert jnp.all(x >= 0.)
assert jnp.all(x <= 1.)
assert x.shape == (5,)

u_input = vmap(lambda key: random.uniform(key, shape=trancated_prior.base_shape))(
random.split(random.PRNGKey(42), 1000))
x = vmap(lambda u: trancated_prior.forward(u))(u_input)
assert jnp.all(jnp.bitwise_not(jnp.isnan(x)))
assert jnp.all(x >= 0.)
assert jnp.all(x <= 1.)

u = vmap(lambda x: trancated_prior.inverse(x))(x)
np.testing.assert_allclose(u, u_input, atol=5e-7)
Loading

0 comments on commit e74d848

Please sign in to comment.