From ec3f3c6b27884065f25212aa79da46a13fd986fe Mon Sep 17 00:00:00 2001 From: Antonio Carta Date: Wed, 14 Feb 2024 11:34:34 +0100 Subject: [PATCH] revert classes timeline values to sets --- avalanche/benchmarks/scenarios/supervised.py | 11 ++++------- avalanche/models/dynamic_modules.py | 2 +- 2 files changed, 5 insertions(+), 8 deletions(-) diff --git a/avalanche/benchmarks/scenarios/supervised.py b/avalanche/benchmarks/scenarios/supervised.py index 830b37c02..8e46446db 100644 --- a/avalanche/benchmarks/scenarios/supervised.py +++ b/avalanche/benchmarks/scenarios/supervised.py @@ -26,7 +26,6 @@ from avalanche.benchmarks.utils.classification_dataset import ( ClassificationDataset, _as_taskaware_supervised_classification_dataset, - TaskAwareSupervisedClassificationDataset, ) from avalanche.benchmarks.utils.data import AvalancheDataset from avalanche.benchmarks.utils.data_attribute import DataAttribute @@ -399,14 +398,12 @@ def _decorate_stream(obj: CLStream): new_exp = copy(exp) curr_cls = exp.dataset.targets.uniques - new_exp.classes_in_this_experience = list(curr_cls) - new_exp.previous_classes = list(set(prev_cls)) - new_exp.classes_seen_so_far = list(curr_cls.union(prev_cls)) + new_exp.classes_in_this_experience = curr_cls + new_exp.previous_classes = set(prev_cls) + new_exp.classes_seen_so_far = curr_cls.union(prev_cls) # TODO: future_classes ignores repetitions right now... # implement and test scenario with repetitions - new_exp.future_classes = list( - all_cls.difference(new_exp.classes_seen_so_far) - ) + new_exp.future_classes = all_cls.difference(new_exp.classes_seen_so_far) new_stream.append(new_exp) prev_cls = prev_cls.union(curr_cls) diff --git a/avalanche/models/dynamic_modules.py b/avalanche/models/dynamic_modules.py index 1cf580983..3086681fe 100644 --- a/avalanche/models/dynamic_modules.py +++ b/avalanche/models/dynamic_modules.py @@ -246,7 +246,7 @@ def adaptation(self, experience: CLExperience): self.active_units[: old_act_units.shape[0]] = old_act_units # update with new active classes if self.training: - self.active_units[curr_classes] = 1 + self.active_units[list(curr_classes)] = 1 # update classifier weights if old_nclasses == new_nclasses: