Skip to content

Commit ab821bf

Browse files
committed
Refactor links [no ci]
* rename link_functions to links * separated activation function Ordered from OrderedQuantiles, one for generality, the other for automatic smart anchor selection based on quantile levels * introduce link function for learnable positive semi-definite matrices * full test coverage for links module
1 parent ecef8ed commit ab821bf

File tree

10 files changed

+253
-64
lines changed

10 files changed

+253
-64
lines changed

bayesflow/link_functions/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

bayesflow/link_functions/ordered_quantiles.py

Lines changed: 0 additions & 62 deletions
This file was deleted.

bayesflow/links/__init__.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .ordered import Ordered
2+
from .ordered_quantiles import OrderedQuantiles
3+
from .positive_semi_definite import PositiveSemiDefinite

bayesflow/links/ordered.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import keras
2+
3+
from bayesflow.utils import keras_kwargs
4+
5+
6+
class Ordered(keras.Layer):
7+
def __init__(self, axis: int, anchor_index: int, **kwargs):
8+
super().__init__(**keras_kwargs(kwargs))
9+
self.axis = axis
10+
self.anchor_index = anchor_index
11+
12+
def build(self, input_shape):
13+
super().build(input_shape)
14+
print("build Ordered()")
15+
16+
assert (
17+
self.anchor_index % input_shape[self.axis] != 0 and self.anchor_index != -1
18+
), "anchor should not be first or last index."
19+
self.group_indeces = dict(
20+
below=list(range(0, self.anchor_index)),
21+
above=list(range(self.anchor_index + 1, input_shape[self.axis])),
22+
)
23+
24+
def call(self, inputs):
25+
# Divide in anchor, below and above
26+
below_inputs = keras.ops.take(inputs, self.group_indeces["below"], axis=self.axis)
27+
anchor_input = keras.ops.take(inputs, self.anchor_index, axis=self.axis)
28+
anchor_input = keras.ops.expand_dims(anchor_input, axis=self.axis)
29+
above_inputs = keras.ops.take(inputs, self.group_indeces["above"], axis=self.axis)
30+
31+
# Apply softplus for positivity and cumulate to ensure ordered quantiles
32+
below = keras.activations.softplus(below_inputs)
33+
above = keras.activations.softplus(above_inputs)
34+
35+
below = anchor_input - keras.ops.flip(keras.ops.cumsum(below, axis=self.axis), self.axis)
36+
above = anchor_input + keras.ops.cumsum(above, axis=self.axis)
37+
38+
# Concatenate and reshape back
39+
x = keras.ops.concatenate([below, anchor_input, above], self.axis)
40+
return x
41+
42+
def compute_output_shape(self, input_shape):
43+
return input_shape

bayesflow/links/ordered_quantiles.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
import keras
2+
3+
from bayesflow.utils import keras_kwargs
4+
5+
from collections.abc import Sequence
6+
7+
from .ordered import Ordered
8+
9+
10+
class OrderedQuantiles(Ordered):
11+
def __init__(self, q: Sequence[float] = None, axis: int = None, **kwargs):
12+
super().__init__(axis, None, **keras_kwargs(kwargs))
13+
self.q = q
14+
15+
def build(self, input_shape):
16+
if self.axis is None and 1 < len(input_shape) <= 3:
17+
self.axis = -2
18+
elif self.axis is None:
19+
raise AssertionError(
20+
f"Cannot resolve which axis should be ordered automatically from input shape {input_shape}."
21+
)
22+
23+
if self.q is None:
24+
# choose the middle of the specified axis as anchor index
25+
num_quantile_levels = input_shape[self.axis]
26+
self.anchor_index = num_quantile_levels // 2
27+
else:
28+
# choose quantile level closest to median as anchor index
29+
self.anchor_index = keras.ops.argmin(keras.ops.abs(keras.ops.convert_to_tensor(self.q) - 0.5))
30+
assert input_shape[self.axis] == len(self.q)
31+
32+
super().build(input_shape)
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import keras
2+
3+
from bayesflow.utils import keras_kwargs
4+
5+
6+
class PositiveSemiDefinite(keras.Layer):
7+
def __init__(self, **kwargs):
8+
super().__init__(**keras_kwargs(kwargs))
9+
10+
def call(self, inputs):
11+
return keras.ops.einsum("...ij,...kj->...ik", inputs, inputs)
12+
13+
def compute_output_shape(self, input_shape):
14+
return input_shape

tests/test_links/__init__.py

Whitespace-only changes.

