Skip to content

Commit

Permalink
Enable squeeze option by broadcast
Browse files Browse the repository at this point in the history
  • Loading branch information
stefanradev93 committed Dec 20, 2024
1 parent 2068b5e commit 5038806
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion bayesflow/adapters/transforms/broadcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,15 @@ class Broadcast(Transform):
It is recommended to precede this transform with a :class:`bayesflow.adapters.transforms.ToArray` transform.
"""

def __init__(self, keys: Sequence[str], *, to: str, expand: str | int | tuple = "left", exclude: int | tuple = -1):
def __init__(
self,
keys: Sequence[str],
*,
to: str,
expand: str | int | tuple = "left",
exclude: int | tuple = -1,
squeeze: int | tuple = None,
):
super().__init__()
self.keys = keys
self.to = to
Expand All @@ -67,6 +75,7 @@ def __init__(self, keys: Sequence[str], *, to: str, expand: str | int | tuple =
exclude = (exclude,)

self.exclude = exclude
self.squeeze = squeeze

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
Expand All @@ -75,6 +84,7 @@ def from_config(cls, config: dict, custom_objects=None) -> "Broadcast":
to=deserialize(config["to"], custom_objects),
expand=deserialize(config["expand"], custom_objects),
exclude=deserialize(config["exclude"], custom_objects),
squeeze=deserialize(config["squeeze"], custom_objects),
)

def get_config(self) -> dict:
Expand All @@ -83,6 +93,7 @@ def get_config(self) -> dict:
"to": serialize(self.to),
"expand": serialize(self.expand),
"exclude": serialize(self.exclude),
"squeeze": serialize(self.squeeze),
}

# noinspection PyMethodOverriding
Expand Down Expand Up @@ -115,6 +126,9 @@ def forward(self, data: dict[str, np.ndarray], **kwargs) -> dict[str, np.ndarray

data[k] = np.broadcast_to(data[k], new_shape)

if self.squeeze is not None:
data[k] = np.squeeze(data[k], axis=self.squeeze)

return data

# noinspection PyMethodOverriding
Expand Down

0 comments on commit 5038806

Please sign in to comment.