Skip to content

Commit

Permalink
black
Browse files Browse the repository at this point in the history
  • Loading branch information
AntonioCarta committed Feb 27, 2024
1 parent 2b62537 commit 2d4414d
Show file tree
Hide file tree
Showing 24 changed files with 114 additions and 189 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -257,12 +257,10 @@ def __len__(self) -> int:
return len(self._benchmark.streams[self._stream])

@overload
def __getitem__(self, exp_id: int) -> Optional[Set[int]]:
...
def __getitem__(self, exp_id: int) -> Optional[Set[int]]: ...

@overload
def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]:
...
def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: ...

def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet:
indexing_collate = _LazyClassesInClassificationExps._slice_collate
Expand Down
18 changes: 9 additions & 9 deletions avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,17 +184,17 @@ def __init__(
invoking the super constructor) to specialize the experience class.
"""

self.experience_factory: Callable[
[TCLStream, int], TDatasetExperience
] = experience_factory
self.experience_factory: Callable[[TCLStream, int], TDatasetExperience] = (
experience_factory
)

self.stream_factory: Callable[
[str, TDatasetScenario], TCLStream
] = stream_factory
self.stream_factory: Callable[[str, TDatasetScenario], TCLStream] = (
stream_factory
)

self.stream_definitions: Dict[
str, StreamDef[TCLDataset]
] = DatasetScenario._check_stream_definitions(stream_definitions)
self.stream_definitions: Dict[str, StreamDef[TCLDataset]] = (
DatasetScenario._check_stream_definitions(stream_definitions)
)
"""
A structure containing the definition of the streams.
"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,9 +143,9 @@ def create_multi_dataset_generic_benchmark(
"complete_test_set_only is True"
)

stream_definitions: Dict[
str, Tuple[Iterable[TaskAwareClassificationDataset]]
] = dict()
stream_definitions: Dict[str, Tuple[Iterable[TaskAwareClassificationDataset]]] = (
dict()
)

