Skip to content

Commit

Permalink
Merge pull request #164 from Joshuaalbert/feature-163
Browse files Browse the repository at this point in the history
Implement and use scan_associative_cumulative_op
  • Loading branch information
Joshuaalbert authored May 8, 2024
2 parents 649c930 + 182ac1e commit 23a4f0e
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 33 deletions.
35 changes: 35 additions & 0 deletions jaxns/internals/cumulative_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,48 @@

import jax
from jax import lax, numpy as jnp, tree_util
from tensorflow_probability.substrates.jax import math as tfp_math

from jaxns.internals.types import IntArray, int_type

X = TypeVar('X')
V = TypeVar('V')
Y = TypeVar('Y')


def scan_associative_cumulative_op(op: Callable[[X, X], X], init: X, xs: X, pre_op: bool = False) -> Tuple[X, X]:
"""
Compute a cumulative operation on an array of values using scan_associative.
Args:
op: the operation to perform, must be associative.
init: the initial value.
xs: the array of values.
Returns:
the final accumulated value, and the result of the cumulative operation applied on input
"""

def associative_op(a, b):
return op(a, b)

# Prepare the input array by prepending the initial value
full_input = jax.tree.map(lambda x, y: jnp.concatenate([x[None], y], axis=0), init, xs)

# Apply the operation to accumulate results using scan_associative
scanned_results = tfp_math.scan_associative(associative_op, full_input)

# The final accumulated value is the last element in the results
final_accumulate = jax.tree.map(lambda x: x[-1], scanned_results)

if pre_op:
scanned_results = jax.tree.map(lambda x: x[:-1], scanned_results)
else:
scanned_results = jax.tree.map(lambda x: x[1:], scanned_results)

return final_accumulate, scanned_results


