diff --git a/jaxns/internals/cumulative_ops.py b/jaxns/internals/cumulative_ops.py index 95a1c596..c202b09a 100644 --- a/jaxns/internals/cumulative_ops.py +++ b/jaxns/internals/cumulative_ops.py @@ -2,13 +2,48 @@ import jax from jax import lax, numpy as jnp, tree_util +from tensorflow_probability.substrates.jax import math as tfp_math from jaxns.internals.types import IntArray, int_type +X = TypeVar('X') V = TypeVar('V') Y = TypeVar('Y') +def scan_associative_cumulative_op(op: Callable[[X, X], X], init: X, xs: X, pre_op: bool = False) -> Tuple[X, X]: + """ + Compute a cumulative operation on an array of values using scan_associative. + + Args: + op: the operation to perform, must be associative. + init: the initial value. + xs: the array of values. + + Returns: + the final accumulated value, and the result of the cumulative operation applied on input + """ + + def associative_op(a, b): + return op(a, b) + + # Prepare the input array by prepending the initial value + full_input = jax.tree.map(lambda x, y: jnp.concatenate([x[None], y], axis=0), init, xs) + + # Apply the operation to accumulate results using scan_associative + scanned_results = tfp_math.scan_associative(associative_op, full_input) + + # The final accumulated value is the last element in the results + final_accumulate = jax.tree.map(lambda x: x[-1], scanned_results) + + if pre_op: + scanned_results = jax.tree.map(lambda x: x[:-1], scanned_results) + else: + scanned_results = jax.tree.map(lambda x: x[1:], scanned_results) + + return final_accumulate, scanned_results + + def cumulative_op_static(op: Callable[[V, Y], V], init: V, xs: Y, pre_op: bool = False, unroll: int = 1) -> Tuple[ V, V]: """ diff --git a/jaxns/internals/shrinkage_statistics.py b/jaxns/internals/shrinkage_statistics.py index 1f9f7b1d..685f362d 100644 --- a/jaxns/internals/shrinkage_statistics.py +++ b/jaxns/internals/shrinkage_statistics.py @@ -2,7 +2,7 @@ import jax.numpy as jnp -from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic +from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op, cumulative_op_static from jaxns.internals.log_semiring import LogSpace from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges from jaxns.internals.types import MeasureType, EvidenceCalculation, float_type, IntArray, FloatArray @@ -28,8 +28,8 @@ def op(log_X, num_live_points): next_X_mean = X_mean * T_mean return next_X_mean.log_abs_val - _, log_X = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type), - xs=live_point_counts.num_live_points) + _, log_X = scan_associative_cumulative_op(op=op, init=jnp.asarray(-jnp.inf, float_type), + xs=live_point_counts.num_live_points) return log_X @@ -141,7 +141,8 @@ def compute_evidence_stats(log_L: MeasureType, num_live_points: FloatArray, num_ ) if num_samples is not None: stop_idx = num_samples - final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, stop_idx=stop_idx) + final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, + stop_idx=stop_idx) else: final_accumulate, result = cumulative_op_static(op=_update_evidence_calc_op, init=init, xs=xs) final_evidence_calculation = final_accumulate diff --git a/jaxns/internals/tests/test_cumulative_ops.py b/jaxns/internals/tests/test_cumulative_ops.py index ade57694..6214b37d 100644 --- a/jaxns/internals/tests/test_cumulative_ops.py +++ b/jaxns/internals/tests/test_cumulative_ops.py @@ -1,6 +1,9 @@ +import jax +import numpy as np +import pytest from jax import numpy as jnp -from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic +from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic, scan_associative_cumulative_op from jaxns.internals.types import float_type, int_type @@ -19,6 +22,64 @@ def op(accumulate, y): assert all(result == jnp.asarray([0, 1, 3], float_type)) +@pytest.mark.parametrize("binary_op", [jnp.add, jnp.multiply, jnp.minimum, jnp.maximum]) +def test_scan_associative_cumulative_op(binary_op): + def op(accumulate, y): + return binary_op(accumulate, y) + + init = jnp.asarray(1, float_type) + xs = jnp.arange(1, 11, dtype=float_type) + final_accumulate, result = scan_associative_cumulative_op(op=binary_op, init=init, xs=xs) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs) + assert final_accumulate == final_accumulate_expected + np.testing.assert_allclose(result, result_expected) + + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs, pre_op=True) + assert final_accumulate == final_accumulate_expected + np.testing.assert_allclose(result, result_expected) + + +@pytest.mark.parametrize("binary_op", [jnp.subtract, jnp.true_divide]) +def test_scan_associative_cumulative_op(binary_op): + def op(accumulate, y): + return binary_op(accumulate, y) + + init = jnp.asarray(1, float_type) + xs = jnp.arange(1, 11, dtype=float_type) + final_accumulate, result = scan_associative_cumulative_op(op=binary_op, init=init, xs=xs) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs) + with pytest.raises(AssertionError): + assert final_accumulate == final_accumulate_expected + with pytest.raises(AssertionError): + np.testing.assert_allclose(result, result_expected) + + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs, pre_op=True) + with pytest.raises(AssertionError): + assert final_accumulate == final_accumulate_expected + with pytest.raises(AssertionError): + np.testing.assert_allclose(result, result_expected) + + +def test_scan_associative_cumulative_op_with_pytrees(): + # Test with pytrees for xs and ys + def op(accumulate, y): + return jax.tree.map(lambda x, y: jnp.add(x, y), accumulate, y) + + init = {'a': jnp.asarray(0, float_type), 'b': jnp.asarray(0, float_type)} + xs = {'a': jnp.asarray([1, 2, 3], float_type), 'b': jnp.asarray([4, 5, 6], float_type)} + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs) + assert final_accumulate == {'a': 6, 'b': 15} + assert all(result['a'] == jnp.asarray([1, 3, 6], float_type)) + assert all(result['b'] == jnp.asarray([4, 9, 15], float_type)) + + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) + assert final_accumulate == {'a': 6, 'b': 15} + assert all(result['a'] == jnp.asarray([0, 1, 3], float_type)) + assert all(result['b'] == jnp.asarray([0, 4, 9], float_type)) + + def test_cumulative_op_dynamic(): def op(accumulate, y): return accumulate + y diff --git a/jaxns/internals/tests/test_tree_structure.py b/jaxns/internals/tests/test_tree_structure.py index bd60c88f..03ea09e6 100644 --- a/jaxns/internals/tests/test_tree_structure.py +++ b/jaxns/internals/tests/test_tree_structure.py @@ -10,6 +10,13 @@ from jaxns.internals.types import StaticStandardNestedSamplerState, StaticStandardSampleCollection +def pytree_assert_equal(a, b): + print(a) + print(b) + for x, y in zip(jax.tree.leaves(a), jax.tree.leaves(b)): + np.testing.assert_allclose(x, y) + + def test_naive(): S = SampleTreeGraph( sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]), @@ -29,20 +36,18 @@ def test_basic(): sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]), log_L=jnp.asarray([1, 2, 3, 4, 5, 6]) ) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S))) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S))) - assert all( - jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S))) + pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S)) + pytree_assert_equal(count_crossed_edges(S), count_old(S)) + pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S)) S = SampleTreeGraph( sender_node_idx=jnp.asarray([0, 0, 0, 1, 3, 2]), log_L=jnp.asarray([1, 2, 3, 4, 6, 5]) ) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S))) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S))) - assert all( - jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S))) + pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S)) + pytree_assert_equal(count_crossed_edges(S), count_old(S)) + pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S)) def test_with_num_samples(): @@ -57,9 +62,9 @@ def test_with_num_samples(): log_L=jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8]) ) - assert all(jax.tree.map(lambda x, y: np.array_equal(x[:num_samples], y), - count_crossed_edges(S1, num_samples), - count_crossed_edges(S2))) + x = jax.tree.map(lambda x: x[:num_samples], count_crossed_edges(S1, num_samples)) + + pytree_assert_equal(x, count_crossed_edges(S2)) output = count_crossed_edges(S1, num_samples) print(output) @@ -89,10 +94,9 @@ def test_random_tree(): plot_tree(S) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S))) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S))) - assert all( - jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S))) + pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S)) + pytree_assert_equal(count_crossed_edges(S), count_old(S)) + pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S)) T = count_crossed_edges_less_fast(S) import pylab as plt diff --git a/jaxns/internals/tree_structure.py b/jaxns/internals/tree_structure.py index dbad50fd..bda36c8e 100644 --- a/jaxns/internals/tree_structure.py +++ b/jaxns/internals/tree_structure.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, lax, core from jax._src.numpy import lax_numpy -from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic +from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op, cumulative_op_static from jaxns.internals.maps import remove_chunk_dim from jaxns.internals.types import MeasureType, IntArray, float_type, FloatArray, StaticStandardNestedSamplerState, \ int_type diff --git a/jaxns/samplers/multi_slice_sampler.py b/jaxns/samplers/multi_slice_sampler.py index b4ff0112..b7ee78c8 100644 --- a/jaxns/samplers/multi_slice_sampler.py +++ b/jaxns/samplers/multi_slice_sampler.py @@ -258,8 +258,7 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample: final_sample, cumulative_samples = cumulative_op_static( op=propose_op, init=init_sample, - xs=random.split(key, self.num_slices), - unroll=2 + xs=random.split(key, self.num_slices) ) # Last sample is the final sample, the rest are potential phantom samples diff --git a/jaxns/samplers/uni_slice_sampler.py b/jaxns/samplers/uni_slice_sampler.py index a0da6942..799ee7b7 100644 --- a/jaxns/samplers/uni_slice_sampler.py +++ b/jaxns/samplers/uni_slice_sampler.py @@ -83,8 +83,7 @@ def _pick_point_in_interval(key: PRNGKey, point_U0: FloatArray, direction: Float return point_U, t -def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: FloatArray, log_L_proposal: FloatArray, - log_L_constraint: FloatArray, log_L0: FloatArray, +def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: FloatArray, midpoint_shrink: bool) -> Tuple[FloatArray, FloatArray]: """ Not successful proposal, so shrink, optionally apply exponential shrinkage. @@ -151,9 +150,6 @@ def body(carry: Carry) -> Carry: t=carry.t, left=carry.left, right=carry.right, - log_L_proposal=carry.log_L, - log_L_constraint=log_L_constraint, - log_L0=seed_point.log_L0, midpoint_shrink=midpoint_shrink ) point_U, t = _pick_point_in_interval( @@ -232,11 +228,11 @@ def body(carry: Carry) -> Carry: class UniDimSliceSampler(BaseAbstractMarkovSampler): """ - Slice sampler for a single dimension. Produces correlated (non-i.i.d.) samples. + Slice sampler for a single dimension. Produces correlated samples. """ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: int, midpoint_shrink: bool, - perfect: bool, gradient_slice: bool = False): + perfect: bool, gradient_slice: bool = False, adaptive_shrink: bool = False): """ Unidimensional slice sampler. @@ -250,7 +246,8 @@ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: perfect: if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance. - gradient_slice: if true then always slice along gradient direction. + gradient_slice: if true then always slice along increasing gradient direction. + adaptive_shrink: if true then shrink interval to random point in interval, rather than midpoint. """ super().__init__(model=model) if num_slices < 1: @@ -264,6 +261,9 @@ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: self.midpoint_shrink = bool(midpoint_shrink) self.perfect = bool(perfect) self.gradient_slice = bool(gradient_slice) + self.adaptive_shrink = bool(adaptive_shrink) + if self.adaptive_shrink: + raise NotImplementedError("Adaptive shrinkage not implemented.") if not self.perfect: raise ValueError("Only perfect slice sampler is implemented.") @@ -341,8 +341,7 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample: final_sample, cumulative_samples = cumulative_op_static( op=propose_op, init=init_sample, - xs=random.split(key, self.num_slices), - unroll=2 + xs=random.split(key, self.num_slices) ) # Last sample is the final sample, the rest are potential phantom samples