From 2d4414dea000178b6b42beac56a60320a9c704c2 Mon Sep 17 00:00:00 2001 From: AntonioCarta Date: Tue, 27 Feb 2024 16:32:31 +0100 Subject: [PATCH] black --- .../deprecated/classification_scenario.py | 6 +- .../scenarios/deprecated/dataset_scenario.py | 18 ++--- .../deprecated/generic_benchmark_creation.py | 6 +- .../deprecated/lazy_dataset_sequence.py | 18 +++-- .../deprecated/new_classes/nc_scenario.py | 6 +- .../scenarios/detection_scenario.py | 6 +- .../benchmarks/scenarios/generic_scenario.py | 6 +- .../utils/classification_dataset.py | 35 ++++------ avalanche/benchmarks/utils/data.py | 6 +- avalanche/benchmarks/utils/data_attribute.py | 6 +- .../benchmarks/utils/dataset_definitions.py | 6 +- avalanche/benchmarks/utils/dataset_utils.py | 6 +- .../benchmarks/utils/detection_dataset.py | 27 +++----- avalanche/benchmarks/utils/flat_data.py | 18 ++--- avalanche/evaluation/metric_definitions.py | 6 +- .../evaluation/metrics/forgetting_bwt.py | 9 +-- .../evaluation/metrics/labels_repartition.py | 6 +- avalanche/training/plugins/bic.py | 6 +- avalanche/training/supervised/expert_gate.py | 6 +- .../training/supervised/joint_training.py | 6 +- avalanche/training/templates/base.py | 12 ++-- .../templates/strategy_mixin_protocol.py | 66 +++++++------------ examples/detection.py | 8 +-- examples/detection_lvis.py | 8 +-- 24 files changed, 114 insertions(+), 189 deletions(-) diff --git a/avalanche/benchmarks/scenarios/deprecated/classification_scenario.py b/avalanche/benchmarks/scenarios/deprecated/classification_scenario.py index ecdabacaa..b195a27f6 100644 --- a/avalanche/benchmarks/scenarios/deprecated/classification_scenario.py +++ b/avalanche/benchmarks/scenarios/deprecated/classification_scenario.py @@ -257,12 +257,10 @@ def __len__(self) -> int: return len(self._benchmark.streams[self._stream]) @overload - def __getitem__(self, exp_id: int) -> Optional[Set[int]]: - ... + def __getitem__(self, exp_id: int) -> Optional[Set[int]]: ... @overload - def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: - ... + def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: ... def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet: indexing_collate = _LazyClassesInClassificationExps._slice_collate diff --git a/avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py b/avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py index 5547a23b5..5af73eb9b 100644 --- a/avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py +++ b/avalanche/benchmarks/scenarios/deprecated/dataset_scenario.py @@ -184,17 +184,17 @@ def __init__( invoking the super constructor) to specialize the experience class. """ - self.experience_factory: Callable[ - [TCLStream, int], TDatasetExperience - ] = experience_factory + self.experience_factory: Callable[[TCLStream, int], TDatasetExperience] = ( + experience_factory + ) - self.stream_factory: Callable[ - [str, TDatasetScenario], TCLStream - ] = stream_factory + self.stream_factory: Callable[[str, TDatasetScenario], TCLStream] = ( + stream_factory + ) - self.stream_definitions: Dict[ - str, StreamDef[TCLDataset] - ] = DatasetScenario._check_stream_definitions(stream_definitions) + self.stream_definitions: Dict[str, StreamDef[TCLDataset]] = ( + DatasetScenario._check_stream_definitions(stream_definitions) + ) """ A structure containing the definition of the streams. """ diff --git a/avalanche/benchmarks/scenarios/deprecated/generic_benchmark_creation.py b/avalanche/benchmarks/scenarios/deprecated/generic_benchmark_creation.py index 9bf1e9b6e..351d7b6f3 100644 --- a/avalanche/benchmarks/scenarios/deprecated/generic_benchmark_creation.py +++ b/avalanche/benchmarks/scenarios/deprecated/generic_benchmark_creation.py @@ -143,9 +143,9 @@ def create_multi_dataset_generic_benchmark( "complete_test_set_only is True" ) - stream_definitions: Dict[ - str, Tuple[Iterable[TaskAwareClassificationDataset]] - ] = dict() + stream_definitions: Dict[str, Tuple[Iterable[TaskAwareClassificationDataset]]] = ( + dict() + ) for stream_name, dataset_list in input_streams.items(): initial_transform_group = "train" diff --git a/avalanche/benchmarks/scenarios/deprecated/lazy_dataset_sequence.py b/avalanche/benchmarks/scenarios/deprecated/lazy_dataset_sequence.py index 1d20fb774..82251d0da 100644 --- a/avalanche/benchmarks/scenarios/deprecated/lazy_dataset_sequence.py +++ b/avalanche/benchmarks/scenarios/deprecated/lazy_dataset_sequence.py @@ -99,9 +99,9 @@ def __init__( now, including the ones of dropped experiences. """ - self.task_labels_field_sequence: Dict[ - int, Optional[Sequence[int]] - ] = defaultdict(lambda: None) + self.task_labels_field_sequence: Dict[int, Optional[Sequence[int]]] = ( + defaultdict(lambda: None) + ) """ A dictionary mapping each experience to its `targets_task_labels` field. @@ -118,12 +118,10 @@ def __len__(self) -> int: return self._stream_length @overload - def __getitem__(self, exp_idx: int) -> TCLDataset: - ... + def __getitem__(self, exp_idx: int) -> TCLDataset: ... @overload - def __getitem__(self, exp_idx: slice) -> Sequence[TCLDataset]: - ... + def __getitem__(self, exp_idx: slice) -> Sequence[TCLDataset]: ... def __getitem__( self, exp_idx: Union[int, slice] @@ -135,9 +133,9 @@ def __getitem__( :return: The dataset associated to the experience. """ # A lot of unuseful lines needed for MyPy -_- - indexing_collate: Callable[ - [Iterable[TCLDataset]], Sequence[TCLDataset] - ] = lambda x: list(x) + indexing_collate: Callable[[Iterable[TCLDataset]], Sequence[TCLDataset]] = ( + lambda x: list(x) + ) result = manage_advanced_indexing( exp_idx, self._get_experience_and_load_if_needed, diff --git a/avalanche/benchmarks/scenarios/deprecated/new_classes/nc_scenario.py b/avalanche/benchmarks/scenarios/deprecated/new_classes/nc_scenario.py index fcbe8610b..a5509b18d 100644 --- a/avalanche/benchmarks/scenarios/deprecated/new_classes/nc_scenario.py +++ b/avalanche/benchmarks/scenarios/deprecated/new_classes/nc_scenario.py @@ -327,9 +327,9 @@ class "34" will be mapped to "1", class "11" to "2" and so on. # used, the user may have defined an amount of classes less than # the overall amount of classes in the dataset. if class_id in self.classes_order_original_ids: - self.class_mapping[ - class_id - ] = self.classes_order_original_ids.index(class_id) + self.class_mapping[class_id] = ( + self.classes_order_original_ids.index(class_id) + ) elif self.class_ids_from_zero_in_each_exp: # Method 2: remap class IDs so that they appear in range [0, N] in # each experience diff --git a/avalanche/benchmarks/scenarios/detection_scenario.py b/avalanche/benchmarks/scenarios/detection_scenario.py index 8b16be027..111c5ac11 100644 --- a/avalanche/benchmarks/scenarios/detection_scenario.py +++ b/avalanche/benchmarks/scenarios/detection_scenario.py @@ -254,12 +254,10 @@ def __len__(self): return len(self._benchmark.streams[self._stream]) @overload - def __getitem__(self, exp_id: int) -> Optional[Set[int]]: - ... + def __getitem__(self, exp_id: int) -> Optional[Set[int]]: ... @overload - def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: - ... + def __getitem__(self, exp_id: slice) -> Tuple[Optional[Set[int]], ...]: ... def __getitem__(self, exp_id: Union[int, slice]) -> LazyClassesInExpsRet: indexing_collate = _LazyClassesInDetectionExps._slice_collate diff --git a/avalanche/benchmarks/scenarios/generic_scenario.py b/avalanche/benchmarks/scenarios/generic_scenario.py index d7b5ba09c..34da0d249 100644 --- a/avalanche/benchmarks/scenarios/generic_scenario.py +++ b/avalanche/benchmarks/scenarios/generic_scenario.py @@ -427,12 +427,10 @@ def __iter__(self) -> Iterator[TCLExperience]: yield exp @overload - def __getitem__(self, item: int) -> TCLExperience: - ... + def __getitem__(self, item: int) -> TCLExperience: ... @overload - def __getitem__(self: TSequenceCLStream, item: slice) -> TSequenceCLStream: - ... + def __getitem__(self: TSequenceCLStream, item: slice) -> TSequenceCLStream: ... @final def __getitem__( diff --git a/avalanche/benchmarks/utils/classification_dataset.py b/avalanche/benchmarks/utils/classification_dataset.py index acb4f880b..ad6db47f6 100644 --- a/avalanche/benchmarks/utils/classification_dataset.py +++ b/avalanche/benchmarks/utils/classification_dataset.py @@ -175,8 +175,7 @@ def _make_taskaware_classification_dataset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -190,8 +189,7 @@ def _make_taskaware_classification_dataset( task_labels: Union[int, Sequence[int]], targets: Sequence[TTargetType], collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -205,8 +203,7 @@ def _make_taskaware_classification_dataset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareClassificationDataset: - ... +) -> TaskAwareClassificationDataset: ... def _make_taskaware_classification_dataset( @@ -386,8 +383,7 @@ def _taskaware_classification_subset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -403,8 +399,7 @@ def _taskaware_classification_subset( task_labels: Union[int, Sequence[int]], targets: Sequence[TTargetType], collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -420,8 +415,7 @@ def _taskaware_classification_subset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareClassificationDataset: - ... +) -> TaskAwareClassificationDataset: ... def _taskaware_classification_subset( @@ -619,8 +613,7 @@ def _make_taskaware_tensor_classification_dataset( task_labels: Union[int, Sequence[int]], targets: Union[Sequence[TTargetType], int], collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -633,8 +626,9 @@ def _make_taskaware_tensor_classification_dataset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Union[Sequence[TTargetType], int]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> Union[TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset]: - ... +) -> Union[ + TaskAwareClassificationDataset, TaskAwareSupervisedClassificationDataset +]: ... def _make_taskaware_tensor_classification_dataset( @@ -759,8 +753,7 @@ def _concat_taskaware_classification_datasets( Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]] ] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -774,8 +767,7 @@ def _concat_taskaware_classification_datasets( task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]], targets: Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]], collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareSupervisedClassificationDataset: - ... +) -> TaskAwareSupervisedClassificationDataset: ... @overload @@ -791,8 +783,7 @@ def _concat_taskaware_classification_datasets( Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]] ] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> TaskAwareClassificationDataset: - ... +) -> TaskAwareClassificationDataset: ... def _concat_taskaware_classification_datasets( diff --git a/avalanche/benchmarks/utils/data.py b/avalanche/benchmarks/utils/data.py index 76908345e..9985e19fa 100644 --- a/avalanche/benchmarks/utils/data.py +++ b/avalanche/benchmarks/utils/data.py @@ -344,12 +344,10 @@ def __eq__(self, other: object): ) @overload - def __getitem__(self, exp_id: int) -> T_co: - ... + def __getitem__(self, exp_id: int) -> T_co: ... @overload - def __getitem__(self: TAvalancheDataset, exp_id: slice) -> TAvalancheDataset: - ... + def __getitem__(self: TAvalancheDataset, exp_id: slice) -> TAvalancheDataset: ... def __getitem__( self: TAvalancheDataset, idx: Union[int, slice] diff --git a/avalanche/benchmarks/utils/data_attribute.py b/avalanche/benchmarks/utils/data_attribute.py index 3cb88e58c..d3ee4b0d9 100644 --- a/avalanche/benchmarks/utils/data_attribute.py +++ b/avalanche/benchmarks/utils/data_attribute.py @@ -62,12 +62,10 @@ def __iter__(self): yield self[i] @overload - def __getitem__(self, item: int) -> T_co: - ... + def __getitem__(self, item: int) -> T_co: ... @overload - def __getitem__(self, item: slice) -> Sequence[T_co]: - ... + def __getitem__(self, item: slice) -> Sequence[T_co]: ... def __getitem__(self, item: Union[int, slice]) -> Union[T_co, Sequence[T_co]]: return self.data[item] diff --git a/avalanche/benchmarks/utils/dataset_definitions.py b/avalanche/benchmarks/utils/dataset_definitions.py index ca8de5b54..de9fa05da 100644 --- a/avalanche/benchmarks/utils/dataset_definitions.py +++ b/avalanche/benchmarks/utils/dataset_definitions.py @@ -38,11 +38,9 @@ class IDataset(Protocol[T_co]): Note: no __add__ method is defined. """ - def __getitem__(self, index: int) -> T_co: - ... + def __getitem__(self, index: int) -> T_co: ... - def __len__(self) -> int: - ... + def __len__(self) -> int: ... class IDatasetWithTargets(IDataset[T_co], Protocol[T_co, TTargetType_co]): diff --git a/avalanche/benchmarks/utils/dataset_utils.py b/avalanche/benchmarks/utils/dataset_utils.py index 717d67e75..867ba9cc6 100644 --- a/avalanche/benchmarks/utils/dataset_utils.py +++ b/avalanche/benchmarks/utils/dataset_utils.py @@ -64,12 +64,10 @@ def __iter__(self) -> Iterator[TData]: yield el @overload - def __getitem__(self, item: int) -> TData: - ... + def __getitem__(self, item: int) -> TData: ... @overload - def __getitem__(self: TSliceSequence, item: slice) -> TSliceSequence: - ... + def __getitem__(self: TSliceSequence, item: slice) -> TSliceSequence: ... @final def __getitem__( diff --git a/avalanche/benchmarks/utils/detection_dataset.py b/avalanche/benchmarks/utils/detection_dataset.py index 7f3b3b632..6c5efb43f 100644 --- a/avalanche/benchmarks/utils/detection_dataset.py +++ b/avalanche/benchmarks/utils/detection_dataset.py @@ -151,8 +151,7 @@ def make_detection_dataset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -166,8 +165,7 @@ def make_detection_dataset( task_labels: Union[int, Sequence[int]], targets: Sequence[TTargetType], collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -181,8 +179,7 @@ def make_detection_dataset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> DetectionDataset: - ... +) -> DetectionDataset: ... def make_detection_dataset( @@ -373,8 +370,7 @@ def detection_subset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -390,8 +386,7 @@ def detection_subset( task_labels: Union[int, Sequence[int]], targets: Sequence[TTargetType], collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -407,8 +402,7 @@ def detection_subset( task_labels: Optional[Union[int, Sequence[int]]] = None, targets: Optional[Sequence[TTargetType]] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> DetectionDataset: - ... +) -> DetectionDataset: ... def detection_subset( @@ -595,8 +589,7 @@ def concat_detection_datasets( Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]] ] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -610,8 +603,7 @@ def concat_detection_datasets( task_labels: Union[int, Sequence[int], Sequence[Sequence[int]]], targets: Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]], collate_fn: Optional[Callable[[List], Any]] = None -) -> SupervisedDetectionDataset: - ... +) -> SupervisedDetectionDataset: ... @overload @@ -627,8 +619,7 @@ def concat_detection_datasets( Union[Sequence[TTargetType], Sequence[Sequence[TTargetType]]] ] = None, collate_fn: Optional[Callable[[List], Any]] = None -) -> DetectionDataset: - ... +) -> DetectionDataset: ... def concat_detection_datasets( diff --git a/avalanche/benchmarks/utils/flat_data.py b/avalanche/benchmarks/utils/flat_data.py index d9fd73ee0..02d3681cd 100644 --- a/avalanche/benchmarks/utils/flat_data.py +++ b/avalanche/benchmarks/utils/flat_data.py @@ -66,9 +66,9 @@ def __init__( else: new_lists.append(ll) - self._lists: Optional[ - List[Sequence[int]] - ] = new_lists # freed after eagerification + self._lists: Optional[List[Sequence[int]]] = ( + new_lists # freed after eagerification + ) self._offset: int = int(offset) self._eager_list: Optional[np.ndarray] = None """This is the list where we save indices @@ -408,12 +408,10 @@ def _get_idx(self, idx) -> Tuple[int, int]: return dataset_idx, int(idx) @overload - def __getitem__(self, item: int) -> T_co: - ... + def __getitem__(self, item: int) -> T_co: ... @overload - def __getitem__(self: TFlatData, item: slice) -> TFlatData: - ... + def __getitem__(self: TFlatData, item: slice) -> TFlatData: ... def __getitem__(self: TFlatData, item: Union[int, slice]) -> Union[T_co, TFlatData]: if isinstance(item, (int, np.integer)): @@ -470,12 +468,10 @@ def __len__(self): return self._size @overload - def __getitem__(self, index: int) -> DataT: - ... + def __getitem__(self, index: int) -> DataT: ... @overload - def __getitem__(self, index: slice) -> "ConstantSequence[DataT]": - ... + def __getitem__(self, index: slice) -> "ConstantSequence[DataT]": ... def __getitem__( self, index: Union[int, slice] diff --git a/avalanche/evaluation/metric_definitions.py b/avalanche/evaluation/metric_definitions.py index 1934d5392..c91a7122c 100644 --- a/avalanche/evaluation/metric_definitions.py +++ b/avalanche/evaluation/metric_definitions.py @@ -215,8 +215,7 @@ def __init__( ] = "experience", emit_at: Literal["iteration", "epoch", "experience", "stream"] = "experience", mode: Literal["train"] = "train", - ): - ... + ): ... @overload def __init__( @@ -225,8 +224,7 @@ def __init__( reset_at: Literal["iteration", "experience", "stream", "never"] = "experience", emit_at: Literal["iteration", "experience", "stream"] = "experience", mode: Literal["eval"] = "eval", - ): - ... + ): ... def __init__( self, metric: TMetric, reset_at="experience", emit_at="experience", mode="eval" diff --git a/avalanche/evaluation/metrics/forgetting_bwt.py b/avalanche/evaluation/metrics/forgetting_bwt.py index 9e5f79094..29d9ffd28 100644 --- a/avalanche/evaluation/metrics/forgetting_bwt.py +++ b/avalanche/evaluation/metrics/forgetting_bwt.py @@ -526,18 +526,15 @@ def forgetting_metrics(*, experience=False, stream=False) -> List[PluginMetric]: @overload -def forgetting_to_bwt(f: float) -> float: - ... +def forgetting_to_bwt(f: float) -> float: ... @overload -def forgetting_to_bwt(f: Dict[int, float]) -> Dict[int, float]: - ... +def forgetting_to_bwt(f: Dict[int, float]) -> Dict[int, float]: ... @overload -def forgetting_to_bwt(f: None) -> None: - ... +def forgetting_to_bwt(f: None) -> None: ... def forgetting_to_bwt(f: Optional[Union[float, Dict[int, float]]]): diff --git a/avalanche/evaluation/metrics/labels_repartition.py b/avalanche/evaluation/metrics/labels_repartition.py index 71d6fb0e9..7a5304696 100644 --- a/avalanche/evaluation/metrics/labels_repartition.py +++ b/avalanche/evaluation/metrics/labels_repartition.py @@ -85,8 +85,7 @@ def __init__( ] = default_history_repartition_image_creator, mode: Literal["train"] = "train", emit_reset_at: Literal["stream", "experience", "epoch"] = "epoch", - ): - ... + ): ... @overload def __init__( @@ -97,8 +96,7 @@ def __init__( ] = default_history_repartition_image_creator, mode: Literal["eval"] = "eval", emit_reset_at: Literal["stream", "experience"], - ): - ... + ): ... def __init__( self, diff --git a/avalanche/training/plugins/bic.py b/avalanche/training/plugins/bic.py index 61535fd37..caef605b9 100644 --- a/avalanche/training/plugins/bic.py +++ b/avalanche/training/plugins/bic.py @@ -454,9 +454,9 @@ def _classes_groups(self, strategy: SupervisedTemplate): # - "current" classes: seen in current_experience # "initial" classes - initial_classes: Set[ - int - ] = set() # pre_initial_cl in the original implementation + initial_classes: Set[int] = ( + set() + ) # pre_initial_cl in the original implementation previous_classes: Set[int] = set() # pre_new_cl in the original implementation current_classes: Set[int] = set() # new_cl in the original implementation # Note: pre_initial_cl + pre_new_cl is "initial_cl" in the original implementation diff --git a/avalanche/training/supervised/expert_gate.py b/avalanche/training/supervised/expert_gate.py index a76637514..abb249216 100644 --- a/avalanche/training/supervised/expert_gate.py +++ b/avalanche/training/supervised/expert_gate.py @@ -239,9 +239,9 @@ def _select_expert(self, strategy: "SupervisedTemplate", task_label): # Iterate through all autoencoders to get error values for autoencoder_id in strategy.model.autoencoder_dict: - error_dict[ - str(autoencoder_id) - ] = self._get_average_reconstruction_error(strategy, autoencoder_id) + error_dict[str(autoencoder_id)] = ( + self._get_average_reconstruction_error(strategy, autoencoder_id) + ) # Send error dictionary to get most relevant autoencoder relatedness_dict = self._task_relatedness(strategy, error_dict, task_label) diff --git a/avalanche/training/supervised/joint_training.py b/avalanche/training/supervised/joint_training.py index 4e145532e..08294cfc0 100644 --- a/avalanche/training/supervised/joint_training.py +++ b/avalanche/training/supervised/joint_training.py @@ -139,9 +139,9 @@ def train( ) # Normalize training and eval data. - experiences_list: Iterable[ - TDatasetExperience - ] = _experiences_parameter_as_iterable(experiences) + experiences_list: Iterable[TDatasetExperience] = ( + _experiences_parameter_as_iterable(experiences) + ) if eval_streams is None: eval_streams = [experiences_list] diff --git a/avalanche/training/templates/base.py b/avalanche/training/templates/base.py index 8dd5c8c03..2ec228291 100644 --- a/avalanche/training/templates/base.py +++ b/avalanche/training/templates/base.py @@ -147,9 +147,9 @@ def train( self.model.to(self.device) # Normalize training and eval data. - experiences_list: Iterable[ - TExperienceType - ] = _experiences_parameter_as_iterable(experiences) + experiences_list: Iterable[TExperienceType] = ( + _experiences_parameter_as_iterable(experiences) + ) if eval_streams is None: eval_streams = [experiences_list] @@ -201,9 +201,9 @@ def eval( self.is_training = False self.model.eval() - experiences_list: Iterable[ - TExperienceType - ] = _experiences_parameter_as_iterable(experiences) + experiences_list: Iterable[TExperienceType] = ( + _experiences_parameter_as_iterable(experiences) + ) self.current_eval_stream = experiences_list self._before_eval(**kwargs) diff --git a/avalanche/training/templates/strategy_mixin_protocol.py b/avalanche/training/templates/strategy_mixin_protocol.py index 847d6ef78..4596c9d39 100644 --- a/avalanche/training/templates/strategy_mixin_protocol.py +++ b/avalanche/training/templates/strategy_mixin_protocol.py @@ -58,53 +58,37 @@ class SGDStrategyProtocol( _criterion: CriterionType - def forward(self) -> TMBOutput: - ... + def forward(self) -> TMBOutput: ... - def criterion(self) -> Tensor: - ... + def criterion(self) -> Tensor: ... - def backward(self) -> None: - ... + def backward(self) -> None: ... - def _make_empty_loss(self) -> Tensor: - ... + def _make_empty_loss(self) -> Tensor: ... - def make_optimizer(self, **kwargs): - ... + def make_optimizer(self, **kwargs): ... - def optimizer_step(self) -> None: - ... + def optimizer_step(self) -> None: ... - def model_adaptation(self, model: Optional[Module] = None) -> Module: - ... + def model_adaptation(self, model: Optional[Module] = None) -> Module: ... - def _unpack_minibatch(self): - ... + def _unpack_minibatch(self): ... - def _before_training_iteration(self, **kwargs): - ... + def _before_training_iteration(self, **kwargs): ... - def _before_forward(self, **kwargs): - ... + def _before_forward(self, **kwargs): ... - def _after_forward(self, **kwargs): - ... + def _after_forward(self, **kwargs): ... - def _before_backward(self, **kwargs): - ... + def _before_backward(self, **kwargs): ... - def _after_backward(self, **kwargs): - ... + def _after_backward(self, **kwargs): ... - def _before_update(self, **kwargs): - ... + def _before_update(self, **kwargs): ... - def _after_update(self, **kwargs): - ... + def _after_update(self, **kwargs): ... - def _after_training_iteration(self, **kwargs): - ... + def _after_training_iteration(self, **kwargs): ... class SupervisedStrategyProtocol( @@ -122,23 +106,17 @@ class MetaLearningStrategyProtocol( SGDStrategyProtocol[TSGDExperienceType, TMBInput, TMBOutput], Protocol[TSGDExperienceType, TMBInput, TMBOutput], ): - def _before_inner_updates(self, **kwargs): - ... + def _before_inner_updates(self, **kwargs): ... - def _inner_updates(self, **kwargs): - ... + def _inner_updates(self, **kwargs): ... - def _after_inner_updates(self, **kwargs): - ... + def _after_inner_updates(self, **kwargs): ... - def _before_outer_update(self, **kwargs): - ... + def _before_outer_update(self, **kwargs): ... - def _outer_update(self, **kwargs): - ... + def _outer_update(self, **kwargs): ... - def _after_outer_update(self, **kwargs): - ... + def _after_outer_update(self, **kwargs): ... __all__ = [ diff --git a/examples/detection.py b/examples/detection.py index 3e7f7d94e..ec2e92119 100644 --- a/examples/detection.py +++ b/examples/detection.py @@ -160,15 +160,11 @@ def obtain_base_model(segmentation: bool): pretrain_argument["pretrained"] = True else: if segmentation: - pretrain_argument[ - "weights" - ] = ( + pretrain_argument["weights"] = ( torchvision.models.detection.mask_rcnn.MaskRCNN_ResNet50_FPN_Weights.DEFAULT ) else: - pretrain_argument[ - "weights" - ] = ( + pretrain_argument["weights"] = ( torchvision.models.detection.faster_rcnn.FasterRCNN_ResNet50_FPN_Weights.DEFAULT ) diff --git a/examples/detection_lvis.py b/examples/detection_lvis.py index ed86e5b2c..330b7a9ec 100644 --- a/examples/detection_lvis.py +++ b/examples/detection_lvis.py @@ -142,15 +142,11 @@ def obtain_base_model(segmentation: bool): pretrain_argument["pretrained"] = True else: if segmentation: - pretrain_argument[ - "weights" - ] = ( + pretrain_argument["weights"] = ( torchvision.models.detection.mask_rcnn.MaskRCNN_ResNet50_FPN_Weights.DEFAULT ) else: - pretrain_argument[ - "weights" - ] = ( + pretrain_argument["weights"] = ( torchvision.models.detection.faster_rcnn.FasterRCNN_ResNet50_FPN_Weights.DEFAULT )