-
Notifications
You must be signed in to change notification settings - Fork 59
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* rename link_functions to links * separated activation function Ordered from OrderedQuantiles, one for generality, the other for automatic smart anchor selection based on quantile levels * introduce link function for learnable positive semi-definite matrices * full test coverage for links module
- Loading branch information
Showing
10 changed files
with
253 additions
and
64 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .ordered import Ordered | ||
from .ordered_quantiles import OrderedQuantiles | ||
from .positive_semi_definite import PositiveSemiDefinite |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
import keras | ||
|
||
from bayesflow.utils import keras_kwargs | ||
|
||
|
||
class Ordered(keras.Layer): | ||
def __init__(self, axis: int, anchor_index: int, **kwargs): | ||
super().__init__(**keras_kwargs(kwargs)) | ||
self.axis = axis | ||
self.anchor_index = anchor_index | ||
|
||
def build(self, input_shape): | ||
super().build(input_shape) | ||
print("build Ordered()") | ||
|
||
assert ( | ||
self.anchor_index % input_shape[self.axis] != 0 and self.anchor_index != -1 | ||
), "anchor should not be first or last index." | ||
self.group_indeces = dict( | ||
below=list(range(0, self.anchor_index)), | ||
above=list(range(self.anchor_index + 1, input_shape[self.axis])), | ||
) | ||
|
||
def call(self, inputs): | ||
# Divide in anchor, below and above | ||
below_inputs = keras.ops.take(inputs, self.group_indeces["below"], axis=self.axis) | ||
anchor_input = keras.ops.take(inputs, self.anchor_index, axis=self.axis) | ||
anchor_input = keras.ops.expand_dims(anchor_input, axis=self.axis) | ||
above_inputs = keras.ops.take(inputs, self.group_indeces["above"], axis=self.axis) | ||
|
||
# Apply softplus for positivity and cumulate to ensure ordered quantiles | ||
below = keras.activations.softplus(below_inputs) | ||
above = keras.activations.softplus(above_inputs) | ||
|
||
below = anchor_input - keras.ops.flip(keras.ops.cumsum(below, axis=self.axis), self.axis) | ||
above = anchor_input + keras.ops.cumsum(above, axis=self.axis) | ||
|
||
# Concatenate and reshape back | ||
x = keras.ops.concatenate([below, anchor_input, above], self.axis) | ||
return x | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import keras | ||
|
||
from bayesflow.utils import keras_kwargs | ||
|
||
from collections.abc import Sequence | ||
|
||
from .ordered import Ordered | ||
|
||
|
||
class OrderedQuantiles(Ordered): | ||
def __init__(self, q: Sequence[float] = None, axis: int = None, **kwargs): | ||
super().__init__(axis, None, **keras_kwargs(kwargs)) | ||
self.q = q | ||
|
||
def build(self, input_shape): | ||
if self.axis is None and 1 < len(input_shape) <= 3: | ||
self.axis = -2 | ||
elif self.axis is None: | ||
raise AssertionError( | ||
f"Cannot resolve which axis should be ordered automatically from input shape {input_shape}." | ||
) | ||
|
||
if self.q is None: | ||
# choose the middle of the specified axis as anchor index | ||
num_quantile_levels = input_shape[self.axis] | ||
self.anchor_index = num_quantile_levels // 2 | ||
else: | ||
# choose quantile level closest to median as anchor index | ||
self.anchor_index = keras.ops.argmin(keras.ops.abs(keras.ops.convert_to_tensor(self.q) - 0.5)) | ||
assert input_shape[self.axis] == len(self.q) | ||
|
||
super().build(input_shape) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
import keras | ||
|
||
from bayesflow.utils import keras_kwargs | ||
|
||
|
||
class PositiveSemiDefinite(keras.Layer): | ||
def __init__(self, **kwargs): | ||
super().__init__(**keras_kwargs(kwargs)) | ||
|
||
def call(self, inputs): | ||
return keras.ops.einsum("...ij,...kj->...ik", inputs, inputs) | ||
|
||
def compute_output_shape(self, input_shape): | ||
return input_shape |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,89 @@ | ||
import numpy as np | ||
import keras | ||
import pytest | ||
|
||
|
||
@pytest.fixture() | ||
def batch_size(): | ||
return 16 | ||
|
||
|
||
@pytest.fixture() | ||
def num_variables(): | ||
return 10 | ||
|
||
|
||
@pytest.fixture() | ||
def generic_preactivation(batch_size): | ||
return keras.ops.ones((batch_size, 4, 4)) | ||
|
||
|
||
@pytest.fixture() | ||
def ordered(): | ||
from bayesflow.links import Ordered | ||
|
||
return Ordered(axis=1, anchor_index=2) | ||
|
||
|
||
@pytest.fixture() | ||
def ordered_quantiles(): | ||
from bayesflow.links import OrderedQuantiles | ||
|
||
return OrderedQuantiles() | ||
|
||
|
||
@pytest.fixture() | ||
def positive_semi_definite(): | ||
from bayesflow.links import PositiveSemiDefinite | ||
|
||
return PositiveSemiDefinite() | ||
|
||
|
||
@pytest.fixture() | ||
def linear(): | ||
return keras.layers.Activation("linear") | ||
|
||
|
||
@pytest.fixture(params=["ordered", "ordered_quantiles", "positive_semi_definite", "linear"], scope="function") | ||
def link(request): | ||
return request.getfixturevalue(request.param) | ||
|
||
|
||
@pytest.fixture() | ||
def num_quantiles(): | ||
return 19 | ||
|
||
|
||
@pytest.fixture() | ||
def quantiles_np(num_quantiles): | ||
return np.linspace(0, 1, num_quantiles + 2)[1:-1] | ||
|
||
|
||
@pytest.fixture() | ||
def quantiles_py(quantiles_np): | ||
return list(quantiles_np) | ||
|
||
|
||
@pytest.fixture() | ||
def quantiles_keras(quantiles_np): | ||
return keras.ops.convert_to_tensor(quantiles_np) | ||
|
||
|
||
@pytest.fixture() | ||
def none(): | ||
return None | ||
|
||
|
||
@pytest.fixture(params=["quantiles_np", "quantiles_py", "quantiles_keras", "none"], scope="function") | ||
def quantiles(request): | ||
return request.getfixturevalue(request.param) | ||
|
||
|
||
@pytest.fixture() | ||
def unordered(batch_size, num_quantiles, num_variables): | ||
return keras.random.normal((batch_size, num_quantiles, num_variables)) | ||
|
||
|
||
@pytest.fixture() | ||
def random_matrix_batch(batch_size, num_variables): | ||
return keras.random.normal((batch_size, num_variables, num_variables)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,71 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
|
||
def test_link_output(link, generic_preactivation): | ||
output_shape = link.compute_output_shape(generic_preactivation.shape) | ||
output = link(generic_preactivation) | ||
|
||
assert output_shape == output.shape | ||
|
||
|
||
def test_invalid_shape_for_ordered_quantiles(ordered_quantiles, batch_size, num_quantiles, num_variables): | ||
with pytest.raises(AssertionError) as excinfo: | ||
ordered_quantiles.build((batch_size, batch_size, num_quantiles, num_variables)) | ||
|
||
assert "resolve which axis should be ordered automatically" in str(excinfo) | ||
|
||
|
||
@pytest.mark.parametrize("axis", [1, 2]) | ||
def test_invalid_shape_for_ordered_quantiles_with_specified_axis( | ||
ordered_quantiles, axis, batch_size, num_quantiles, num_variables | ||
): | ||
ordered_quantiles.axis = axis | ||
ordered_quantiles.build((batch_size, batch_size, num_quantiles, num_variables)) | ||
|
||
|
||
def check_ordering(output, axis): | ||
assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}." | ||
for i in range(output.ndim): | ||
if i != axis % output.ndim: | ||
assert not np.all( | ||
np.diff(output, axis=i) > 0 | ||
), f"is ordered along axis which is not meant to be ordered: {i}." | ||
|
||
|
||
@pytest.mark.parametrize("axis", [0, 1, 2]) | ||
def test_ordering(axis, unordered): | ||
from bayesflow.links import Ordered | ||
|
||
activation = Ordered(axis=axis, anchor_index=5) | ||
|
||
output = activation(unordered) | ||
|
||
check_ordering(output, axis) | ||
|
||
|
||
def test_quantile_ordering(quantiles, unordered): | ||
from bayesflow.links import OrderedQuantiles | ||
|
||
activation = OrderedQuantiles(q=quantiles) | ||
|
||
activation.build(unordered.shape) | ||
axis = activation.axis | ||
|
||
output = activation(unordered) | ||
|
||
check_ordering(output, axis) | ||
|
||
|
||
def test_positive_semi_definite(random_matrix_batch): | ||
from bayesflow.links import PositiveSemiDefinite | ||
|
||
activation = PositiveSemiDefinite() | ||
|
||
output = activation(random_matrix_batch) | ||
|
||
eigenvalues = np.linalg.eig(output).eigenvalues | ||
|
||
assert np.all(eigenvalues.real > 0) and np.all( | ||
np.isclose(eigenvalues.imag, 0) | ||
), "output is not positive semi-definite." |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -9,7 +9,7 @@ def batch_size(): | |
|
||
@pytest.fixture() | ||
def num_variables(): | ||
return 4 | ||
return 10 | ||
|
||
|
||
@pytest.fixture() | ||
|