From ce3436210dc725fba84133539d4e0d5144689ec7 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 1 May 2024 12:49:38 +0200 Subject: [PATCH 1/3] * Implement and use scan_associative_cumulative_op --- jaxns/experimental/evidence_maximisation.py | 4 +-- jaxns/internals/cumulative_ops.py | 34 ++++++++++++++++++++ jaxns/internals/shrinkage_statistics.py | 11 ++++--- jaxns/internals/tests/test_cumulative_ops.py | 17 +++++++++- jaxns/internals/tree_structure.py | 4 +-- jaxns/nested_sampler/standard_static.py | 4 +-- jaxns/samplers/multi_slice_sampler.py | 7 ++-- jaxns/samplers/uni_slice_sampler.py | 23 +++++++------ jaxns/utils.py | 4 +-- 9 files changed, 78 insertions(+), 30 deletions(-) diff --git a/jaxns/experimental/evidence_maximisation.py b/jaxns/experimental/evidence_maximisation.py index 299ed811..289a814d 100644 --- a/jaxns/experimental/evidence_maximisation.py +++ b/jaxns/experimental/evidence_maximisation.py @@ -12,7 +12,7 @@ from jaxopt import NonlinearCG, ArmijoSGD from tqdm import tqdm -from jaxns.internals.cumulative_ops import cumulative_op_static +from jaxns.internals.cumulative_ops import scan_associative_cumulative_op from jaxns.internals.log_semiring import LogSpace from jaxns.internals.logging import logger @@ -241,7 +241,7 @@ def op(log_Z, data): log_dZ = model.forward(data.U_samples) + data.log_weights return (LogSpace(log_Z) + LogSpace(log_dZ)).log_abs_val - log_Z, _ = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data) + log_Z, _ = scan_associative_cumulative_op(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data) return log_Z def loss(params: hk.MutableParams, data: MStepData): diff --git a/jaxns/internals/cumulative_ops.py b/jaxns/internals/cumulative_ops.py index 95a1c596..75fcb71a 100644 --- a/jaxns/internals/cumulative_ops.py +++ b/jaxns/internals/cumulative_ops.py @@ -2,6 +2,7 @@ import jax from jax import lax, numpy as jnp, tree_util +from tensorflow_probability.substrates.jax import math as tfp_math from jaxns.internals.types import IntArray, int_type @@ -9,6 +10,39 @@ Y = TypeVar('Y') +def scan_associative_cumulative_op(op: Callable[[V, Y], V], init: V, xs: Y, pre_op: bool = False) -> Tuple[V, V]: + """ + Compute a cumulative operation on an array of values using scan_associative. + + Args: + op: the operation to perform, must be associative. + init: the initial value. + xs: the array of values. + + Returns: + the final accumulated value, and the result of the cumulative operation applied on input + """ + + def associative_op(a, b): + return op(a, b) + + # Prepare the input array by prepending the initial value + full_input = jax.tree.map(lambda x, y: jnp.concatenate([x[None], y], axis=0), init, xs) + + # Apply the operation to accumulate results using scan_associative + scanned_results = tfp_math.scan_associative(associative_op, full_input) + + # The final accumulated value is the last element in the results + final_accumulate = scanned_results[-1] + + if pre_op: + scanned_results = scanned_results[:-1] + else: + scanned_results = scanned_results[1:] + + return final_accumulate, scanned_results + + def cumulative_op_static(op: Callable[[V, Y], V], init: V, xs: Y, pre_op: bool = False, unroll: int = 1) -> Tuple[ V, V]: """ diff --git a/jaxns/internals/shrinkage_statistics.py b/jaxns/internals/shrinkage_statistics.py index 1f9f7b1d..40dc8ab4 100644 --- a/jaxns/internals/shrinkage_statistics.py +++ b/jaxns/internals/shrinkage_statistics.py @@ -2,7 +2,7 @@ import jax.numpy as jnp -from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic +from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op from jaxns.internals.log_semiring import LogSpace from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges from jaxns.internals.types import MeasureType, EvidenceCalculation, float_type, IntArray, FloatArray @@ -28,8 +28,8 @@ def op(log_X, num_live_points): next_X_mean = X_mean * T_mean return next_X_mean.log_abs_val - _, log_X = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type), - xs=live_point_counts.num_live_points) + _, log_X = scan_associative_cumulative_op(op=op, init=jnp.asarray(-jnp.inf, float_type), + xs=live_point_counts.num_live_points) return log_X @@ -141,9 +141,10 @@ def compute_evidence_stats(log_L: MeasureType, num_live_points: FloatArray, num_ ) if num_samples is not None: stop_idx = num_samples - final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, stop_idx=stop_idx) + final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, + stop_idx=stop_idx) else: - final_accumulate, result = cumulative_op_static(op=_update_evidence_calc_op, init=init, xs=xs) + final_accumulate, result = scan_associative_cumulative_op(op=_update_evidence_calc_op, init=init, xs=xs) final_evidence_calculation = final_accumulate per_sample_evidence_calculation = result return final_evidence_calculation, per_sample_evidence_calculation diff --git a/jaxns/internals/tests/test_cumulative_ops.py b/jaxns/internals/tests/test_cumulative_ops.py index ade57694..509da75a 100644 --- a/jaxns/internals/tests/test_cumulative_ops.py +++ b/jaxns/internals/tests/test_cumulative_ops.py @@ -1,6 +1,6 @@ from jax import numpy as jnp -from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic +from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic, scan_associative_cumulative_op from jaxns.internals.types import float_type, int_type @@ -19,6 +19,21 @@ def op(accumulate, y): assert all(result == jnp.asarray([0, 1, 3], float_type)) +def test_scan_associative_cumulative_op(): + def op(accumulate, y): + return accumulate + y + + init = jnp.asarray(0, float_type) + xs = jnp.asarray([1, 2, 3], float_type) + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs) + assert final_accumulate == 6 + assert all(result == jnp.asarray([1, 3, 6], float_type)) + + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) + assert final_accumulate == 6 + assert all(result == jnp.asarray([0, 1, 3], float_type)) + + def test_cumulative_op_dynamic(): def op(accumulate, y): return accumulate + y diff --git a/jaxns/internals/tree_structure.py b/jaxns/internals/tree_structure.py index dbad50fd..a603b3bd 100644 --- a/jaxns/internals/tree_structure.py +++ b/jaxns/internals/tree_structure.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, lax, core from jax._src.numpy import lax_numpy -from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic +from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op from jaxns.internals.maps import remove_chunk_dim from jaxns.internals.types import MeasureType, IntArray, float_type, FloatArray, StaticStandardNestedSamplerState, \ int_type @@ -82,7 +82,7 @@ def op(crossed_edges, last_node): empty_fill=jnp.asarray(fake_edges, out_degree.dtype) ) else: - _, crossed_edges_sorted = cumulative_op_static( + _, crossed_edges_sorted = scan_associative_cumulative_op( op=op, init=jnp.asarray(1, out_degree.dtype), xs=sort_idx, diff --git a/jaxns/nested_sampler/standard_static.py b/jaxns/nested_sampler/standard_static.py index 7219e16b..0ea6cb09 100644 --- a/jaxns/nested_sampler/standard_static.py +++ b/jaxns/nested_sampler/standard_static.py @@ -6,7 +6,7 @@ from jax._src.lax import parallel from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import cumulative_op_static +from jaxns.internals.cumulative_ops import scan_associative_cumulative_op from jaxns.internals.log_semiring import LogSpace, normalise_log_space from jaxns.internals.logging import logger from jaxns.internals.shrinkage_statistics import compute_evidence_stats, init_evidence_calc, \ @@ -211,7 +211,7 @@ def body(carry: CarryType, unused_X: IntArray) -> Tuple[CarryType, ResultType]: # Update termination register _n = init_state.front_idx.size _num_samples = _n - evidence_calc_with_remaining, _ = cumulative_op_static( + evidence_calc_with_remaining, _ = scan_associative_cumulative_op( op=_update_evidence_calc_op, init=out_carry.evidence_calc, xs=EvidenceUpdateVariables( diff --git a/jaxns/samplers/multi_slice_sampler.py b/jaxns/samplers/multi_slice_sampler.py index b4ff0112..18323ff5 100644 --- a/jaxns/samplers/multi_slice_sampler.py +++ b/jaxns/samplers/multi_slice_sampler.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, random, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import cumulative_op_static +from jaxns.internals.cumulative_ops import scan_associative_cumulative_op from jaxns.internals.types import PRNGKey, FloatArray, BoolArray, Sample, int_type, StaticStandardNestedSamplerState, \ UType, \ IntArray, float_type, StaticStandardSampleCollection @@ -255,11 +255,10 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample: log_L=seed_point.log_L0, num_likelihood_evaluations=jnp.asarray(0, int_type) ) - final_sample, cumulative_samples = cumulative_op_static( + final_sample, cumulative_samples = scan_associative_cumulative_op( op=propose_op, init=init_sample, - xs=random.split(key, self.num_slices), - unroll=2 + xs=random.split(key, self.num_slices) ) # Last sample is the final sample, the rest are potential phantom samples diff --git a/jaxns/samplers/uni_slice_sampler.py b/jaxns/samplers/uni_slice_sampler.py index a0da6942..072a50f1 100644 --- a/jaxns/samplers/uni_slice_sampler.py +++ b/jaxns/samplers/uni_slice_sampler.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, random, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import cumulative_op_static +from jaxns.internals.cumulative_ops import scan_associative_cumulative_op from jaxns.internals.types import PRNGKey, FloatArray, BoolArray, Sample, float_type, int_type, \ StaticStandardNestedSamplerState, \ IntArray, UType, StaticStandardSampleCollection @@ -83,8 +83,7 @@ def _pick_point_in_interval(key: PRNGKey, point_U0: FloatArray, direction: Float return point_U, t -def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: FloatArray, log_L_proposal: FloatArray, - log_L_constraint: FloatArray, log_L0: FloatArray, +def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: FloatArray, midpoint_shrink: bool) -> Tuple[FloatArray, FloatArray]: """ Not successful proposal, so shrink, optionally apply exponential shrinkage. @@ -151,9 +150,6 @@ def body(carry: Carry) -> Carry: t=carry.t, left=carry.left, right=carry.right, - log_L_proposal=carry.log_L, - log_L_constraint=log_L_constraint, - log_L0=seed_point.log_L0, midpoint_shrink=midpoint_shrink ) point_U, t = _pick_point_in_interval( @@ -232,11 +228,11 @@ def body(carry: Carry) -> Carry: class UniDimSliceSampler(BaseAbstractMarkovSampler): """ - Slice sampler for a single dimension. Produces correlated (non-i.i.d.) samples. + Slice sampler for a single dimension. Produces correlated samples. """ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: int, midpoint_shrink: bool, - perfect: bool, gradient_slice: bool = False): + perfect: bool, gradient_slice: bool = False, adaptive_shrink: bool = False): """ Unidimensional slice sampler. @@ -250,7 +246,8 @@ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: perfect: if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance. - gradient_slice: if true then always slice along gradient direction. + gradient_slice: if true then always slice along increasing gradient direction. + adaptive_shrink: if true then shrink interval to random point in interval, rather than midpoint. """ super().__init__(model=model) if num_slices < 1: @@ -264,6 +261,9 @@ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: self.midpoint_shrink = bool(midpoint_shrink) self.perfect = bool(perfect) self.gradient_slice = bool(gradient_slice) + self.adaptive_shrink = bool(adaptive_shrink) + if self.adaptive_shrink: + raise NotImplementedError("Adaptive shrinkage not implemented.") if not self.perfect: raise ValueError("Only perfect slice sampler is implemented.") @@ -338,11 +338,10 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample: log_L=seed_point.log_L0, num_likelihood_evaluations=jnp.asarray(0, int_type) ) - final_sample, cumulative_samples = cumulative_op_static( + final_sample, cumulative_samples = scan_associative_cumulative_op( op=propose_op, init=init_sample, - xs=random.split(key, self.num_slices), - unroll=2 + xs=random.split(key, self.num_slices) ) # Last sample is the final sample, the rest are potential phantom samples diff --git a/jaxns/utils.py b/jaxns/utils.py index d849f5b6..b7be003e 100644 --- a/jaxns/utils.py +++ b/jaxns/utils.py @@ -8,7 +8,7 @@ from jax import numpy as jnp, vmap, random, jit, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import cumulative_op_static +from jaxns.internals.cumulative_ops import scan_associative_cumulative_op from jaxns.internals.log_semiring import LogSpace from jaxns.internals.maps import prepare_func_args from jaxns.internals.namedtuple_utils import serialise_namedtuple, deserialise_namedtuple @@ -444,7 +444,7 @@ def accumulate_op(accumulate, y): def single_log_Z_sample(key: PRNGKey) -> FloatArray: init = (jnp.asarray(-jnp.inf, log_L_samples.dtype), jnp.asarray(0., log_L_samples.dtype)) xs = (random.split(key, num_live_points_per_sample.shape[0]), num_live_points_per_sample, log_L_samples) - final_accumulate, _ = cumulative_op_static(accumulate_op, init=init, xs=xs) + final_accumulate, _ = scan_associative_cumulative_op(accumulate_op, init=init, xs=xs) (log_Z, _) = final_accumulate return log_Z From 4dff5f40f0e46de7cda046656b64a93a34851beb Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 1 May 2024 12:56:23 +0200 Subject: [PATCH 2/3] * Handle pytrees in scan_associative properly --- jaxns/internals/cumulative_ops.py | 6 +++--- jaxns/internals/tests/test_cumulative_ops.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/jaxns/internals/cumulative_ops.py b/jaxns/internals/cumulative_ops.py index 75fcb71a..dacc23ad 100644 --- a/jaxns/internals/cumulative_ops.py +++ b/jaxns/internals/cumulative_ops.py @@ -33,12 +33,12 @@ def associative_op(a, b): scanned_results = tfp_math.scan_associative(associative_op, full_input) # The final accumulated value is the last element in the results - final_accumulate = scanned_results[-1] + final_accumulate = jax.tree.map(lambda x: x[-1], scanned_results) if pre_op: - scanned_results = scanned_results[:-1] + scanned_results = jax.tree.map(lambda x: x[:-1], scanned_results) else: - scanned_results = scanned_results[1:] + scanned_results = jax.tree.map(lambda x: x[1:], scanned_results) return final_accumulate, scanned_results diff --git a/jaxns/internals/tests/test_cumulative_ops.py b/jaxns/internals/tests/test_cumulative_ops.py index 509da75a..5cd9d159 100644 --- a/jaxns/internals/tests/test_cumulative_ops.py +++ b/jaxns/internals/tests/test_cumulative_ops.py @@ -1,3 +1,4 @@ +import jax from jax import numpy as jnp from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic, scan_associative_cumulative_op @@ -33,6 +34,22 @@ def op(accumulate, y): assert final_accumulate == 6 assert all(result == jnp.asarray([0, 1, 3], float_type)) + # Test with pytrees for xs and ys + def op(accumulate, y): + return jax.tree.map(lambda x, y: x + y, accumulate, y) + + init = {'a': jnp.asarray(0, float_type), 'b': jnp.asarray(0, float_type)} + xs = {'a': jnp.asarray([1, 2, 3], float_type), 'b': jnp.asarray([4, 5, 6], float_type)} + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs) + assert final_accumulate == {'a': 6, 'b': 15} + assert all(result['a'] == jnp.asarray([1, 3, 6], float_type)) + assert all(result['b'] == jnp.asarray([4, 9, 15], float_type)) + + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) + assert final_accumulate == {'a': 6, 'b': 15} + assert all(result['a'] == jnp.asarray([0, 1, 3], float_type)) + assert all(result['b'] == jnp.asarray([0, 4, 9], float_type)) + def test_cumulative_op_dynamic(): def op(accumulate, y): From 182ac1e42e85dafb1b5e6c7fc5d60fe710bbbbc4 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 8 May 2024 21:50:32 +0200 Subject: [PATCH 3/3] * Fix, only a small subset of ops are associative. --- jaxns/experimental/evidence_maximisation.py | 4 +- jaxns/internals/cumulative_ops.py | 3 +- jaxns/internals/shrinkage_statistics.py | 4 +- jaxns/internals/tests/test_cumulative_ops.py | 49 ++++++++++++++++---- jaxns/internals/tests/test_tree_structure.py | 34 ++++++++------ jaxns/internals/tree_structure.py | 4 +- jaxns/nested_sampler/standard_static.py | 4 +- jaxns/samplers/multi_slice_sampler.py | 4 +- jaxns/samplers/uni_slice_sampler.py | 4 +- jaxns/utils.py | 4 +- 10 files changed, 74 insertions(+), 40 deletions(-) diff --git a/jaxns/experimental/evidence_maximisation.py b/jaxns/experimental/evidence_maximisation.py index 289a814d..299ed811 100644 --- a/jaxns/experimental/evidence_maximisation.py +++ b/jaxns/experimental/evidence_maximisation.py @@ -12,7 +12,7 @@ from jaxopt import NonlinearCG, ArmijoSGD from tqdm import tqdm -from jaxns.internals.cumulative_ops import scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_static from jaxns.internals.log_semiring import LogSpace from jaxns.internals.logging import logger @@ -241,7 +241,7 @@ def op(log_Z, data): log_dZ = model.forward(data.U_samples) + data.log_weights return (LogSpace(log_Z) + LogSpace(log_dZ)).log_abs_val - log_Z, _ = scan_associative_cumulative_op(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data) + log_Z, _ = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data) return log_Z def loss(params: hk.MutableParams, data: MStepData): diff --git a/jaxns/internals/cumulative_ops.py b/jaxns/internals/cumulative_ops.py index dacc23ad..c202b09a 100644 --- a/jaxns/internals/cumulative_ops.py +++ b/jaxns/internals/cumulative_ops.py @@ -6,11 +6,12 @@ from jaxns.internals.types import IntArray, int_type +X = TypeVar('X') V = TypeVar('V') Y = TypeVar('Y') -def scan_associative_cumulative_op(op: Callable[[V, Y], V], init: V, xs: Y, pre_op: bool = False) -> Tuple[V, V]: +def scan_associative_cumulative_op(op: Callable[[X, X], X], init: X, xs: X, pre_op: bool = False) -> Tuple[X, X]: """ Compute a cumulative operation on an array of values using scan_associative. diff --git a/jaxns/internals/shrinkage_statistics.py b/jaxns/internals/shrinkage_statistics.py index 40dc8ab4..685f362d 100644 --- a/jaxns/internals/shrinkage_statistics.py +++ b/jaxns/internals/shrinkage_statistics.py @@ -2,7 +2,7 @@ import jax.numpy as jnp -from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op, cumulative_op_static from jaxns.internals.log_semiring import LogSpace from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges from jaxns.internals.types import MeasureType, EvidenceCalculation, float_type, IntArray, FloatArray @@ -144,7 +144,7 @@ def compute_evidence_stats(log_L: MeasureType, num_live_points: FloatArray, num_ final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, stop_idx=stop_idx) else: - final_accumulate, result = scan_associative_cumulative_op(op=_update_evidence_calc_op, init=init, xs=xs) + final_accumulate, result = cumulative_op_static(op=_update_evidence_calc_op, init=init, xs=xs) final_evidence_calculation = final_accumulate per_sample_evidence_calculation = result return final_evidence_calculation, per_sample_evidence_calculation diff --git a/jaxns/internals/tests/test_cumulative_ops.py b/jaxns/internals/tests/test_cumulative_ops.py index 5cd9d159..6214b37d 100644 --- a/jaxns/internals/tests/test_cumulative_ops.py +++ b/jaxns/internals/tests/test_cumulative_ops.py @@ -1,4 +1,6 @@ import jax +import numpy as np +import pytest from jax import numpy as jnp from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic, scan_associative_cumulative_op @@ -20,23 +22,50 @@ def op(accumulate, y): assert all(result == jnp.asarray([0, 1, 3], float_type)) -def test_scan_associative_cumulative_op(): +@pytest.mark.parametrize("binary_op", [jnp.add, jnp.multiply, jnp.minimum, jnp.maximum]) +def test_scan_associative_cumulative_op(binary_op): def op(accumulate, y): - return accumulate + y + return binary_op(accumulate, y) - init = jnp.asarray(0, float_type) - xs = jnp.asarray([1, 2, 3], float_type) - final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs) - assert final_accumulate == 6 - assert all(result == jnp.asarray([1, 3, 6], float_type)) + init = jnp.asarray(1, float_type) + xs = jnp.arange(1, 11, dtype=float_type) + final_accumulate, result = scan_associative_cumulative_op(op=binary_op, init=init, xs=xs) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs) + assert final_accumulate == final_accumulate_expected + np.testing.assert_allclose(result, result_expected) final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) - assert final_accumulate == 6 - assert all(result == jnp.asarray([0, 1, 3], float_type)) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs, pre_op=True) + assert final_accumulate == final_accumulate_expected + np.testing.assert_allclose(result, result_expected) + + +@pytest.mark.parametrize("binary_op", [jnp.subtract, jnp.true_divide]) +def test_scan_associative_cumulative_op(binary_op): + def op(accumulate, y): + return binary_op(accumulate, y) + + init = jnp.asarray(1, float_type) + xs = jnp.arange(1, 11, dtype=float_type) + final_accumulate, result = scan_associative_cumulative_op(op=binary_op, init=init, xs=xs) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs) + with pytest.raises(AssertionError): + assert final_accumulate == final_accumulate_expected + with pytest.raises(AssertionError): + np.testing.assert_allclose(result, result_expected) + + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs, pre_op=True) + with pytest.raises(AssertionError): + assert final_accumulate == final_accumulate_expected + with pytest.raises(AssertionError): + np.testing.assert_allclose(result, result_expected) + +def test_scan_associative_cumulative_op_with_pytrees(): # Test with pytrees for xs and ys def op(accumulate, y): - return jax.tree.map(lambda x, y: x + y, accumulate, y) + return jax.tree.map(lambda x, y: jnp.add(x, y), accumulate, y) init = {'a': jnp.asarray(0, float_type), 'b': jnp.asarray(0, float_type)} xs = {'a': jnp.asarray([1, 2, 3], float_type), 'b': jnp.asarray([4, 5, 6], float_type)} diff --git a/jaxns/internals/tests/test_tree_structure.py b/jaxns/internals/tests/test_tree_structure.py index bd60c88f..03ea09e6 100644 --- a/jaxns/internals/tests/test_tree_structure.py +++ b/jaxns/internals/tests/test_tree_structure.py @@ -10,6 +10,13 @@ from jaxns.internals.types import StaticStandardNestedSamplerState, StaticStandardSampleCollection +def pytree_assert_equal(a, b): + print(a) + print(b) + for x, y in zip(jax.tree.leaves(a), jax.tree.leaves(b)): + np.testing.assert_allclose(x, y) + + def test_naive(): S = SampleTreeGraph( sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]), @@ -29,20 +36,18 @@ def test_basic(): sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]), log_L=jnp.asarray([1, 2, 3, 4, 5, 6]) ) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S))) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S))) - assert all( - jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S))) + pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S)) + pytree_assert_equal(count_crossed_edges(S), count_old(S)) + pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S)) S = SampleTreeGraph( sender_node_idx=jnp.asarray([0, 0, 0, 1, 3, 2]), log_L=jnp.asarray([1, 2, 3, 4, 6, 5]) ) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S))) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S))) - assert all( - jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S))) + pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S)) + pytree_assert_equal(count_crossed_edges(S), count_old(S)) + pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S)) def test_with_num_samples(): @@ -57,9 +62,9 @@ def test_with_num_samples(): log_L=jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8]) ) - assert all(jax.tree.map(lambda x, y: np.array_equal(x[:num_samples], y), - count_crossed_edges(S1, num_samples), - count_crossed_edges(S2))) + x = jax.tree.map(lambda x: x[:num_samples], count_crossed_edges(S1, num_samples)) + + pytree_assert_equal(x, count_crossed_edges(S2)) output = count_crossed_edges(S1, num_samples) print(output) @@ -89,10 +94,9 @@ def test_random_tree(): plot_tree(S) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S))) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S))) - assert all( - jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S))) + pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S)) + pytree_assert_equal(count_crossed_edges(S), count_old(S)) + pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S)) T = count_crossed_edges_less_fast(S) import pylab as plt diff --git a/jaxns/internals/tree_structure.py b/jaxns/internals/tree_structure.py index a603b3bd..bda36c8e 100644 --- a/jaxns/internals/tree_structure.py +++ b/jaxns/internals/tree_structure.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, lax, core from jax._src.numpy import lax_numpy -from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op, cumulative_op_static from jaxns.internals.maps import remove_chunk_dim from jaxns.internals.types import MeasureType, IntArray, float_type, FloatArray, StaticStandardNestedSamplerState, \ int_type @@ -82,7 +82,7 @@ def op(crossed_edges, last_node): empty_fill=jnp.asarray(fake_edges, out_degree.dtype) ) else: - _, crossed_edges_sorted = scan_associative_cumulative_op( + _, crossed_edges_sorted = cumulative_op_static( op=op, init=jnp.asarray(1, out_degree.dtype), xs=sort_idx, diff --git a/jaxns/nested_sampler/standard_static.py b/jaxns/nested_sampler/standard_static.py index 0ea6cb09..7219e16b 100644 --- a/jaxns/nested_sampler/standard_static.py +++ b/jaxns/nested_sampler/standard_static.py @@ -6,7 +6,7 @@ from jax._src.lax import parallel from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_static from jaxns.internals.log_semiring import LogSpace, normalise_log_space from jaxns.internals.logging import logger from jaxns.internals.shrinkage_statistics import compute_evidence_stats, init_evidence_calc, \ @@ -211,7 +211,7 @@ def body(carry: CarryType, unused_X: IntArray) -> Tuple[CarryType, ResultType]: # Update termination register _n = init_state.front_idx.size _num_samples = _n - evidence_calc_with_remaining, _ = scan_associative_cumulative_op( + evidence_calc_with_remaining, _ = cumulative_op_static( op=_update_evidence_calc_op, init=out_carry.evidence_calc, xs=EvidenceUpdateVariables( diff --git a/jaxns/samplers/multi_slice_sampler.py b/jaxns/samplers/multi_slice_sampler.py index 18323ff5..b7ee78c8 100644 --- a/jaxns/samplers/multi_slice_sampler.py +++ b/jaxns/samplers/multi_slice_sampler.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, random, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_static from jaxns.internals.types import PRNGKey, FloatArray, BoolArray, Sample, int_type, StaticStandardNestedSamplerState, \ UType, \ IntArray, float_type, StaticStandardSampleCollection @@ -255,7 +255,7 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample: log_L=seed_point.log_L0, num_likelihood_evaluations=jnp.asarray(0, int_type) ) - final_sample, cumulative_samples = scan_associative_cumulative_op( + final_sample, cumulative_samples = cumulative_op_static( op=propose_op, init=init_sample, xs=random.split(key, self.num_slices) diff --git a/jaxns/samplers/uni_slice_sampler.py b/jaxns/samplers/uni_slice_sampler.py index 072a50f1..799ee7b7 100644 --- a/jaxns/samplers/uni_slice_sampler.py +++ b/jaxns/samplers/uni_slice_sampler.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, random, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_static from jaxns.internals.types import PRNGKey, FloatArray, BoolArray, Sample, float_type, int_type, \ StaticStandardNestedSamplerState, \ IntArray, UType, StaticStandardSampleCollection @@ -338,7 +338,7 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample: log_L=seed_point.log_L0, num_likelihood_evaluations=jnp.asarray(0, int_type) ) - final_sample, cumulative_samples = scan_associative_cumulative_op( + final_sample, cumulative_samples = cumulative_op_static( op=propose_op, init=init_sample, xs=random.split(key, self.num_slices) diff --git a/jaxns/utils.py b/jaxns/utils.py index b7be003e..d849f5b6 100644 --- a/jaxns/utils.py +++ b/jaxns/utils.py @@ -8,7 +8,7 @@ from jax import numpy as jnp, vmap, random, jit, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_static from jaxns.internals.log_semiring import LogSpace from jaxns.internals.maps import prepare_func_args from jaxns.internals.namedtuple_utils import serialise_namedtuple, deserialise_namedtuple @@ -444,7 +444,7 @@ def accumulate_op(accumulate, y): def single_log_Z_sample(key: PRNGKey) -> FloatArray: init = (jnp.asarray(-jnp.inf, log_L_samples.dtype), jnp.asarray(0., log_L_samples.dtype)) xs = (random.split(key, num_live_points_per_sample.shape[0]), num_live_points_per_sample, log_L_samples) - final_accumulate, _ = scan_associative_cumulative_op(accumulate_op, init=init, xs=xs) + final_accumulate, _ = cumulative_op_static(accumulate_op, init=init, xs=xs) (log_Z, _) = final_accumulate return log_Z