def cumulative_op_static(op: Callable[[V, Y], V], init: V, xs: Y, pre_op: bool = False, unroll: int = 1) -> Tuple[
V, V]:
"""
Expand Down
9 changes: 5 additions & 4 deletions jaxns/internals/shrinkage_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax.numpy as jnp

from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic
from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op, cumulative_op_static
from jaxns.internals.log_semiring import LogSpace
from jaxns.internals.tree_structure import SampleTreeGraph, count_crossed_edges
from jaxns.internals.types import MeasureType, EvidenceCalculation, float_type, IntArray, FloatArray
Expand All @@ -28,8 +28,8 @@ def op(log_X, num_live_points):
next_X_mean = X_mean * T_mean
return next_X_mean.log_abs_val

_, log_X = cumulative_op_static(op=op, init=jnp.asarray(-jnp.inf, float_type),
xs=live_point_counts.num_live_points)
_, log_X = scan_associative_cumulative_op(op=op, init=jnp.asarray(-jnp.inf, float_type),
xs=live_point_counts.num_live_points)
return log_X


Expand Down Expand Up @@ -141,7 +141,8 @@ def compute_evidence_stats(log_L: MeasureType, num_live_points: FloatArray, num_
)
if num_samples is not None:
stop_idx = num_samples
final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs, stop_idx=stop_idx)
final_accumulate, result = cumulative_op_dynamic(op=_update_evidence_calc_op, init=init, xs=xs,
stop_idx=stop_idx)
else:
final_accumulate, result = cumulative_op_static(op=_update_evidence_calc_op, init=init, xs=xs)
final_evidence_calculation = final_accumulate
Expand Down
63 changes: 62 additions & 1 deletion jaxns/internals/tests/test_cumulative_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import jax
import numpy as np
import pytest
from jax import numpy as jnp

from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic
from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic, scan_associative_cumulative_op
from jaxns.internals.types import float_type, int_type


Expand All @@ -19,6 +22,64 @@ def op(accumulate, y):
assert all(result == jnp.asarray([0, 1, 3], float_type))


@pytest.mark.parametrize("binary_op", [jnp.add, jnp.multiply, jnp.minimum, jnp.maximum])
def test_scan_associative_cumulative_op(binary_op):
def op(accumulate, y):
return binary_op(accumulate, y)

init = jnp.asarray(1, float_type)
xs = jnp.arange(1, 11, dtype=float_type)
final_accumulate, result = scan_associative_cumulative_op(op=binary_op, init=init, xs=xs)
final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs)
assert final_accumulate == final_accumulate_expected
np.testing.assert_allclose(result, result_expected)

final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True)
final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs, pre_op=True)
assert final_accumulate == final_accumulate_expected
np.testing.assert_allclose(result, result_expected)


@pytest.mark.parametrize("binary_op", [jnp.subtract, jnp.true_divide])
def test_scan_associative_cumulative_op(binary_op):
def op(accumulate, y):
return binary_op(accumulate, y)

init = jnp.asarray(1, float_type)
xs = jnp.arange(1, 11, dtype=float_type)
final_accumulate, result = scan_associative_cumulative_op(op=binary_op, init=init, xs=xs)
final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs)
with pytest.raises(AssertionError):
assert final_accumulate == final_accumulate_expected
with pytest.raises(AssertionError):
np.testing.assert_allclose(result, result_expected)

final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True)
final_accumulate_expected, result_expected = cumulative_op_static(op=op, init=init, xs=xs, pre_op=True)
with pytest.raises(AssertionError):
assert final_accumulate == final_accumulate_expected
with pytest.raises(AssertionError):
np.testing.assert_allclose(result, result_expected)


def test_scan_associative_cumulative_op_with_pytrees():
# Test with pytrees for xs and ys
def op(accumulate, y):
return jax.tree.map(lambda x, y: jnp.add(x, y), accumulate, y)

init = {'a': jnp.asarray(0, float_type), 'b': jnp.asarray(0, float_type)}
xs = {'a': jnp.asarray([1, 2, 3], float_type), 'b': jnp.asarray([4, 5, 6], float_type)}
final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs)
assert final_accumulate == {'a': 6, 'b': 15}
assert all(result['a'] == jnp.asarray([1, 3, 6], float_type))
assert all(result['b'] == jnp.asarray([4, 9, 15], float_type))

final_accumulate, result = scan_associative_cumulative_op(op=op, init=init, xs=xs, pre_op=True)
assert final_accumulate == {'a': 6, 'b': 15}
assert all(result['a'] == jnp.asarray([0, 1, 3], float_type))
assert all(result['b'] == jnp.asarray([0, 4, 9], float_type))


def test_cumulative_op_dynamic():
def op(accumulate, y):
return accumulate + y
Expand Down
34 changes: 19 additions & 15 deletions jaxns/internals/tests/test_tree_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,13 @@
from jaxns.internals.types import StaticStandardNestedSamplerState, StaticStandardSampleCollection


def pytree_assert_equal(a, b):
print(a)
print(b)
for x, y in zip(jax.tree.leaves(a), jax.tree.leaves(b)):
np.testing.assert_allclose(x, y)


def test_naive():
S = SampleTreeGraph(
sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]),
Expand All @@ -29,20 +36,18 @@ def test_basic():
sender_node_idx=jnp.asarray([0, 0, 0, 1, 2, 3]),
log_L=jnp.asarray([1, 2, 3, 4, 5, 6])
)
assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S)))
assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S)))
assert all(
jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S)))
pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S))
pytree_assert_equal(count_crossed_edges(S), count_old(S))
pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S))

S = SampleTreeGraph(
sender_node_idx=jnp.asarray([0, 0, 0, 1, 3, 2]),
log_L=jnp.asarray([1, 2, 3, 4, 6, 5])
)

assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S)))
assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S)))
assert all(
jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S)))
pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S))
pytree_assert_equal(count_crossed_edges(S), count_old(S))
pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S))


def test_with_num_samples():
Expand All @@ -57,9 +62,9 @@ def test_with_num_samples():
log_L=jnp.asarray([1, 2, 3, 4, 5, 6, 7, 8])
)

assert all(jax.tree.map(lambda x, y: np.array_equal(x[:num_samples], y),
count_crossed_edges(S1, num_samples),
count_crossed_edges(S2)))
x = jax.tree.map(lambda x: x[:num_samples], count_crossed_edges(S1, num_samples))

pytree_assert_equal(x, count_crossed_edges(S2))

output = count_crossed_edges(S1, num_samples)
print(output)
Expand Down Expand Up @@ -89,10 +94,9 @@ def test_random_tree():

plot_tree(S)

assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_intervals_naive(S)))
assert all(jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_old(S)))
assert all(
jax.tree.map(lambda x, y: np.array_equal(x, y), count_crossed_edges(S), count_crossed_edges_less_fast(S)))
pytree_assert_equal(count_crossed_edges(S), count_intervals_naive(S))
pytree_assert_equal(count_crossed_edges(S), count_old(S))
pytree_assert_equal(count_crossed_edges(S), count_crossed_edges_less_fast(S))

T = count_crossed_edges_less_fast(S)
import pylab as plt
Expand Down
2 changes: 1 addition & 1 deletion jaxns/internals/tree_structure.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import numpy as jnp, lax, core
from jax._src.numpy import lax_numpy

from jaxns.internals.cumulative_ops import cumulative_op_static, cumulative_op_dynamic
from jaxns.internals.cumulative_ops import cumulative_op_dynamic, scan_associative_cumulative_op, cumulative_op_static
from jaxns.internals.maps import remove_chunk_dim
from jaxns.internals.types import MeasureType, IntArray, float_type, FloatArray, StaticStandardNestedSamplerState, \
int_type
Expand Down
3 changes: 1 addition & 2 deletions jaxns/samplers/multi_slice_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,8 +258,7 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample:
final_sample, cumulative_samples = cumulative_op_static(
op=propose_op,
init=init_sample,
xs=random.split(key, self.num_slices),
unroll=2
xs=random.split(key, self.num_slices)
)

# Last sample is the final sample, the rest are potential phantom samples
Expand Down
19 changes: 9 additions & 10 deletions jaxns/samplers/uni_slice_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,7 @@ def _pick_point_in_interval(key: PRNGKey, point_U0: FloatArray, direction: Float
return point_U, t


def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: FloatArray, log_L_proposal: FloatArray,
log_L_constraint: FloatArray, log_L0: FloatArray,
def _shrink_interval(key: PRNGKey, t: FloatArray, left: FloatArray, right: FloatArray,
midpoint_shrink: bool) -> Tuple[FloatArray, FloatArray]:
"""
Not successful proposal, so shrink, optionally apply exponential shrinkage.
Expand Down Expand Up @@ -151,9 +150,6 @@ def body(carry: Carry) -> Carry:
t=carry.t,
left=carry.left,
right=carry.right,
log_L_proposal=carry.log_L,
log_L_constraint=log_L_constraint,
log_L0=seed_point.log_L0,
midpoint_shrink=midpoint_shrink
)
point_U, t = _pick_point_in_interval(
Expand Down Expand Up @@ -232,11 +228,11 @@ def body(carry: Carry) -> Carry:

class UniDimSliceSampler(BaseAbstractMarkovSampler):
"""
Slice sampler for a single dimension. Produces correlated (non-i.i.d.) samples.
Slice sampler for a single dimension. Produces correlated samples.
"""

def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save: int, midpoint_shrink: bool,
perfect: bool, gradient_slice: bool = False):
perfect: bool, gradient_slice: bool = False, adaptive_shrink: bool = False):
"""
Unidimensional slice sampler.
Expand All @@ -250,7 +246,8 @@ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save:
perfect: if true then perform exponential shrinkage from maximal bounds, requiring no step-out procedure.
Otherwise, uses a doubling procedure (exponentially finding bracket).
Note: Perfect is a misnomer, as perfection also depends on the number of slices between acceptance.
gradient_slice: if true then always slice along gradient direction.
gradient_slice: if true then always slice along increasing gradient direction.
adaptive_shrink: if true then shrink interval to random point in interval, rather than midpoint.
"""
super().__init__(model=model)
if num_slices < 1:
Expand All @@ -264,6 +261,9 @@ def __init__(self, model: BaseAbstractModel, num_slices: int, num_phantom_save:
self.midpoint_shrink = bool(midpoint_shrink)
self.perfect = bool(perfect)
self.gradient_slice = bool(gradient_slice)
self.adaptive_shrink = bool(adaptive_shrink)
if self.adaptive_shrink:
raise NotImplementedError("Adaptive shrinkage not implemented.")
if not self.perfect:
raise ValueError("Only perfect slice sampler is implemented.")

Expand Down Expand Up @@ -341,8 +341,7 @@ def propose_op(sample: Sample, key: PRNGKey) -> Sample:
final_sample, cumulative_samples = cumulative_op_static(
op=propose_op,
init=init_sample,
xs=random.split(key, self.num_slices),
unroll=2
xs=random.split(key, self.num_slices)
)

# Last sample is the final sample, the rest are potential phantom samples
Expand Down

0 comments on commit 23a4f0e

Please sign in to comment.