Skip to content

Commit

Permalink
Merge pull request #165 from Joshuaalbert/develop
Browse files Browse the repository at this point in the history
2.5.0
  • Loading branch information
Joshuaalbert authored May 15, 2024
2 parents e74d848 + 2ae9ca3 commit 37019fe
Show file tree
Hide file tree
Showing 14 changed files with 220 additions and 40 deletions.
6 changes: 5 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,11 @@ is the best way to achieve speed up.

# Change Log

22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of likelihod.
15 May, 2024 -- JAXNS 2.5.0 released. Added ability to handle non-JAX likelihoods, e.g. if you have a simulation
framework with python bindings you can now use it for likelihoods in JAXNS. Small performance improvements.

22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of
likelihod.

20 Mar, 2024 -- JAXNS 2.4.12 released. Minor bug fixes, and readability improvements. Added Empirial special prior.

Expand Down
2 changes: 1 addition & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
project = "jaxns"
copyright = "2022, Joshua G. Albert"
author = "Joshua G. Albert"
release = "2.4.13"
release = "2.5.0"

# -- General configuration ---------------------------------------------------
# https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration
Expand Down
1 change: 1 addition & 0 deletions jaxns/framework/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from jaxns.framework.model import *
from jaxns.framework.prior import *
from jaxns.framework.special_priors import *
from jaxns.framework.jaxify import *
from jaxns.framework.bases import PriorModelGen, PriorModelType
44 changes: 44 additions & 0 deletions jaxns/framework/jaxify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import warnings
from typing import Callable

import jax
import numpy as np

from jaxns.internals.types import float_type, LikelihoodType

__all__ = [
'jaxify_likelihood'
]

def jaxify_likelihood(log_likelihood: Callable[..., np.ndarray], vectorised: bool = False) -> LikelihoodType:
"""
Wraps a non-JAX log likelihood function.
Args:
log_likelihood: a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar
log-likelihood.
vectorised: if True then the `log_likelihood` performs a vectorised computation for leading batch dimensions,
i.e. if a leading batch dimension is added to all input arguments, then it returns a vector of
log-likelihoods with the same leading batch dimension.
Returns:
A JAX-compatible log-likelihood function.
"""
warnings.warn(
"You're using a non-JAX log-likelihood function. This may be slower than a JAX log-likelihood function. "
"Also, you are responsible for ensuring that the function is deterministic. "
"Also, you cannot use learnable parameters in the likelihood call."
)

def _casted_log_likelihood(*args) -> np.ndarray:
return np.asarray(log_likelihood(*args), dtype=float_type)

def _log_likelihood(*args) -> jax.Array:
# Define the expected shape & dtype of output.
result_shape_dtype = jax.ShapeDtypeStruct(
shape=(),
dtype=float_type
)
return jax.pure_callback(_casted_log_likelihood, result_shape_dtype, *args, vectorized=vectorised)

return _log_likelihood
35 changes: 35 additions & 0 deletions jaxns/framework/tests/test_jaxify.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
import jax
import jax.random
import numpy as np

from jaxns import Prior, Model
from jaxns.framework.jaxify import jaxify_likelihood
from jaxns.framework.tests.test_model import tfpd


def test_jaxify_likelihood():
def log_likelihood(x, y):
return np.sum(x, axis=-1) + np.sum(y, axis=-1)

wrapped_ll = jaxify_likelihood(log_likelihood)
np.testing.assert_allclose(wrapped_ll(np.array([1, 2]), np.array([3, 4])), 10)

vmaped_wrapped_ll = jax.vmap(jaxify_likelihood(log_likelihood, vectorised=True))

np.testing.assert_allclose(vmaped_wrapped_ll(np.array([[1, 2], [2, 2]]), np.array([[3, 4], [4, 4]])),
np.array([10, 12]))


def test_jaxify():
def prior_model():
x = yield Prior(tfpd.Uniform(), name='x').parametrised()
return x

@jaxify_likelihood
def log_likelihood(x):
return x

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
model.sanity_check(key=jax.random.PRNGKey(0), S=10)
assert model.U_ndims == 0
assert model.num_params == 1
5 changes: 1 addition & 4 deletions jaxns/framework/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def log_likelihood(obj: Obj):
model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
model.sanity_check(key=jax.random.PRNGKey(0), S=10)

def test_empty_prior_models():

def test_empty_prior_models():
def prior_model():
return 1.

