From ce3436210dc725fba84133539d4e0d5144689ec7 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 1 May 2024 12:49:38 +0200 Subject: [PATCH 1/4] * Implement and use scan_associative_cumulative_op --- jaxns/experimental/evidence_maximisation.py | 4 +-- jaxns/internals/cumulative_ops.py | 34 ++++++++++++++++++++ jaxns/internals/shrinkage_statistics.py | 11 ++++--- jaxns/internals/tests/test_cumulative_ops.py | 17 +++++++++- jaxns/internals/tree_structure.py | 4 +-- jaxns/nested_sampler/standard_static.py | 4 +-- jaxns/samplers/multi_slice_sampler.py | 7 ++-- jaxns/samplers/uni_slice_sampler.py | 23 +++++++------ jaxns/utils.py | 4 +-- 9 files changed, 78 insertions(+), 30 deletions(-) diff --git a/jaxns/experimental/evidence_maximisation.py b/jaxns/experimental/evidence_maximisation.py index 299ed811..289a814d 100644 --- a/jaxns/experimental/evidence_maximisation.py +++ b/jaxns/experimental/evidence_maximisation.py @@ -12,7 +12,7 @@ from jaxopt import NonlinearCG, ArmijoSGD from tqdm import tqdm -from jaxns.internals.cumulative_ops import cumulative_op_static +from jaxns.internals.cumulative_ops import scan_associative_cumulative_op from jaxns.internals.log_semiring import LogSpace from jaxns.internals.logging import logger @@ -241,7 +241,7 @@ def op(log_Z, data): log_dZ = model.forward(data.U_samples) + data.log_weights return (LogSpace(log_Z) + LogSpace(log_dZ)).log_abs_val - log_Z, _ = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data) + log_Z, _ = scan_associative_cumulative_op(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data) return log_Z def loss(params: hk.MutableParams, data: MStepData): diff --git a/jaxns/internals/cumulative_ops.py b/jaxns/internals/cumulative_ops.py index 95a1c596..75fcb71a 100644 --- a/jaxns/internals/cumulative_ops.py +++ b/jaxns/internals/cumulative_ops.py @@ -2,6 +2,7 @@ import jax from jax import lax, numpy as jnp, tree_util +from tensorflow_probability.substrates.jax import math as tfp_math from jaxns.internals.types import IntArray, int_type @@ -9,6 +10,39 @@ Y = TypeVar('Y') +def scan_associative_cumulative_op(op: Callable[[V, Y], V], init: V, xs: Y, pre_op: bool = False) -> Tuple[V, V]: + """ + Compute a cumulative operation on an array of values using scan_associative. + + Args: + op: the operation to perform, must be associative. + init: the initial value. + xs: the array of values. + + Returns: + the final accumulated value, and the result of the cumulative operation applied on input + """ + + def associative_op(a, b): + return op(a, b) + + # Prepare the input array by prepending the initial value + full_input = jax.tree.map(lambda x, y: jnp.concatenate([x[None], y], axis=0), init, xs) + + # Apply the operation to accumulate results using scan_associative + scanned_results = tfp_math.scan_associative(associative_op, full_input) + + # The final accumulated value is the last element in the results + final_accumulate = scanned_results[-1] + + if pre_op: + scanned_results = scanned_results[:-1] + else: + scanned_results = scanned_results[1:] + + return final_accumulate, scanned_results + + def cumulative_op_static(op: Callable[[V, Y], V], init: V, xs: Y, pre_op: bool = False, unroll: int = 1) -> Tuple[ V, V]: """ diff --git a/jaxns/internals/shrinkage_statistics.py b/jaxns/internals/shrinkage_statistics.py index 1f9f7b1d..40dc8ab4 100644 --- a/jaxns/internals/shrinkage_statistics.py +++ b/jaxns/internals/shrinkage_statistics.py @@ -2,7 +2,7 @@ import jax.numpy as jnp -from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic +from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op from jaxns.internals.log_semiring import LogSpace from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges from jaxns.internals.types import MeasureType, EvidenceCalculation, float_type, IntArray, FloatArray @@ -28,8 +28,8 @@ def op(log_X, num_live_points): next_X_mean = X_mean * T_mean return next_X_mean.log_abs_val - _, log_X = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type), - xs=live_point_counts.num_live_points) + _, log_X = scan_associative_cumulative_op(op=op, init=jnp.asarray(-jnp.inf, float_type), + xs=live_point_counts.num_live_points) return log_X @@ -141,9 +141,10 @@ def compute_evidence_stats(log_L: MeasureType, num_live_points: FloatArray, num_ ) if num_samples is not None: stop_idx = num_samples - final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, stop_idx=stop_idx) + final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, + stop_idx=stop_idx) else: - final_accumulate, result = cumulative_op_static(op=_update_evidence_calc_op, init=init, xs=xs) + final_accumulate, result = scan_associative_cumulative_op(op=_update_evidence_calc_op, init=init, xs=xs) final_evidence_calculation = final_accumulate per_sample_evidence_calculation = result return final_evidence_calculation, per_sample_evidence_calculation diff --git a/jaxns/internals/tests/test_cumulative_ops.py b/jaxns/internals/tests/test_cumulative_ops.py index ade57694..509da75a 100644 --- a/jaxns/internals/tests/test_cumulative_ops.py +++ b/jaxns/internals/tests/test_cumulative_ops.py @@ -1,6 +1,6 @@ from jax import numpy as jnp -from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic +from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic, scan_associative_cumulative_op from jaxns.internals.types import float_type, int_type @@ -19,6 +19,21 @@ def op(accumulate, y): assert all(result == jnp.asarray([0, 1, 3], float_type)) +def test_scan_associative_cumulative_op(): + def op(accumulate, y): + return accumulate + y + + init = jnp.asarray(0, float_type) + xs = jnp.asarray([1, 2, 3], float_type) + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs) + assert final_accumulate == 6 + assert all(result == jnp.asarray([1, 3, 6], float_type)) + + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) + assert final_accumulate == 6 + assert all(result == jnp.asarray([0, 1, 3], float_type)) + + def test_cumulative_op_dynamic(): def op(accumulate, y): return accumulate + y diff --git a/jaxns/internals/tree_structure.py b/jaxns/internals/tree_structure.py index dbad50fd..a603b3bd 100644 --- a/jaxns/internals/tree_structure.py +++ b/jaxns/internals/tree_structure.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, lax, core from jax._src.numpy import lax_numpy -from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic +from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op from jaxns.internals.maps import remove_chunk_dim from jaxns.internals.types import MeasureType, IntArray, float_type, FloatArray, StaticStandardNestedSamplerState, \ int_type @@ -82,7 +82,7 @@ def op(crossed_edges, last_node): empty_fill=jnp.asarray(fake_edges, out_degree.dtype) ) else: - _, crossed_edges_sorted = cumulative_op_static( + _, crossed_edges_sorted = scan_associative_cumulative_op( op=op, init=jnp.asarray(1, out_degree.dtype), xs=sort_idx, diff --git a/jaxns/nested_sampler/standard_static.py b/jaxns/nested_sampler/standard_static.py index 7219e16b..0ea6cb09 100644 --- a/jaxns/nested_sampler/standard_static.py +++ b/jaxns/nested_sampler/standard_static.py @@ -6,7 +6,7 @@ from jax._src.lax import parallel from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import cumulative_op_static +from jaxns.internals.cumulative_ops import scan_associative_cumulative_op from jaxns.internals.log_semiring import LogSpace, normalise_log_space from jaxns.internals.logging import logger from jaxns.internals.shrinkage_statistics import compute_evidence_stats, init_evidence_calc, \ @@ -211,7 +211,7 @@ def body(carry: CarryType, unused_X: IntArray) -> Tuple[CarryType, ResultType]: # Update termination register _n = init_state.front_idx.size _num_samples = _n - evidence_calc_with_remaining, _ = cumulative_op_static( + evidence_calc_with_remaining, _ = scan_associative_cumulative_op( op=_update_evidence_calc_op, init=out_carry.evidence_calc, xs=EvidenceUpdateVariables( diff --git a/jaxns/samplers/multi_slice_sampler.py b/jaxns/samplers/multi_slice_sampler.py index b4ff0112..18323ff5 100644 --- a/jaxns/samplers/multi_slice_sampler.py +++ b/jaxns/samplers/multi_slice_sampler.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, random, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import cumulative_op_static +from jaxns.internals.cumulative_ops import scan_associative_cumulative_op from jaxns.internals.types import PRNGKey, FloatArray, BoolArray, Sample, int_type, StaticStandardNestedSamplerState, \ UType, \ IntArray, float_type, StaticStandardSampleCollection @@ -255,11 +255,10 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample: log_L=seed_point.log_L0, num_likelihood_evaluations=jnp.asarray(0, int_type) ) - final_sample, cumulative_samples = cumulative_op_static( + final_sample, cumulative_samples = scan_associative_cumulative_op( op=propose_op, init=init_sample, - xs=random.split(key, self.num_slices), - unroll=2 + xs=random.split(key, self.num_slices) ) # Last sample is the final sample, the rest are potential phantom samples diff --git a/jaxns/samplers/uni_slice_sampler.py b/jaxns/samplers/uni_slice_sampler.py index a0da6942..072a50f1 100644 --- a/jaxns/samplers/uni_slice_sampler.py +++ b/jaxns/samplers/uni_slice_sampler.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, random, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import cumulative_op_static +from jaxns.internals.cumulative_ops import scan_associative_cumulative_op from jaxns.internals.types import PRNGKey, FloatArray, BoolArray, Sample, float_type, int_type, \ StaticStandardNestedSamplerState, \ IntArray, UType, StaticStandardSampleCollection @@ -83,8 +83,7 @@ def _pick_point_in_interval(key: PRNGKey, point_U0: FloatArray, direction: Float return point_U, t -def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: FloatArray, log_L_proposal: FloatArray, - log_L_constraint: FloatArray, log_L0: FloatArray, +def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: FloatArray, midpoint_shrink: bool) -> Tuple[FloatArray, FloatArray]: """ Not successful proposal, so shrink, optionally apply exponential shrinkage. @@ -151,9 +150,6 @@ def body(carry: Carry) -> Carry: t=carry.t, left=carry.left, right=carry.right, - log_L_proposal=carry.log_L, - log_L_constraint=log_L_constraint, - log_L0=seed_point.log_L0, midpoint_shrink=midpoint_shrink ) point_U, t = _pick_point_in_interval( @@ -232,11 +228,11 @@ def body(carry: Carry) -> Carry: class UniDimSliceSampler(BaseAbstractMarkovSampler): """ - Slice sampler for a single dimension. Produces correlated (non-i.i.d.) samples. + Slice sampler for a single dimension. Produces correlated samples. """ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: int, midpoint_shrink: bool, - perfect: bool, gradient_slice: bool = False): + perfect: bool, gradient_slice: bool = False, adaptive_shrink: bool = False): """ Unidimensional slice sampler. @@ -250,7 +246,8 @@ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: perfect: if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure. Otherwise, uses a doubling procedure (exponentially finding bracket). Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance. - gradient_slice: if true then always slice along gradient direction. + gradient_slice: if true then always slice along increasing gradient direction. + adaptive_shrink: if true then shrink interval to random point in interval, rather than midpoint. """ super().__init__(model=model) if num_slices < 1: @@ -264,6 +261,9 @@ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: self.midpoint_shrink = bool(midpoint_shrink) self.perfect = bool(perfect) self.gradient_slice = bool(gradient_slice) + self.adaptive_shrink = bool(adaptive_shrink) + if self.adaptive_shrink: + raise NotImplementedError("Adaptive shrinkage not implemented.") if not self.perfect: raise ValueError("Only perfect slice sampler is implemented.") @@ -338,11 +338,10 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample: log_L=seed_point.log_L0, num_likelihood_evaluations=jnp.asarray(0, int_type) ) - final_sample, cumulative_samples = cumulative_op_static( + final_sample, cumulative_samples = scan_associative_cumulative_op( op=propose_op, init=init_sample, - xs=random.split(key, self.num_slices), - unroll=2 + xs=random.split(key, self.num_slices) ) # Last sample is the final sample, the rest are potential phantom samples diff --git a/jaxns/utils.py b/jaxns/utils.py index d849f5b6..b7be003e 100644 --- a/jaxns/utils.py +++ b/jaxns/utils.py @@ -8,7 +8,7 @@ from jax import numpy as jnp, vmap, random, jit, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import cumulative_op_static +from jaxns.internals.cumulative_ops import scan_associative_cumulative_op from jaxns.internals.log_semiring import LogSpace from jaxns.internals.maps import prepare_func_args from jaxns.internals.namedtuple_utils import serialise_namedtuple, deserialise_namedtuple @@ -444,7 +444,7 @@ def accumulate_op(accumulate, y): def single_log_Z_sample(key: PRNGKey) -> FloatArray: init = (jnp.asarray(-jnp.inf, log_L_samples.dtype), jnp.asarray(0., log_L_samples.dtype)) xs = (random.split(key, num_live_points_per_sample.shape[0]), num_live_points_per_sample, log_L_samples) - final_accumulate, _ = cumulative_op_static(accumulate_op, init=init, xs=xs) + final_accumulate, _ = scan_associative_cumulative_op(accumulate_op, init=init, xs=xs) (log_Z, _) = final_accumulate return log_Z From 4dff5f40f0e46de7cda046656b64a93a34851beb Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 1 May 2024 12:56:23 +0200 Subject: [PATCH 2/4] * Handle pytrees in scan_associative properly --- jaxns/internals/cumulative_ops.py | 6 +++--- jaxns/internals/tests/test_cumulative_ops.py | 17 +++++++++++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/jaxns/internals/cumulative_ops.py b/jaxns/internals/cumulative_ops.py index 75fcb71a..dacc23ad 100644 --- a/jaxns/internals/cumulative_ops.py +++ b/jaxns/internals/cumulative_ops.py @@ -33,12 +33,12 @@ def associative_op(a, b): scanned_results = tfp_math.scan_associative(associative_op, full_input) # The final accumulated value is the last element in the results - final_accumulate = scanned_results[-1] + final_accumulate = jax.tree.map(lambda x: x[-1], scanned_results) if pre_op: - scanned_results = scanned_results[:-1] + scanned_results = jax.tree.map(lambda x: x[:-1], scanned_results) else: - scanned_results = scanned_results[1:] + scanned_results = jax.tree.map(lambda x: x[1:], scanned_results) return final_accumulate, scanned_results diff --git a/jaxns/internals/tests/test_cumulative_ops.py b/jaxns/internals/tests/test_cumulative_ops.py index 509da75a..5cd9d159 100644 --- a/jaxns/internals/tests/test_cumulative_ops.py +++ b/jaxns/internals/tests/test_cumulative_ops.py @@ -1,3 +1,4 @@ +import jax from jax import numpy as jnp from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic, scan_associative_cumulative_op @@ -33,6 +34,22 @@ def op(accumulate, y): assert final_accumulate == 6 assert all(result == jnp.asarray([0, 1, 3], float_type)) + # Test with pytrees for xs and ys + def op(accumulate, y): + return jax.tree.map(lambda x, y: x + y, accumulate, y) + + init = {'a': jnp.asarray(0, float_type), 'b': jnp.asarray(0, float_type)} + xs = {'a': jnp.asarray([1, 2, 3], float_type), 'b': jnp.asarray([4, 5, 6], float_type)} + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs) + assert final_accumulate == {'a': 6, 'b': 15} + assert all(result['a'] == jnp.asarray([1, 3, 6], float_type)) + assert all(result['b'] == jnp.asarray([4, 9, 15], float_type)) + + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) + assert final_accumulate == {'a': 6, 'b': 15} + assert all(result['a'] == jnp.asarray([0, 1, 3], float_type)) + assert all(result['b'] == jnp.asarray([0, 4, 9], float_type)) + def test_cumulative_op_dynamic(): def op(accumulate, y): From 182ac1e42e85dafb1b5e6c7fc5d60fe710bbbbc4 Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 8 May 2024 21:50:32 +0200 Subject: [PATCH 3/4] * Fix, only a small subset of ops are associative. --- jaxns/experimental/evidence_maximisation.py | 4 +- jaxns/internals/cumulative_ops.py | 3 +- jaxns/internals/shrinkage_statistics.py | 4 +- jaxns/internals/tests/test_cumulative_ops.py | 49 ++++++++++++++++---- jaxns/internals/tests/test_tree_structure.py | 34 ++++++++------ jaxns/internals/tree_structure.py | 4 +- jaxns/nested_sampler/standard_static.py | 4 +- jaxns/samplers/multi_slice_sampler.py | 4 +- jaxns/samplers/uni_slice_sampler.py | 4 +- jaxns/utils.py | 4 +- 10 files changed, 74 insertions(+), 40 deletions(-) diff --git a/jaxns/experimental/evidence_maximisation.py b/jaxns/experimental/evidence_maximisation.py index 289a814d..299ed811 100644 --- a/jaxns/experimental/evidence_maximisation.py +++ b/jaxns/experimental/evidence_maximisation.py @@ -12,7 +12,7 @@ from jaxopt import NonlinearCG, ArmijoSGD from tqdm import tqdm -from jaxns.internals.cumulative_ops import scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_static from jaxns.internals.log_semiring import LogSpace from jaxns.internals.logging import logger @@ -241,7 +241,7 @@ def op(log_Z, data): log_dZ = model.forward(data.U_samples) + data.log_weights return (LogSpace(log_Z) + LogSpace(log_dZ)).log_abs_val - log_Z, _ = scan_associative_cumulative_op(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data) + log_Z, _ = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type), xs=data) return log_Z def loss(params: hk.MutableParams, data: MStepData): diff --git a/jaxns/internals/cumulative_ops.py b/jaxns/internals/cumulative_ops.py index dacc23ad..c202b09a 100644 --- a/jaxns/internals/cumulative_ops.py +++ b/jaxns/internals/cumulative_ops.py @@ -6,11 +6,12 @@ from jaxns.internals.types import IntArray, int_type +X = TypeVar('X') V = TypeVar('V') Y = TypeVar('Y') -def scan_associative_cumulative_op(op: Callable[[V, Y], V], init: V, xs: Y, pre_op: bool = False) -> Tuple[V, V]: +def scan_associative_cumulative_op(op: Callable[[X, X], X], init: X, xs: X, pre_op: bool = False) -> Tuple[X, X]: """ Compute a cumulative operation on an array of values using scan_associative. diff --git a/jaxns/internals/shrinkage_statistics.py b/jaxns/internals/shrinkage_statistics.py index 40dc8ab4..685f362d 100644 --- a/jaxns/internals/shrinkage_statistics.py +++ b/jaxns/internals/shrinkage_statistics.py @@ -2,7 +2,7 @@ import jax.numpy as jnp -from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op, cumulative_op_static from jaxns.internals.log_semiring import LogSpace from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges from jaxns.internals.types import MeasureType, EvidenceCalculation, float_type, IntArray, FloatArray @@ -144,7 +144,7 @@ def compute_evidence_stats(log_L: MeasureType, num_live_points: FloatArray, num_ final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, stop_idx=stop_idx) else: - final_accumulate, result = scan_associative_cumulative_op(op=_update_evidence_calc_op, init=init, xs=xs) + final_accumulate, result = cumulative_op_static(op=_update_evidence_calc_op, init=init, xs=xs) final_evidence_calculation = final_accumulate per_sample_evidence_calculation = result return final_evidence_calculation, per_sample_evidence_calculation diff --git a/jaxns/internals/tests/test_cumulative_ops.py b/jaxns/internals/tests/test_cumulative_ops.py index 5cd9d159..6214b37d 100644 --- a/jaxns/internals/tests/test_cumulative_ops.py +++ b/jaxns/internals/tests/test_cumulative_ops.py @@ -1,4 +1,6 @@ import jax +import numpy as np +import pytest from jax import numpy as jnp from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic, scan_associative_cumulative_op @@ -20,23 +22,50 @@ def op(accumulate, y): assert all(result == jnp.asarray([0, 1, 3], float_type)) -def test_scan_associative_cumulative_op(): +@pytest.mark.parametrize("binary_op", [jnp.add, jnp.multiply, jnp.minimum, jnp.maximum]) +def test_scan_associative_cumulative_op(binary_op): def op(accumulate, y): - return accumulate + y + return binary_op(accumulate, y) - init = jnp.asarray(0, float_type) - xs = jnp.asarray([1, 2, 3], float_type) - final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs) - assert final_accumulate == 6 - assert all(result == jnp.asarray([1, 3, 6], float_type)) + init = jnp.asarray(1, float_type) + xs = jnp.arange(1, 11, dtype=float_type) + final_accumulate, result = scan_associative_cumulative_op(op=binary_op, init=init, xs=xs) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs) + assert final_accumulate == final_accumulate_expected + np.testing.assert_allclose(result, result_expected) final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) - assert final_accumulate == 6 - assert all(result == jnp.asarray([0, 1, 3], float_type)) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs, pre_op=True) + assert final_accumulate == final_accumulate_expected + np.testing.assert_allclose(result, result_expected) + + +@pytest.mark.parametrize("binary_op", [jnp.subtract, jnp.true_divide]) +def test_scan_associative_cumulative_op(binary_op): + def op(accumulate, y): + return binary_op(accumulate, y) + + init = jnp.asarray(1, float_type) + xs = jnp.arange(1, 11, dtype=float_type) + final_accumulate, result = scan_associative_cumulative_op(op=binary_op, init=init, xs=xs) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs) + with pytest.raises(AssertionError): + assert final_accumulate == final_accumulate_expected + with pytest.raises(AssertionError): + np.testing.assert_allclose(result, result_expected) + + final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True) + final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs, pre_op=True) + with pytest.raises(AssertionError): + assert final_accumulate == final_accumulate_expected + with pytest.raises(AssertionError): + np.testing.assert_allclose(result, result_expected) + +def test_scan_associative_cumulative_op_with_pytrees(): # Test with pytrees for xs and ys def op(accumulate, y): - return jax.tree.map(lambda x, y: x + y, accumulate, y) + return jax.tree.map(lambda x, y: jnp.add(x, y), accumulate, y) init = {'a': jnp.asarray(0, float_type), 'b': jnp.asarray(0, float_type)} xs = {'a': jnp.asarray([1, 2, 3], float_type), 'b': jnp.asarray([4, 5, 6], float_type)} diff --git a/jaxns/internals/tests/test_tree_structure.py b/jaxns/internals/tests/test_tree_structure.py index bd60c88f..03ea09e6 100644 --- a/jaxns/internals/tests/test_tree_structure.py +++ b/jaxns/internals/tests/test_tree_structure.py @@ -10,6 +10,13 @@ from jaxns.internals.types import StaticStandardNestedSamplerState, StaticStandardSampleCollection +def pytree_assert_equal(a, b): + print(a) + print(b) + for x, y in zip(jax.tree.leaves(a), jax.tree.leaves(b)): + np.testing.assert_allclose(x, y) + + def test_naive(): S = SampleTreeGraph( sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]), @@ -29,20 +36,18 @@ def test_basic(): sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]), log_L=jnp.asarray([1, 2, 3, 4, 5, 6]) ) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S))) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S))) - assert all( - jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S))) + pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S)) + pytree_assert_equal(count_crossed_edges(S), count_old(S)) + pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S)) S = SampleTreeGraph( sender_node_idx=jnp.asarray([0, 0, 0, 1, 3, 2]), log_L=jnp.asarray([1, 2, 3, 4, 6, 5]) ) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S))) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S))) - assert all( - jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S))) + pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S)) + pytree_assert_equal(count_crossed_edges(S), count_old(S)) + pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S)) def test_with_num_samples(): @@ -57,9 +62,9 @@ def test_with_num_samples(): log_L=jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8]) ) - assert all(jax.tree.map(lambda x, y: np.array_equal(x[:num_samples], y), - count_crossed_edges(S1, num_samples), - count_crossed_edges(S2))) + x = jax.tree.map(lambda x: x[:num_samples], count_crossed_edges(S1, num_samples)) + + pytree_assert_equal(x, count_crossed_edges(S2)) output = count_crossed_edges(S1, num_samples) print(output) @@ -89,10 +94,9 @@ def test_random_tree(): plot_tree(S) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S))) - assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S))) - assert all( - jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S))) + pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S)) + pytree_assert_equal(count_crossed_edges(S), count_old(S)) + pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S)) T = count_crossed_edges_less_fast(S) import pylab as plt diff --git a/jaxns/internals/tree_structure.py b/jaxns/internals/tree_structure.py index a603b3bd..bda36c8e 100644 --- a/jaxns/internals/tree_structure.py +++ b/jaxns/internals/tree_structure.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, lax, core from jax._src.numpy import lax_numpy -from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op, cumulative_op_static from jaxns.internals.maps import remove_chunk_dim from jaxns.internals.types import MeasureType, IntArray, float_type, FloatArray, StaticStandardNestedSamplerState, \ int_type @@ -82,7 +82,7 @@ def op(crossed_edges, last_node): empty_fill=jnp.asarray(fake_edges, out_degree.dtype) ) else: - _, crossed_edges_sorted = scan_associative_cumulative_op( + _, crossed_edges_sorted = cumulative_op_static( op=op, init=jnp.asarray(1, out_degree.dtype), xs=sort_idx, diff --git a/jaxns/nested_sampler/standard_static.py b/jaxns/nested_sampler/standard_static.py index 0ea6cb09..7219e16b 100644 --- a/jaxns/nested_sampler/standard_static.py +++ b/jaxns/nested_sampler/standard_static.py @@ -6,7 +6,7 @@ from jax._src.lax import parallel from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_static from jaxns.internals.log_semiring import LogSpace, normalise_log_space from jaxns.internals.logging import logger from jaxns.internals.shrinkage_statistics import compute_evidence_stats, init_evidence_calc, \ @@ -211,7 +211,7 @@ def body(carry: CarryType, unused_X: IntArray) -> Tuple[CarryType, ResultType]: # Update termination register _n = init_state.front_idx.size _num_samples = _n - evidence_calc_with_remaining, _ = scan_associative_cumulative_op( + evidence_calc_with_remaining, _ = cumulative_op_static( op=_update_evidence_calc_op, init=out_carry.evidence_calc, xs=EvidenceUpdateVariables( diff --git a/jaxns/samplers/multi_slice_sampler.py b/jaxns/samplers/multi_slice_sampler.py index 18323ff5..b7ee78c8 100644 --- a/jaxns/samplers/multi_slice_sampler.py +++ b/jaxns/samplers/multi_slice_sampler.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, random, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_static from jaxns.internals.types import PRNGKey, FloatArray, BoolArray, Sample, int_type, StaticStandardNestedSamplerState, \ UType, \ IntArray, float_type, StaticStandardSampleCollection @@ -255,7 +255,7 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample: log_L=seed_point.log_L0, num_likelihood_evaluations=jnp.asarray(0, int_type) ) - final_sample, cumulative_samples = scan_associative_cumulative_op( + final_sample, cumulative_samples = cumulative_op_static( op=propose_op, init=init_sample, xs=random.split(key, self.num_slices) diff --git a/jaxns/samplers/uni_slice_sampler.py b/jaxns/samplers/uni_slice_sampler.py index 072a50f1..799ee7b7 100644 --- a/jaxns/samplers/uni_slice_sampler.py +++ b/jaxns/samplers/uni_slice_sampler.py @@ -4,7 +4,7 @@ from jax import numpy as jnp, random, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_static from jaxns.internals.types import PRNGKey, FloatArray, BoolArray, Sample, float_type, int_type, \ StaticStandardNestedSamplerState, \ IntArray, UType, StaticStandardSampleCollection @@ -338,7 +338,7 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample: log_L=seed_point.log_L0, num_likelihood_evaluations=jnp.asarray(0, int_type) ) - final_sample, cumulative_samples = scan_associative_cumulative_op( + final_sample, cumulative_samples = cumulative_op_static( op=propose_op, init=init_sample, xs=random.split(key, self.num_slices) diff --git a/jaxns/utils.py b/jaxns/utils.py index b7be003e..d849f5b6 100644 --- a/jaxns/utils.py +++ b/jaxns/utils.py @@ -8,7 +8,7 @@ from jax import numpy as jnp, vmap, random, jit, lax from jaxns.framework.bases import BaseAbstractModel -from jaxns.internals.cumulative_ops import scan_associative_cumulative_op +from jaxns.internals.cumulative_ops import cumulative_op_static from jaxns.internals.log_semiring import LogSpace from jaxns.internals.maps import prepare_func_args from jaxns.internals.namedtuple_utils import serialise_namedtuple, deserialise_namedtuple @@ -444,7 +444,7 @@ def accumulate_op(accumulate, y): def single_log_Z_sample(key: PRNGKey) -> FloatArray: init = (jnp.asarray(-jnp.inf, log_L_samples.dtype), jnp.asarray(0., log_L_samples.dtype)) xs = (random.split(key, num_live_points_per_sample.shape[0]), num_live_points_per_sample, log_L_samples) - final_accumulate, _ = scan_associative_cumulative_op(accumulate_op, init=init, xs=xs) + final_accumulate, _ = cumulative_op_static(accumulate_op, init=init, xs=xs) (log_Z, _) = final_accumulate return log_Z From 7f8bff695cc29073cc4e5f3208511e53eabf821e Mon Sep 17 00:00:00 2001 From: joshuaalbert Date: Wed, 15 May 2024 11:57:15 +0200 Subject: [PATCH 4/4] * implement #59 * bump to 2.5.0 --- README.md | 6 +++- docs/conf.py | 2 +- jaxns/framework/__init__.py | 1 + jaxns/framework/jaxify.py | 44 ++++++++++++++++++++++++++++ jaxns/framework/tests/test_jaxify.py | 35 ++++++++++++++++++++++ jaxns/framework/tests/test_model.py | 5 +--- setup.py | 2 +- 7 files changed, 88 insertions(+), 7 deletions(-) create mode 100644 jaxns/framework/jaxify.py create mode 100644 jaxns/framework/tests/test_jaxify.py diff --git a/README.md b/README.md index e3eebeb9..f5929a4e 100644 --- a/README.md +++ b/README.md @@ -363,7 +363,11 @@ is the best way to achieve speed up. # Change Log -22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of likelihod. +15 May, 2024 -- JAXNS 2.5.0 released. Added ability to handle non-JAX likelihoods, e.g. if you have a simulation +framework with python bindings you can now use it for likelihoods in JAXNS. Small performance improvements. + +22 Apr, 2024 -- JAXNS 2.4.13 released. Fixes bug where slice sampling not invariant to monotonic transforms of +likelihod. 20 Mar, 2024 -- JAXNS 2.4.12 released. Minor bug fixes, and readability improvements. Added Empirial special prior. diff --git a/docs/conf.py b/docs/conf.py index e481d7e8..5fdd9ef5 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -12,7 +12,7 @@ project = "jaxns" copyright = "2022, Joshua G. Albert" author = "Joshua G. Albert" -release = "2.4.13" +release = "2.5.0" # -- General configuration --------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#general-configuration diff --git a/jaxns/framework/__init__.py b/jaxns/framework/__init__.py index 91bc7c35..184e7793 100644 --- a/jaxns/framework/__init__.py +++ b/jaxns/framework/__init__.py @@ -1,4 +1,5 @@ from jaxns.framework.model import * from jaxns.framework.prior import * from jaxns.framework.special_priors import * +from jaxns.framework.jaxify import * from jaxns.framework.bases import PriorModelGen, PriorModelType \ No newline at end of file diff --git a/jaxns/framework/jaxify.py b/jaxns/framework/jaxify.py new file mode 100644 index 00000000..6f223244 --- /dev/null +++ b/jaxns/framework/jaxify.py @@ -0,0 +1,44 @@ +import warnings +from typing import Callable + +import jax +import numpy as np + +from jaxns.internals.types import float_type, LikelihoodType + +__all__ = [ + 'jaxify_likelihood' +] + +def jaxify_likelihood(log_likelihood: Callable[..., np.ndarray], vectorised: bool = False) -> LikelihoodType: + """ + Wraps a non-JAX log likelihood function. + + Args: + log_likelihood: a non-JAX log-likelihood function, which accepts a number of arguments and returns a scalar + log-likelihood. + vectorised: if True then the `log_likelihood` performs a vectorised computation for leading batch dimensions, + i.e. if a leading batch dimension is added to all input arguments, then it returns a vector of + log-likelihoods with the same leading batch dimension. + + Returns: + A JAX-compatible log-likelihood function. + """ + warnings.warn( + "You're using a non-JAX log-likelihood function. This may be slower than a JAX log-likelihood function. " + "Also, you are responsible for ensuring that the function is deterministic. " + "Also, you cannot use learnable parameters in the likelihood call." + ) + + def _casted_log_likelihood(*args) -> np.ndarray: + return np.asarray(log_likelihood(*args), dtype=float_type) + + def _log_likelihood(*args) -> jax.Array: + # Define the expected shape & dtype of output. + result_shape_dtype = jax.ShapeDtypeStruct( + shape=(), + dtype=float_type + ) + return jax.pure_callback(_casted_log_likelihood, result_shape_dtype, *args, vectorized=vectorised) + + return _log_likelihood diff --git a/jaxns/framework/tests/test_jaxify.py b/jaxns/framework/tests/test_jaxify.py new file mode 100644 index 00000000..b124c5f7 --- /dev/null +++ b/jaxns/framework/tests/test_jaxify.py @@ -0,0 +1,35 @@ +import jax +import jax.random +import numpy as np + +from jaxns import Prior, Model +from jaxns.framework.jaxify import jaxify_likelihood +from jaxns.framework.tests.test_model import tfpd + + +def test_jaxify_likelihood(): + def log_likelihood(x, y): + return np.sum(x, axis=-1) + np.sum(y, axis=-1) + + wrapped_ll = jaxify_likelihood(log_likelihood) + np.testing.assert_allclose(wrapped_ll(np.array([1, 2]), np.array([3, 4])), 10) + + vmaped_wrapped_ll = jax.vmap(jaxify_likelihood(log_likelihood, vectorised=True)) + + np.testing.assert_allclose(vmaped_wrapped_ll(np.array([[1, 2], [2, 2]]), np.array([[3, 4], [4, 4]])), + np.array([10, 12])) + + +def test_jaxify(): + def prior_model(): + x = yield Prior(tfpd.Uniform(), name='x').parametrised() + return x + + @jaxify_likelihood + def log_likelihood(x): + return x + + model = Model(prior_model=prior_model, log_likelihood=log_likelihood) + model.sanity_check(key=jax.random.PRNGKey(0), S=10) + assert model.U_ndims == 0 + assert model.num_params == 1 diff --git a/jaxns/framework/tests/test_model.py b/jaxns/framework/tests/test_model.py index 84eb468a..22dcce23 100644 --- a/jaxns/framework/tests/test_model.py +++ b/jaxns/framework/tests/test_model.py @@ -138,8 +138,8 @@ def log_likelihood(obj: Obj): model = Model(prior_model=prior_model, log_likelihood=log_likelihood) model.sanity_check(key=jax.random.PRNGKey(0), S=10) -def test_empty_prior_models(): +def test_empty_prior_models(): def prior_model(): return 1. @@ -148,6 +148,3 @@ def log_likelihood(x): model = Model(prior_model=prior_model, log_likelihood=log_likelihood) model.sanity_check(key=jax.random.PRNGKey(0), S=10) - - - diff --git a/setup.py b/setup.py index f329dde1..2b72fc0d 100755 --- a/setup.py +++ b/setup.py @@ -20,7 +20,7 @@ long_description = fh.read() setup(name='jaxns', - version='2.4.13', + version='2.5.0', description='Nested Sampling in JAX', long_description=long_description, long_description_content_type="text/markdown",