Skip to content

Commit

Permalink
hotfix: serialize indices in concatenate transform
Browse files Browse the repository at this point in the history
  • Loading branch information
vpratz committed Feb 13, 2025
1 parent 8210610 commit b681034
Showing 1 changed file with 4 additions and 2 deletions.
6 changes: 4 additions & 2 deletions bayesflow/adapters/transforms/concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,26 +30,28 @@ class Concatenate(Transform):
)
"""

def __init__(self, keys: Sequence[str], *, into: str, axis: int = -1):
def __init__(self, keys: Sequence[str], *, into: str, axis: int = -1, _indices: list | None = None):
self.keys = keys
self.into = into
self.axis = axis

self.indices = None
self.indices = _indices

@classmethod
def from_config(cls, config: dict, custom_objects=None) -> "Concatenate":
return cls(
keys=deserialize(config["keys"], custom_objects),
into=deserialize(config["into"], custom_objects),
axis=deserialize(config["axis"], custom_objects),
_indices=deserialize(config["indices"], custom_objects),
)

def get_config(self) -> dict:
return {
"keys": serialize(self.keys),
"into": serialize(self.into),
"axis": serialize(self.axis),
"indices": serialize(self.indices),
}

def forward(self, data: dict[str, any], *, strict: bool = True, **kwargs) -> dict[str, any]:
Expand Down

0 comments on commit b681034

Please sign in to comment.