for stream_name, dataset_list in input_streams.items():
initial_transform_group = "train"
Expand Down
18 changes: 8 additions & 10 deletions avalanche/benchmarks/scenarios/deprecated/lazy_dataset_sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,9 @@ def __init__(
now, including the ones of dropped experiences.
"""

self.task_labels_field_sequence: Dict[
int, Optional[Sequence[int]]
] = defaultdict(lambda: None)
self.task_labels_field_sequence: Dict[int, Optional[Sequence[int]]] = (
defaultdict(lambda: None)
)
"""
A dictionary mapping each experience to its `targets_task_labels` field.
Expand All @@ -118,12 +118,10 @@ def __len__(self) -> int:
return self._stream_length

@overload
def __getitem__(self, exp_idx: int) -> TCLDataset:
...
def __getitem__(self, exp_idx: int) -> TCLDataset: ...

@overload
def __getitem__(self, exp_idx: slice) -> Sequence[TCLDataset]:
...
def __getitem__(self, exp_idx: slice) -> Sequence[TCLDataset]: ...

def __getitem__(
self, exp_idx: Union[int, slice]
Expand All @@ -135,9 +133,9 @@ def __getitem__(
:return: The dataset associated to the experience.
"""
# A lot of unuseful lines needed for MyPy -_-
indexing_collate: Callable[
[Iterable[TCLDataset]], Sequence[TCLDataset]
] = lambda x: list(x)
indexing_collate: Callable[[Iterable[TCLDataset]], Sequence[TCLDataset]] = (
lambda x: list(x)
)
result = manage_advanced_indexing(
exp_idx,
self._get_experience_and_load_if_needed,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,9 +327,9 @@ class "34" will be mapped to "1", class "11" to "2" and so on.
# used, the user may have defined an amount of classes less than
# the overall amount of classes in the dataset.
if class_id in self.classes_order_original_ids:
self.class_mapping[
class_id
] = self.classes_order_original_ids.index(class_id)
self.class_mapping[class_id] = (
self.classes_order_original_ids.index(class_id)
)
elif self.class_ids_from_zero_in_each_exp:
# Method 2: remap class IDs so that they appear in range [0, N] in
# each experience
Expand Down
6 changes: 2 additions & 4 deletions avalanche/benchmarks/scenarios/detection_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,12 +254,10 @@ def __len__(self):
return len(self._benchmark.streams[self._stream])

@overload
def __getitem__(self, exp_id: int) -> Optional[Set[int]]:
...
def __getitem__(self, exp_id: int) -> Optional[Set[int]]: ...

@overload
def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]:
...
def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: ...

def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet:
indexing_collate = _LazyClassesInDetectionExps._slice_collate
Expand Down
6 changes: 2 additions & 4 deletions avalanche/benchmarks/scenarios/generic_scenario.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,12 +427,10 @@ def __iter__(self) -> Iterator[TCLExperience]:
yield exp

@overload
def __getitem__(self, item: int) -> TCLExperience:
...
def __getitem__(self, item: int) -> TCLExperience: ...

@overload
def __getitem__(self: TSequenceCLStream, item: slice) -> TSequenceCLStream:
...
def __getitem__(self: TSequenceCLStream, item: slice) -> TSequenceCLStream: ...

@final
def __getitem__(
Expand Down
35 changes: 13 additions & 22 deletions avalanche/benchmarks/utils/classification_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,8 +175,7 @@ def _make_taskaware_classification_dataset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset:
...
) -> TaskAwareSupervisedClassificationDataset: ...


@overload
Expand All @@ -190,8 +189,7 @@ def _make_taskaware_classification_dataset(
task_labels: Union[int, Sequence[int]],
targets: Sequence[TTargetType],
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset:
...
) -> TaskAwareSupervisedClassificationDataset: ...


@overload
Expand All @@ -205,8 +203,7 @@ def _make_taskaware_classification_dataset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareClassificationDataset:
...
) -> TaskAwareClassificationDataset: ...


def _make_taskaware_classification_dataset(
Expand Down Expand Up @@ -386,8 +383,7 @@ def _taskaware_classification_subset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset:
...
) -> TaskAwareSupervisedClassificationDataset: ...


@overload
Expand All @@ -403,8 +399,7 @@ def _taskaware_classification_subset(
task_labels: Union[int, Sequence[int]],
targets: Sequence[TTargetType],
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset:
...
) -> TaskAwareSupervisedClassificationDataset: ...


@overload
Expand All @@ -420,8 +415,7 @@ def _taskaware_classification_subset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareClassificationDataset:
...
) -> TaskAwareClassificationDataset: ...


def _taskaware_classification_subset(
Expand Down Expand Up @@ -619,8 +613,7 @@ def _make_taskaware_tensor_classification_dataset(
task_labels: Union[int, Sequence[int]],
targets: Union[Sequence[TTargetType], int],
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset:
...
) -> TaskAwareSupervisedClassificationDataset: ...


@overload
Expand All @@ -633,8 +626,9 @@ def _make_taskaware_tensor_classification_dataset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Union[Sequence[TTargetType], int]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> Union[TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset]:
...
) -> Union[
TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset
]: ...


def _make_taskaware_tensor_classification_dataset(
Expand Down Expand Up @@ -759,8 +753,7 @@ def _concat_taskaware_classification_datasets(
Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset:
...
) -> TaskAwareSupervisedClassificationDataset: ...


@overload
Expand All @@ -774,8 +767,7 @@ def _concat_taskaware_classification_datasets(
task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]],
targets: Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]],
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareSupervisedClassificationDataset:
...
) -> TaskAwareSupervisedClassificationDataset: ...


@overload
Expand All @@ -791,8 +783,7 @@ def _concat_taskaware_classification_datasets(
Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> TaskAwareClassificationDataset:
...
) -> TaskAwareClassificationDataset: ...


def _concat_taskaware_classification_datasets(
Expand Down
6 changes: 2 additions & 4 deletions avalanche/benchmarks/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,12 +344,10 @@ def __eq__(self, other: object):
)

@overload
def __getitem__(self, exp_id: int) -> T_co:
...
def __getitem__(self, exp_id: int) -> T_co: ...

@overload
def __getitem__(self: TAvalancheDataset, exp_id: slice) -> TAvalancheDataset:
...
def __getitem__(self: TAvalancheDataset, exp_id: slice) -> TAvalancheDataset: ...

def __getitem__(
self: TAvalancheDataset, idx: Union[int, slice]
Expand Down
6 changes: 2 additions & 4 deletions avalanche/benchmarks/utils/data_attribute.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,12 +62,10 @@ def __iter__(self):
yield self[i]

@overload
def __getitem__(self, item: int) -> T_co:
...
def __getitem__(self, item: int) -> T_co: ...

@overload
def __getitem__(self, item: slice) -> Sequence[T_co]:
...
def __getitem__(self, item: slice) -> Sequence[T_co]: ...

def __getitem__(self, item: Union[int, slice]) -> Union[T_co, Sequence[T_co]]:
return self.data[item]
Expand Down
6 changes: 2 additions & 4 deletions avalanche/benchmarks/utils/dataset_definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,11 +38,9 @@ class IDataset(Protocol[T_co]):
Note: no __add__ method is defined.
"""

def __getitem__(self, index: int) -> T_co:
...
def __getitem__(self, index: int) -> T_co: ...

def __len__(self) -> int:
...
def __len__(self) -> int: ...


class IDatasetWithTargets(IDataset[T_co], Protocol[T_co, TTargetType_co]):
Expand Down
6 changes: 2 additions & 4 deletions avalanche/benchmarks/utils/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,12 +64,10 @@ def __iter__(self) -> Iterator[TData]:
yield el

@overload
def __getitem__(self, item: int) -> TData:
...
def __getitem__(self, item: int) -> TData: ...

@overload
def __getitem__(self: TSliceSequence, item: slice) -> TSliceSequence:
...
def __getitem__(self: TSliceSequence, item: slice) -> TSliceSequence: ...

@final
def __getitem__(
Expand Down
27 changes: 9 additions & 18 deletions avalanche/benchmarks/utils/detection_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,7 @@ def make_detection_dataset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> SupervisedDetectionDataset:
...
) -> SupervisedDetectionDataset: ...


@overload
Expand All @@ -166,8 +165,7 @@ def make_detection_dataset(
task_labels: Union[int, Sequence[int]],
targets: Sequence[TTargetType],
collate_fn: Optional[Callable[[List], Any]] = None
) -> SupervisedDetectionDataset:
...
) -> SupervisedDetectionDataset: ...


@overload
Expand All @@ -181,8 +179,7 @@ def make_detection_dataset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> DetectionDataset:
...
) -> DetectionDataset: ...


def make_detection_dataset(
Expand Down Expand Up @@ -373,8 +370,7 @@ def detection_subset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> SupervisedDetectionDataset:
...
) -> SupervisedDetectionDataset: ...


@overload
Expand All @@ -390,8 +386,7 @@ def detection_subset(
task_labels: Union[int, Sequence[int]],
targets: Sequence[TTargetType],
collate_fn: Optional[Callable[[List], Any]] = None
) -> SupervisedDetectionDataset:
...
) -> SupervisedDetectionDataset: ...


@overload
Expand All @@ -407,8 +402,7 @@ def detection_subset(
task_labels: Optional[Union[int, Sequence[int]]] = None,
targets: Optional[Sequence[TTargetType]] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> DetectionDataset:
...
) -> DetectionDataset: ...


def detection_subset(
Expand Down Expand Up @@ -595,8 +589,7 @@ def concat_detection_datasets(
Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> SupervisedDetectionDataset:
...
) -> SupervisedDetectionDataset: ...


@overload
Expand All @@ -610,8 +603,7 @@ def concat_detection_datasets(
task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]],
targets: Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]],
collate_fn: Optional[Callable[[List], Any]] = None
) -> SupervisedDetectionDataset:
...
) -> SupervisedDetectionDataset: ...


@overload
Expand All @@ -627,8 +619,7 @@ def concat_detection_datasets(
Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]]
] = None,
collate_fn: Optional[Callable[[List], Any]] = None
) -> DetectionDataset:
...
) -> DetectionDataset: ...


def concat_detection_datasets(
Expand Down
Loading

0 comments on commit 2d4414d

Please sign in to comment.