Skip to content

Commit ce07855

Browse files
committed
Add activation function for quantile estimation
1 parent 98437ae commit ce07855

File tree

3 files changed

+79
-11
lines changed

3 files changed

+79
-11
lines changed

bayesflow/link_functions/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .ordered_quantiles import OrderedQuantiles
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
import keras
2+
3+
from bayesflow.utils import keras_kwargs
4+
5+
from collections.abc import Sequence
6+
7+
8+
class OrderedQuantiles(keras.Layer):
9+
def __init__(self, quantile_levels: Sequence[float] = None, axis: int = None, **kwargs):
10+
super().__init__(**keras_kwargs(kwargs))
11+
self.quantile_levels = quantile_levels
12+
self.axis = axis
13+
14+
def build(self, input_shape):
15+
super().build(input_shape)
16+
if 1 < len(input_shape) <= 3:
17+
self.axis = -2
18+
if self.quantile_levels is not None:
19+
num_quantile_levels = len(self.quantile_levels)
20+
# choose quantile level closest to median as anchor
21+
self.anchor_quantile_index = keras.ops.argmin(
22+
keras.ops.abs(keras.ops.convert_to_tensor(self.quantile_levels) - 0.5)
23+
)
24+
else:
25+
num_quantile_levels = input_shape[self.axis]
26+
self.anchor_quantile_index = num_quantile_levels // 2
27+
28+
self.group_indeces = dict(
29+
below=list(range(0, self.anchor_quantile_index)),
30+
above=list(range(self.anchor_quantile_index + 1, num_quantile_levels)),
31+
)
32+
else:
33+
raise AssertionError(
34+
"Cannot resolve which axis should be ordered automatically from input shape " + str(input_shape)
35+
)
36+
37+
def call(self, inputs):
38+
# Divide in anchor, below and above
39+
below_inputs = keras.ops.take(inputs, self.group_indeces["below"], axis=self.axis)
40+
anchor_input = keras.ops.take(inputs, self.anchor_quantile_index, axis=self.axis)
41+
above_inputs = keras.ops.take(inputs, self.group_indeces["above"], axis=self.axis)
42+
43+
# prepare a reshape target to aid broadcasting correctly
44+
broadcast_shape = list(below_inputs.shape) # convert to list to allow item assignment
45+
broadcast_shape[self.axis] = 1
46+
broadcast_shape = tuple(broadcast_shape)
47+
48+
anchor_input = keras.ops.reshape(anchor_input, broadcast_shape)
49+
50+
# Apply softplus for positivity and cumulate to ensure ordered quantiles
51+
below = keras.activations.softplus(below_inputs)
52+
above = keras.activations.softplus(above_inputs)
53+
54+
below = anchor_input - keras.ops.flip(keras.ops.cumsum(below, axis=self.axis), self.axis)
55+
above = anchor_input + keras.ops.cumsum(above, axis=self.axis)
56+
57+
# Concatenate and reshape back
58+
x = keras.ops.concatenate([below, anchor_input, above], self.axis)
59+
return x
60+
61+
def compute_output_shape(self, input_shape):
62+
return input_shape

bayesflow/networks/point_inference_network.py

Lines changed: 16 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def __init__(
2121
scoring_rules: dict[str, ScoringRule],
2222
body_subnet: str | type = "mlp", # naming: shared_subnet / body / subnet ?
2323
heads_subnet: dict[str, str | keras.Layer] = None, # TODO: `type` instead of `keras.Layer` ? Too specific ?
24-
activations: dict[str, keras.layers.Activation | Callable | str] = None,
24+
activations: dict[str, keras.Layer | Callable | str] = None,
2525
**kwargs,
2626
):
2727
super().__init__(
@@ -36,17 +36,17 @@ def __init__(
3636

3737
self.body_subnet = find_network(body_subnet, **kwargs.get("body_subnet_kwargs", {}))
3838

39-
if heads_subnet:
39+
if heads_subnet is not None:
4040
self.heads = {
4141
key: [find_network(value, **kwargs.get("heads_subnet_kwargs", {}).get(key, {}))]
4242
for key, value in heads_subnet.items()
4343
}
4444
else:
4545
self.heads = {key: [] for key in self.scoring_rules.keys()}
4646

47-
if activations:
47+
if activations is not None:
4848
self.activations = {
49-
key: (value if isinstance(value, keras.layers.Activation) else keras.layers.Activation(value))
49+
key: (value if isinstance(value, keras.Layer) else keras.layers.Activation(value))
5050
for key, value in activations.items()
5151
} # make sure that each value is an Activation object
5252
else:
@@ -64,16 +64,16 @@ def __init__(
6464

6565
assert set(self.scoring_rules.keys()) == set(self.heads.keys()) == set(self.activations.keys())
6666

67-
def build(self, xz_shape: Shape, conditions_shape: Shape = None) -> None:
67+
def build(self, xz_shape: Shape, conditions_shape: Shape) -> None:
6868
# build the shared body network
6969
input_shape = conditions_shape
7070
self.body_subnet.build(input_shape)
7171
body_output_shape = self.body_subnet.compute_output_shape(input_shape)
7272

7373
for key in self.heads.keys():
74-
# head_output_shape (excluding batch_size) convention is (*prediction_shape, *parameter_block_shape)
75-
prediction_shape = self.scoring_rules[key].prediction_shape
76-
head_output_shape = prediction_shape + xz_shape[1:]
74+
# head_output_shape (excluding batch_size) convention is (*target_shape, *parameter_block_shape)
75+
target_shape = self.scoring_rules[key].target_shape
76+
head_output_shape = target_shape + xz_shape[1:]
7777

7878
# set correct head shape
7979
self.heads[key][-3].units = prod(head_output_shape)
@@ -91,13 +91,18 @@ def call(
9191
conditions: Tensor = None,
9292
training: bool = False,
9393
**kwargs,
94-
) -> Tensor | tuple[Tensor, Tensor]:
94+
) -> dict[str, Tensor]:
9595
# TODO: remove unnecessary simularity with InferenceNetwork
9696
return self._forward(xz, conditions=conditions, training=training, **kwargs)
9797

9898
def _forward(
99-
self, x: Tensor, conditions: Tensor = None, training: bool = False, **kwargs
100-
) -> Tensor | tuple[Tensor, Tensor]:
99+
self,
100+
x: Tensor,
101+
conditions: Tensor = None,
102+
training: bool = False,
103+
**kwargs,
104+
# TODO: propagate training flag
105+
) -> dict[str, Tensor]:
101106
body_output = self.body_subnet(conditions)
102107

103108
output = dict()

0 commit comments

Comments
 (0)