Skip to content

Commit

Permalink
Refactor links [no ci]
Browse files Browse the repository at this point in the history
* rename link_functions to links
* separated activation function Ordered from OrderedQuantiles, one for
  generality, the other for automatic smart anchor selection based on
  quantile levels
* introduce link function for learnable positive semi-definite matrices
* full test coverage for links module
  • Loading branch information
han-ol committed Jan 28, 2025
1 parent ecef8ed commit ab821bf
Show file tree
Hide file tree
Showing 10 changed files with 253 additions and 64 deletions.
1 change: 0 additions & 1 deletion bayesflow/link_functions/__init__.py

This file was deleted.

62 changes: 0 additions & 62 deletions bayesflow/link_functions/ordered_quantiles.py

This file was deleted.

3 changes: 3 additions & 0 deletions bayesflow/links/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .ordered import Ordered
from .ordered_quantiles import OrderedQuantiles
from .positive_semi_definite import PositiveSemiDefinite
43 changes: 43 additions & 0 deletions bayesflow/links/ordered.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
import keras

from bayesflow.utils import keras_kwargs


class Ordered(keras.Layer):
def __init__(self, axis: int, anchor_index: int, **kwargs):
super().__init__(**keras_kwargs(kwargs))
self.axis = axis
self.anchor_index = anchor_index

def build(self, input_shape):
super().build(input_shape)
print("build Ordered()")

assert (
self.anchor_index % input_shape[self.axis] != 0 and self.anchor_index != -1
), "anchor should not be first or last index."
self.group_indeces = dict(
below=list(range(0, self.anchor_index)),
above=list(range(self.anchor_index + 1, input_shape[self.axis])),
)

def call(self, inputs):
# Divide in anchor, below and above
below_inputs = keras.ops.take(inputs, self.group_indeces["below"], axis=self.axis)
anchor_input = keras.ops.take(inputs, self.anchor_index, axis=self.axis)
anchor_input = keras.ops.expand_dims(anchor_input, axis=self.axis)
above_inputs = keras.ops.take(inputs, self.group_indeces["above"], axis=self.axis)

# Apply softplus for positivity and cumulate to ensure ordered quantiles
below = keras.activations.softplus(below_inputs)
above = keras.activations.softplus(above_inputs)

below = anchor_input - keras.ops.flip(keras.ops.cumsum(below, axis=self.axis), self.axis)
above = anchor_input + keras.ops.cumsum(above, axis=self.axis)

# Concatenate and reshape back
x = keras.ops.concatenate([below, anchor_input, above], self.axis)
return x

def compute_output_shape(self, input_shape):
return input_shape
32 changes: 32 additions & 0 deletions bayesflow/links/ordered_quantiles.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import keras

from bayesflow.utils import keras_kwargs

from collections.abc import Sequence

from .ordered import Ordered


class OrderedQuantiles(Ordered):
def __init__(self, q: Sequence[float] = None, axis: int = None, **kwargs):
super().__init__(axis, None, **keras_kwargs(kwargs))
self.q = q

def build(self, input_shape):
if self.axis is None and 1 < len(input_shape) <= 3:
self.axis = -2
elif self.axis is None:
raise AssertionError(
f"Cannot resolve which axis should be ordered automatically from input shape {input_shape}."
)

if self.q is None:
# choose the middle of the specified axis as anchor index
num_quantile_levels = input_shape[self.axis]
self.anchor_index = num_quantile_levels // 2
else:
# choose quantile level closest to median as anchor index
self.anchor_index = keras.ops.argmin(keras.ops.abs(keras.ops.convert_to_tensor(self.q) - 0.5))
assert input_shape[self.axis] == len(self.q)

super().build(input_shape)
14 changes: 14 additions & 0 deletions bayesflow/links/positive_semi_definite.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import keras

from bayesflow.utils import keras_kwargs


class PositiveSemiDefinite(keras.Layer):
def __init__(self, **kwargs):
super().__init__(**keras_kwargs(kwargs))

def call(self, inputs):
return keras.ops.einsum("...ij,...kj->...ik", inputs, inputs)

def compute_output_shape(self, input_shape):
return input_shape
Empty file added tests/test_links/__init__.py
Empty file.
89 changes: 89 additions & 0 deletions tests/test_links/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
import numpy as np
import keras
import pytest


