Skip to content

Commit

Permalink
added comment, some more tests, renamed method
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbinSou committed Feb 20, 2024
1 parent 1b049f0 commit 0a837d1
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 6 deletions.
16 changes: 11 additions & 5 deletions avalanche/models/dynamic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,13 +27,19 @@ def avalanche_model_adaptation(
_visited=None,
_initial_call: bool = True,
):
# _initial_call is set to true in the first iteration of the adaptation
# If initial_call is not true anymore, it means that the depth of the call is
# more than 1 and the adaptation is considered as "automatic" <=> done inside the
# recursive loop, Automatic adaptation calls will not adapt modules that
# have the _auto_adapt set to False

if _visited is None:
_visited = []
_visited = set()

if module in _visited:
return

_visited.append(module)
_visited.add(module)

if isinstance(module, DynamicModule):
if (not _initial_call) and (not module._auto_adapt):
Expand Down Expand Up @@ -67,10 +73,10 @@ def __init__(self, auto_adapt=True):
super().__init__()
self._auto_adapt = auto_adapt

def adapt(self, experience):
def recursive_adaptation(self, experience):
"""
Calls self.adaptation recursively accross
the hierarchy of module children
the hierarchy of pytorch module childrens
"""
avalanche_model_adaptation(self, experience)

Expand All @@ -90,7 +96,7 @@ def adaptation(self, experience: CLExperience):
.. warning::
This function only adapts the current module, to recursively adapt all
submodules use self.adapt() instead
submodules use self.recursive_adaptation() instead
:param experience: the current experience.
:return:
Expand Down
12 changes: 11 additions & 1 deletion tests/models/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,17 @@ def test_multihead_sizes(self):
for t, s in sizes.items():
self.assertEqual(s, model.classifiers[str(t)].classifier.out_features)

def test_avalanche_adaptation(self):
model1 = torch.nn.Sequential(MultiHeadClassifier(in_features=6))
benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True)
avalanche_model_adaptation(model1, benchmark.train_stream[0])

def test_recursive_adaptation(self):
model1 = MultiHeadClassifier(in_features=6)
benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True)
model1.recursive_adaptation(benchmark.train_stream[0])

def test_recursive_loop(self):
model1 = MultiHeadClassifier(in_features=6)
model2 = MultiHeadClassifier(in_features=6)

Expand All @@ -546,7 +556,7 @@ def test_recursive_adaptation(self):
model2.layer2 = model1

benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True)
avalanche_model_adaptation(model1, benchmark.train_stream[0])
model1.recursive_adaptation(benchmark.train_stream[0])

def test_multi_head_classifier_masking(self):
benchmark = get_fast_benchmark(use_task_labels=True, shuffle=True)
Expand Down

0 comments on commit 0a837d1

Please sign in to comment.