Skip to content

Commit

Permalink
add compute_output_shape to mlp
Browse files Browse the repository at this point in the history
  • Loading branch information
LarsKue committed Jul 11, 2024
1 parent d8428c8 commit 4b629e4
Showing 1 changed file with 6 additions and 0 deletions.
6 changes: 6 additions & 0 deletions bayesflow/networks/mlp/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,3 +81,9 @@ def call(self, x: Tensor, **kwargs) -> Tensor:
for layer in self.res_blocks:
x = layer(x, training=kwargs.get("training", False))
return x

def compute_output_shape(self, input_shape):
for layer in self.res_blocks:
input_shape = layer.compute_output_shape(input_shape)

return input_shape

0 comments on commit 4b629e4

Please sign in to comment.