@pytest.fixture()
def batch_size():
return 16


@pytest.fixture()
def num_variables():
return 10


@pytest.fixture()
def generic_preactivation(batch_size):
return keras.ops.ones((batch_size, 4, 4))


@pytest.fixture()
def ordered():
from bayesflow.links import Ordered

return Ordered(axis=1, anchor_index=2)


@pytest.fixture()
def ordered_quantiles():
from bayesflow.links import OrderedQuantiles

return OrderedQuantiles()


@pytest.fixture()
def positive_semi_definite():
from bayesflow.links import PositiveSemiDefinite

return PositiveSemiDefinite()


@pytest.fixture()
def linear():
return keras.layers.Activation("linear")


@pytest.fixture(params=["ordered", "ordered_quantiles", "positive_semi_definite", "linear"], scope="function")
def link(request):
return request.getfixturevalue(request.param)


@pytest.fixture()
def num_quantiles():
return 19


@pytest.fixture()
def quantiles_np(num_quantiles):
return np.linspace(0, 1, num_quantiles + 2)[1:-1]


@pytest.fixture()
def quantiles_py(quantiles_np):
return list(quantiles_np)


@pytest.fixture()
def quantiles_keras(quantiles_np):
return keras.ops.convert_to_tensor(quantiles_np)


@pytest.fixture()
def none():
return None


@pytest.fixture(params=["quantiles_np", "quantiles_py", "quantiles_keras", "none"], scope="function")
def quantiles(request):
return request.getfixturevalue(request.param)


@pytest.fixture()
def unordered(batch_size, num_quantiles, num_variables):
return keras.random.normal((batch_size, num_quantiles, num_variables))


@pytest.fixture()
def random_matrix_batch(batch_size, num_variables):
return keras.random.normal((batch_size, num_variables, num_variables))
71 changes: 71 additions & 0 deletions tests/test_links/test_links.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import numpy as np
import pytest


def test_link_output(link, generic_preactivation):
output_shape = link.compute_output_shape(generic_preactivation.shape)
output = link(generic_preactivation)

assert output_shape == output.shape


def test_invalid_shape_for_ordered_quantiles(ordered_quantiles, batch_size, num_quantiles, num_variables):
with pytest.raises(AssertionError) as excinfo:
ordered_quantiles.build((batch_size, batch_size, num_quantiles, num_variables))

assert "resolve which axis should be ordered automatically" in str(excinfo)


@pytest.mark.parametrize("axis", [1, 2])
def test_invalid_shape_for_ordered_quantiles_with_specified_axis(
ordered_quantiles, axis, batch_size, num_quantiles, num_variables
):
ordered_quantiles.axis = axis
ordered_quantiles.build((batch_size, batch_size, num_quantiles, num_variables))


def check_ordering(output, axis):
assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}."
for i in range(output.ndim):
if i != axis % output.ndim:
assert not np.all(
np.diff(output, axis=i) > 0
), f"is ordered along axis which is not meant to be ordered: {i}."


@pytest.mark.parametrize("axis", [0, 1, 2])
def test_ordering(axis, unordered):
from bayesflow.links import Ordered

activation = Ordered(axis=axis, anchor_index=5)

output = activation(unordered)

check_ordering(output, axis)


def test_quantile_ordering(quantiles, unordered):
from bayesflow.links import OrderedQuantiles

activation = OrderedQuantiles(q=quantiles)

activation.build(unordered.shape)
axis = activation.axis

output = activation(unordered)

check_ordering(output, axis)


def test_positive_semi_definite(random_matrix_batch):
from bayesflow.links import PositiveSemiDefinite

activation = PositiveSemiDefinite()

output = activation(random_matrix_batch)

eigenvalues = np.linalg.eig(output).eigenvalues

assert np.all(eigenvalues.real > 0) and np.all(
np.isclose(eigenvalues.imag, 0)
), "output is not positive semi-definite."
2 changes: 1 addition & 1 deletion tests/test_scores/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def batch_size():

@pytest.fixture()
def num_variables():
return 4
return 10


@pytest.fixture()
Expand Down

0 comments on commit ab821bf

Please sign in to comment.