Expand All @@ -148,6 +148,3 @@ def log_likelihood(x):

model = Model(prior_model=prior_model, log_likelihood=log_likelihood)
model.sanity_check(key=jax.random.PRNGKey(0), S=10)



35 changes: 35 additions & 0 deletions jaxns/internals/cumulative_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,48 @@

import jax
from jax import lax, numpy as jnp, tree_util
from tensorflow_probability.substrates.jax import math as tfp_math

from jaxns.internals.types import IntArray, int_type

X = TypeVar('X')
V = TypeVar('V')
Y = TypeVar('Y')


def scan_associative_cumulative_op(op: Callable[[X, X], X], init: X, xs: X, pre_op: bool = False) -> Tuple[X, X]:
"""
Compute a cumulative operation on an array of values using scan_associative.
Args:
op: the operation to perform, must be associative.
init: the initial value.
xs: the array of values.
Returns:
the final accumulated value, and the result of the cumulative operation applied on input
"""

def associative_op(a, b):
return op(a, b)

# Prepare the input array by prepending the initial value
full_input = jax.tree.map(lambda x, y: jnp.concatenate([x[None], y], axis=0), init, xs)

# Apply the operation to accumulate results using scan_associative
scanned_results = tfp_math.scan_associative(associative_op, full_input)

# The final accumulated value is the last element in the results
final_accumulate = jax.tree.map(lambda x: x[-1], scanned_results)

if pre_op:
scanned_results = jax.tree.map(lambda x: x[:-1], scanned_results)
else:
scanned_results = jax.tree.map(lambda x: x[1:], scanned_results)

return final_accumulate, scanned_results


