diff --git a/bayesflow/adapters/adapter.py b/bayesflow/adapters/adapter.py index d6eb8ce52..f60973e34 100644 --- a/bayesflow/adapters/adapter.py +++ b/bayesflow/adapters/adapter.py @@ -15,7 +15,6 @@ ConvertDType, Drop, ExpandDims, - ElementwiseTransform, # why wasn't this added before? FilterTransform, Keep, LambdaTransform, @@ -80,78 +79,74 @@ def __repr__(self): return f"Adapter([{' -> '.join(map(repr, self.transforms))}])" def __getitem__(self, index): - - if isinstance(index, slice): - if index.start > index.stop: + if isinstance(index, slice): + if index.start > index.stop: raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]") - if index.stop < len(self.transforms): + if index.stop < len(self.transforms): # print("What is the slice?") # print(index) # print(type(index)) - # check that the slice is in range + # check that the slice is in range sliced_transforms = self.transforms[index] # print("Are the sliced transforms a sequence") # print(isinstance(sliced_transforms, Sequence)) # print("What is in the slice?") # print(sliced_transforms) - new_adapter = Adapter(transforms = sliced_transforms) + new_adapter = Adapter(transforms=sliced_transforms) return new_adapter - else: + else: raise IndexError("Index slice out of range") - - elif isinstance(index, int): + + elif isinstance(index, int): if index < 0: - index = index + len(self.transforms) # negative indexing - if index < 0 or index >= len(self.transforms): + index = index + len(self.transforms) # negative indexing + if index < 0 or index >= len(self.transforms): raise IndexError("Adapter index out of range.") sliced_transforms = self.transforms[index] - new_adapter = Adapter(transforms = sliced_transforms) + new_adapter = Adapter(transforms=sliced_transforms) return new_adapter else: raise TypeError("Invalid index type. Must be int or slice.") - - - def __setitem__(self, index, new_value): - if not isinstance(new_value, Adapter): + def __setitem__(self, index, new_value): + if not isinstance(new_value, Adapter): raise TypeError("new_value must be an Adapter instance") - - - new_transform = new_value.transforms - - if len(new_transform) == 0: - raise ValueError("new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform.") + new_transform = new_value.transforms - if isinstance(index, slice): - if index.start > index.stop: + if len(new_transform) == 0: + raise ValueError( + "new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform." + ) + + if isinstance(index, slice): + if index.start > index.stop: raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]") - + if index.stop < len(self.transforms): self.transforms[index] = new_transform - - else: + + else: raise IndexError("Index slice out of range") - - elif isinstance(index, int): - if index < 0: # negative indexing + elif isinstance(index, int): + if index < 0: # negative indexing index = index + len(self.transforms) - - if index < 0 or index >= len(self.transforms): + + if index < 0 or index >= len(self.transforms): raise IndexError("Index out of range.") - # could add that if the index is out of range, like index == len - # then we just add the transform + # could add that if the index is out of range, like index == len + # then we just add the transform print("what is self.transforms[index]?") print(self.transforms[index]) print("what is the value of the newvalue") print(new_transform) print(type(new_transform)) - + self.transforms[index] = new_transform - else: - raise TypeError("Invalid index type. Must be int or slice.") - + else: + raise TypeError("Invalid index type. Must be int or slice.") + def add_transform(self, transform: Transform): self.transforms.append(transform) return self @@ -178,7 +173,7 @@ def apply( self.transforms.append(transform) return self - # Begin of transformed derived from transform classes + # Begin of transformed derived from transform classes def as_set(self, keys: str | Sequence[str]): if isinstance(keys, str): keys = [keys]