Skip to content

Commit

Permalink
ran linter
Browse files Browse the repository at this point in the history
  • Loading branch information
eodole committed Dec 17, 2024
1 parent 71374ef commit 4fce744
Showing 1 changed file with 35 additions and 40 deletions.
75 changes: 35 additions & 40 deletions bayesflow/adapters/adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
ConvertDType,
Drop,
ExpandDims,
ElementwiseTransform, # why wasn't this added before?
FilterTransform,
Keep,
LambdaTransform,
Expand Down Expand Up @@ -80,78 +79,74 @@ def __repr__(self):
return f"Adapter([{' -> '.join(map(repr, self.transforms))}])"

def __getitem__(self, index):

if isinstance(index, slice):
if index.start > index.stop:
if isinstance(index, slice):
if index.start > index.stop:
raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]")
if index.stop < len(self.transforms):
if index.stop < len(self.transforms):
# print("What is the slice?")
# print(index)
# print(type(index))
# check that the slice is in range
# check that the slice is in range
sliced_transforms = self.transforms[index]
# print("Are the sliced transforms a sequence")
# print(isinstance(sliced_transforms, Sequence))
# print("What is in the slice?")
# print(sliced_transforms)
new_adapter = Adapter(transforms = sliced_transforms)
new_adapter = Adapter(transforms=sliced_transforms)
return new_adapter
else:
else:
raise IndexError("Index slice out of range")
elif isinstance(index, int):

elif isinstance(index, int):
if index < 0:
index = index + len(self.transforms) # negative indexing
if index < 0 or index >= len(self.transforms):
index = index + len(self.transforms) # negative indexing
if index < 0 or index >= len(self.transforms):
raise IndexError("Adapter index out of range.")
sliced_transforms = self.transforms[index]
new_adapter = Adapter(transforms = sliced_transforms)
new_adapter = Adapter(transforms=sliced_transforms)
return new_adapter
else:
raise TypeError("Invalid index type. Must be int or slice.")


def __setitem__(self, index, new_value):

if not isinstance(new_value, Adapter):
def __setitem__(self, index, new_value):
if not isinstance(new_value, Adapter):
raise TypeError("new_value must be an Adapter instance")


new_transform = new_value.transforms

if len(new_transform) == 0:
raise ValueError("new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform.")

new_transform = new_value.transforms

if isinstance(index, slice):
if index.start > index.stop:
if len(new_transform) == 0:
raise ValueError(
"new_value is an Adapter instance without any specified transforms, new_value Adapter must contain at least one transform."
)

if isinstance(index, slice):
if index.start > index.stop:
raise IndexError("Index slice must be positive integers such that a < b for adapter[a:b]")

if index.stop < len(self.transforms):
self.transforms[index] = new_transform
else:

else:
raise IndexError("Index slice out of range")


elif isinstance(index, int):
if index < 0: # negative indexing
elif isinstance(index, int):
if index < 0: # negative indexing
index = index + len(self.transforms)
if index < 0 or index >= len(self.transforms):

if index < 0 or index >= len(self.transforms):
raise IndexError("Index out of range.")
# could add that if the index is out of range, like index == len
# then we just add the transform
# could add that if the index is out of range, like index == len
# then we just add the transform
print("what is self.transforms[index]?")
print(self.transforms[index])
print("what is the value of the newvalue")
print(new_transform)
print(type(new_transform))

self.transforms[index] = new_transform
else:
raise TypeError("Invalid index type. Must be int or slice.")
else:
raise TypeError("Invalid index type. Must be int or slice.")

def add_transform(self, transform: Transform):
self.transforms.append(transform)
return self
Expand All @@ -178,7 +173,7 @@ def apply(
self.transforms.append(transform)
return self

# Begin of transformed derived from transform classes
# Begin of transformed derived from transform classes
def as_set(self, keys: str | Sequence[str]):
if isinstance(keys, str):
keys = [keys]
Expand Down

0 comments on commit 4fce744

Please sign in to comment.