tests/test_links/conftest.py

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,89 @@
1+
import numpy as np
2+
import keras
3+
import pytest
4+
5+
6+
@pytest.fixture()
7+
def batch_size():
8+
return 16
9+
10+
11+
@pytest.fixture()
12+
def num_variables():
13+
return 10
14+
15+
16+
@pytest.fixture()
17+
def generic_preactivation(batch_size):
18+
return keras.ops.ones((batch_size, 4, 4))
19+
20+
21+
@pytest.fixture()
22+
def ordered():
23+
from bayesflow.links import Ordered
24+
25+
return Ordered(axis=1, anchor_index=2)
26+
27+
28+
@pytest.fixture()
29+
def ordered_quantiles():
30+
from bayesflow.links import OrderedQuantiles
31+
32+
return OrderedQuantiles()
33+
34+
35+
@pytest.fixture()
36+
def positive_semi_definite():
37+
from bayesflow.links import PositiveSemiDefinite
38+
39+
return PositiveSemiDefinite()
40+
41+
42+
@pytest.fixture()
43+
def linear():
44+
return keras.layers.Activation("linear")
45+
46+
47+
@pytest.fixture(params=["ordered", "ordered_quantiles", "positive_semi_definite", "linear"], scope="function")
48+
def link(request):
49+
return request.getfixturevalue(request.param)
50+
51+
52+
@pytest.fixture()
53+
def num_quantiles():
54+
return 19
55+
56+
57+
@pytest.fixture()
58+
def quantiles_np(num_quantiles):
59+
return np.linspace(0, 1, num_quantiles + 2)[1:-1]
60+
61+
62+
@pytest.fixture()
63+
def quantiles_py(quantiles_np):
64+
return list(quantiles_np)
65+
66+
67+
@pytest.fixture()
68+
def quantiles_keras(quantiles_np):
69+
return keras.ops.convert_to_tensor(quantiles_np)
70+
71+
72+
@pytest.fixture()
73+
def none():
74+
return None
75+
76+
77+
@pytest.fixture(params=["quantiles_np", "quantiles_py", "quantiles_keras", "none"], scope="function")
78+
def quantiles(request):
79+
return request.getfixturevalue(request.param)
80+
81+
82+
@pytest.fixture()
83+
def unordered(batch_size, num_quantiles, num_variables):
84+
return keras.random.normal((batch_size, num_quantiles, num_variables))
85+
86+
87+
@pytest.fixture()
88+
def random_matrix_batch(batch_size, num_variables):
89+
return keras.random.normal((batch_size, num_variables, num_variables))

tests/test_links/test_links.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
import numpy as np
2+
import pytest
3+
4+
5+
def test_link_output(link, generic_preactivation):
6+
output_shape = link.compute_output_shape(generic_preactivation.shape)
7+
output = link(generic_preactivation)
8+
9+
assert output_shape == output.shape
10+
11+
12+
def test_invalid_shape_for_ordered_quantiles(ordered_quantiles, batch_size, num_quantiles, num_variables):
13+
with pytest.raises(AssertionError) as excinfo:
14+
ordered_quantiles.build((batch_size, batch_size, num_quantiles, num_variables))
15+
16+
assert "resolve which axis should be ordered automatically" in str(excinfo)
17+
18+
19+
@pytest.mark.parametrize("axis", [1, 2])
20+
def test_invalid_shape_for_ordered_quantiles_with_specified_axis(
21+
ordered_quantiles, axis, batch_size, num_quantiles, num_variables
22+
):
23+
ordered_quantiles.axis = axis
24+
ordered_quantiles.build((batch_size, batch_size, num_quantiles, num_variables))
25+
26+
27+
def check_ordering(output, axis):
28+
assert np.all(np.diff(output, axis=axis) > 0), f"is not ordered along specified axis: {axis}."
29+
for i in range(output.ndim):
30+
if i != axis % output.ndim:
31+
assert not np.all(
32+
np.diff(output, axis=i) > 0
33+
), f"is ordered along axis which is not meant to be ordered: {i}."
34+
35+
36+
@pytest.mark.parametrize("axis", [0, 1, 2])
37+
def test_ordering(axis, unordered):
38+
from bayesflow.links import Ordered
39+
40+
activation = Ordered(axis=axis, anchor_index=5)
41+
42+
output = activation(unordered)
43+
44+
check_ordering(output, axis)
45+
46+
47+
def test_quantile_ordering(quantiles, unordered):
48+
from bayesflow.links import OrderedQuantiles
49+
50+
activation = OrderedQuantiles(q=quantiles)
51+
52+
activation.build(unordered.shape)
53+
axis = activation.axis
54+
55+
output = activation(unordered)
56+
57+
check_ordering(output, axis)
58+
59+
60+
def test_positive_semi_definite(random_matrix_batch):
61+
from bayesflow.links import PositiveSemiDefinite
62+
63+
activation = PositiveSemiDefinite()
64+
65+
output = activation(random_matrix_batch)
66+
67+
eigenvalues = np.linalg.eig(output).eigenvalues
68+
69+
assert np.all(eigenvalues.real > 0) and np.all(
70+
np.isclose(eigenvalues.imag, 0)
71+
), "output is not positive semi-definite."

tests/test_scores/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ def batch_size():
99

1010
@pytest.fixture()
1111
def num_variables():
12-
return 4
12+
return 10
1313

1414

1515
@pytest.fixture()

0 commit comments

Comments
 (0)