def cumulative_op_static(op: Callable[[V, Y], V], init: V, xs: Y, pre_op: bool = False, unroll: int = 1) -> Tuple[
V, V]:
"""
Expand Down
9 changes: 5 additions & 4 deletions jaxns/internals/shrinkage_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax.numpy as jnp

from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic
from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op, cumulative_op_static
from jaxns.internals.log_semiring import LogSpace
from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges
from jaxns.internals.types import MeasureType, EvidenceCalculation, float_type, IntArray, FloatArray
Expand All @@ -28,8 +28,8 @@ def op(log_X, num_live_points):
next_X_mean = X_mean * T_mean
return next_X_mean.log_abs_val

_, log_X = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type),
xs=live_point_counts.num_live_points)
_, log_X = scan_associative_cumulative_op(op=op, init=jnp.asarray(-jnp.inf, float_type),
xs=live_point_counts.num_live_points)
return log_X


Expand Down Expand Up @@ -141,7 +141,8 @@ def compute_evidence_stats(log_L: MeasureType, num_live_points: FloatArray, num_
)
if num_samples is not None:
stop_idx = num_samples
final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, stop_idx=stop_idx)
final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs,
stop_idx=stop_idx)
else:
final_accumulate, result = cumulative_op_static(op=_update_evidence_calc_op, init=init, xs=xs)
final_evidence_calculation = final_accumulate
Expand Down
63 changes: 62 additions & 1 deletion jaxns/internals/tests/test_cumulative_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import jax
import numpy as np
import pytest
from jax import numpy as jnp

from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic
from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic, scan_associative_cumulative_op
from jaxns.internals.types import float_type, int_type


Expand All @@ -19,6 +22,64 @@ def op(accumulate, y):
assert all(result == jnp.asarray([0, 1, 3], float_type))


@pytest.mark.parametrize("binary_op", [jnp.add, jnp.multiply, jnp.minimum, jnp.maximum])
def test_scan_associative_cumulative_op(binary_op):
def op(accumulate, y):
return binary_op(accumulate, y)

init = jnp.asarray(1, float_type)
xs = jnp.arange(1, 11, dtype=float_type)
final_accumulate, result = scan_associative_cumulative_op(op=binary_op, init=init, xs=xs)
final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs)
assert final_accumulate == final_accumulate_expected
np.testing.assert_allclose(result, result_expected)

final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True)
final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs, pre_op=True)
assert final_accumulate == final_accumulate_expected
np.testing.assert_allclose(result, result_expected)


@pytest.mark.parametrize("binary_op", [jnp.subtract, jnp.true_divide])
def test_scan_associative_cumulative_op(binary_op):
def op(accumulate, y):
return binary_op(accumulate, y)

init = jnp.asarray(1, float_type)
xs = jnp.arange(1, 11, dtype=float_type)
final_accumulate, result = scan_associative_cumulative_op(op=binary_op, init=init, xs=xs)
final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs)
with pytest.raises(AssertionError):
assert final_accumulate == final_accumulate_expected
with pytest.raises(AssertionError):
np.testing.assert_allclose(result, result_expected)

final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True)
final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs, pre_op=True)
with pytest.raises(AssertionError):
assert final_accumulate == final_accumulate_expected
with pytest.raises(AssertionError):
np.testing.assert_allclose(result, result_expected)


def test_scan_associative_cumulative_op_with_pytrees():
# Test with pytrees for xs and ys
def op(accumulate, y):
return jax.tree.map(lambda x, y: jnp.add(x, y), accumulate, y)

init = {'a': jnp.asarray(0, float_type), 'b': jnp.asarray(0, float_type)}
xs = {'a': jnp.asarray([1, 2, 3], float_type), 'b': jnp.asarray([4, 5, 6], float_type)}
final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs)
assert final_accumulate == {'a': 6, 'b': 15}
assert all(result['a'] == jnp.asarray([1, 3, 6], float_type))
assert all(result['b'] == jnp.asarray([4, 9, 15], float_type))

final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True)
assert final_accumulate == {'a': 6, 'b': 15}
assert all(result['a'] == jnp.asarray([0, 1, 3], float_type))
assert all(result['b'] == jnp.asarray([0, 4, 9], float_type))


def test_cumulative_op_dynamic():
def op(accumulate, y):
return accumulate + y
Expand Down
34 changes: 19 additions & 15 deletions jaxns/internals/tests/test_tree_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from jaxns.internals.types import StaticStandardNestedSamplerState, StaticStandardSampleCollection


def pytree_assert_equal(a, b):
print(a)
print(b)
for x, y in zip(jax.tree.leaves(a), jax.tree.leaves(b)):
np.testing.assert_allclose(x, y)


def test_naive():
S = SampleTreeGraph(
sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]),
Expand All @@ -29,20 +36,18 @@ def test_basic():
sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]),
log_L=jnp.asarray([1, 2, 3, 4, 5, 6])
)
assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S)))
assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S)))
assert all(
jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S)))
pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S))
pytree_assert_equal(count_crossed_edges(S), count_old(S))
pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S))

S = SampleTreeGraph(
sender_node_idx=jnp.asarray([0, 0, 0, 1, 3, 2]),
log_L=jnp.asarray([1, 2, 3, 4, 6, 5])
)

assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S)))
assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S)))
assert all(
jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S)))
pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S))
pytree_assert_equal(count_crossed_edges(S), count_old(S))
pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S))


def test_with_num_samples():
Expand All @@ -57,9 +62,9 @@ def test_with_num_samples():
log_L=jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8])
)

assert all(jax.tree.map(lambda x, y: np.array_equal(x[:num_samples], y),
count_crossed_edges(S1, num_samples),
count_crossed_edges(S2)))
x = jax.tree.map(lambda x: x[:num_samples], count_crossed_edges(S1, num_samples))

pytree_assert_equal(x, count_crossed_edges(S2))

output = count_crossed_edges(S1, num_samples)
print(output)
Expand Down Expand Up @@ -89,10 +94,9 @@ def test_random_tree():

plot_tree(S)

assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S)))
assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S)))
assert all(
jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S)))
pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S))
pytree_assert_equal(count_crossed_edges(S), count_old(S))
pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S))

T = count_crossed_edges_less_fast(S)
import pylab as plt
Expand Down
2 changes: 1 addition & 1 deletion jaxns/internals/tree_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import numpy as jnp, lax, core
from jax._src.numpy import lax_numpy

from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic
from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op, cumulative_op_static
from jaxns.internals.maps import remove_chunk_dim
from jaxns.internals.types import MeasureType, IntArray, float_type, FloatArray, StaticStandardNestedSamplerState, \
int_type
Expand Down
3 changes: 1 addition & 2 deletions jaxns/samplers/multi_slice_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample:
final_sample, cumulative_samples = cumulative_op_static(
op=propose_op,
init=init_sample,
xs=random.split(key, self.num_slices),
unroll=2
xs=random.split(key, self.num_slices)
)

# Last sample is the final sample, the rest are potential phantom samples
Expand Down
Loading

0 comments on commit 37019fe

Please sign in to comment.