Skip to content

Commit

Permalink
added remove_params as an option to make_optimizer that defaults to F…
Browse files Browse the repository at this point in the history
…alse
  • Loading branch information
AlbinSou committed Mar 21, 2024
1 parent dc6c5a3 commit 971fe65
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 33 deletions.
41 changes: 26 additions & 15 deletions avalanche/models/dynamic_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,11 @@ def map_optimized_params(optimizer, parameters, old_params=None):
returns the lists of:
returns:
new_parameters: Names of new parameters in the provided "parameters" argument
new_parameters: Names of new parameters in the provided "parameters" argument,
that are not in the old parameters
changed_parameters: Names and indexes of parameters that have changed (grown, shrink)
removed_parameters: List of indexes of optimizer parameters that are not found in the new parameters
not_found_in_parameters: List of indexes of optimizer parameters
that are not found in the provided parameters
"""

if old_params is None:
Expand Down Expand Up @@ -228,7 +229,7 @@ def single_group(self):
return list(self.groups)[0]


@deprecated(0.6, "reset_optimizer is deprecated in favor of update_optimizer")
@deprecated(0.6, "update_optimizer with optimized_params=None is now used instead")
def reset_optimizer(optimizer, model):
"""Reset the optimizer to update the list of learnable parameters.
Expand Down Expand Up @@ -259,7 +260,12 @@ def reset_optimizer(optimizer, model):


def update_optimizer(
optimizer, new_params, optimized_params, reset_state=False, verbose=False
optimizer,
new_params,
optimized_params=None,
reset_state=False,
remove_params=False,
verbose=False,
):
"""Update the optimizer by adding new parameters,
removing removed parameters, and adding new parameters
Expand All @@ -271,11 +277,15 @@ def update_optimizer(
:param new_params: Dict (name, param) of new parameters
:param optimized_params: Dict (name, param) of
currently optimized parameters (returned by reset_optimizer)
:param reset_state: Wheter to reset the optimizer's state (i.e momentum).
Defaults to False.
currently optimized parameters
:param reset_state: Whether to reset the optimizer's state (i.e momentum).
Defaults to False.
:param remove_params: Whether to remove parameters that were in the optimizer
but are not found in new parameters. For safety reasons,
defaults to False.
:param verbose: If True, prints information about inferred
parameter groups for new params
:return: Dict (name, param) of optimized parameters
"""
(
Expand All @@ -299,13 +309,14 @@ def update_optimizer(

# Remove parameters that are not here anymore
# This should not happend in most use case
for group_idx, idx_list in enumerate(not_found_in_parameters):
for j in sorted(idx_list, key=lambda x: x, reverse=True):
p = optimizer.param_groups[group_idx]["params"][j]
optimizer.param_groups[group_idx]["params"].pop(j)
if p in optimizer.state:
optimizer.state.pop(p)
del p
if remove_params:
for group_idx, idx_list in enumerate(not_found_in_parameters):
for j in sorted(idx_list, key=lambda x: x, reverse=True):
p = optimizer.param_groups[group_idx]["params"][j]
optimizer.param_groups[group_idx]["params"].pop(j)
if p in optimizer.state:
optimizer.state.pop(p)
del p

# Add newly added parameters (i.e Multitask, PNN)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,9 @@ def model_adaptation(self, model=None):

return model.to(self.device)

def make_optimizer(self, reset_optimizer_state=False, **kwargs):
def make_optimizer(
self, reset_optimizer_state=False, remove_params=False, **kwargs
):
"""Optimizer initialization.
Called before each training experience to configure the optimizer.
Expand All @@ -49,7 +51,7 @@ def make_optimizer(self, reset_optimizer_state=False, **kwargs):
for a given strategy it will reset the
optimizer to gather the (name, param)
correspondance of the optimized parameters
all the model parameters will be put in the
all of the model parameters will be put in the
optimizer, regardless of what parameters are
initially put in the optimizer.
"""
Expand All @@ -58,6 +60,7 @@ def make_optimizer(self, reset_optimizer_state=False, **kwargs):
dict(self.model.named_parameters()),
self.optimized_param_id,
reset_state=reset_optimizer_state,
remove_params=remove_params,
)

def check_model_and_optimizer(self, reset_optimizer_state=False, **kwargs):
Expand Down
77 changes: 61 additions & 16 deletions tests/test_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,9 @@ def test_optimizer_update(self):
# Here we cannot know what parameter group but there is only one so it should work
new_parameters = {"new_param": p_new}
new_parameters.update(dict(model.named_parameters()))
optimized = update_optimizer(optimizer, new_parameters, {"old_param": p})
optimized = update_optimizer(
optimizer, new_parameters, {"old_param": p}, remove_params=True
)
self.assertTrue("new_param" in optimized)
self.assertFalse("old_param" in optimized)
self.assertTrue(self._is_param_in_optimizer(p_new, strategy.optimizer))
Expand All @@ -140,7 +142,7 @@ def test_optimizers(self):
self._test_optimizer(strategy)
self._test_optimizer_state(strategy)

def test_optimizer_groups_clf(self):
def test_optimizer_groups_clf_til(self):
model, criterion, benchmark = self.init_scenario(multi_task=True)

g1 = []
Expand Down Expand Up @@ -177,7 +179,7 @@ def test_optimizer_groups_clf(self):
self._is_param_in_optimizer_group(p, strategy.optimizer), 1
)

def test_optimizer_groups_rename(self):
def test_optimizer_groups_clf_cil(self):
model, criterion, benchmark = self.init_scenario(multi_task=False)

g1 = []
Expand All @@ -200,22 +202,46 @@ def test_optimizer_groups_rename(self):
train_epochs=2,
)

experience = benchmark.train_stream[0]
for experience in benchmark.train_stream:
strategy.train(experience)

print(experience.classes_in_this_experience)
for n, p in model.named_parameters():
assert self._is_param_in_optimizer(p, strategy.optimizer)
if "classifier" in n:
self.assertEqual(
self._is_param_in_optimizer_group(p, strategy.optimizer), 0
)
else:
self.assertEqual(
self._is_param_in_optimizer_group(p, strategy.optimizer), 1
)

strategy.train(experience)
def test_optimizer_groups_rename(self):
model, criterion, benchmark = self.init_scenario(multi_task=False)

experience = benchmark.train_stream[1]
g1 = []
g2 = []
for n, p in model.named_parameters():
if "classifier" in n:
g1.append(p)
else:
g2.append(p)

print(experience.classes_in_this_experience)
optimizer = SGD([{"params": g1, "lr": 0.1}, {"params": g2, "lr": 0.05}])

# Here I do not get an error but all groups switch to 1 for some unknown reason
strategy.model.new_module = torch.nn.Linear(10, 10)
strategy.model = TorchWrapper(strategy.model)
strategy = Naive(
model=model,
optimizer=optimizer,
criterion=criterion,
train_mb_size=64,
device=self.device,
eval_mb_size=50,
train_epochs=2,
)

strategy.train(experience)
strategy.make_optimizer()

# Check parameter groups
for n, p in model.named_parameters():
assert self._is_param_in_optimizer(p, strategy.optimizer)
if "classifier" in n:
Expand All @@ -227,6 +253,22 @@ def test_optimizer_groups_rename(self):
self._is_param_in_optimizer_group(p, strategy.optimizer), 1
)

# Rename parameters
strategy.model = TorchWrapper(strategy.model)

strategy.make_optimizer()

# Check parameter groups are still the same
for n, p in model.named_parameters():
assert self._is_param_in_optimizer(p, strategy.optimizer)
if "classifier" in n:
self.assertEqual(
self._is_param_in_optimizer_group(p, strategy.optimizer), 0
)
else:
self.assertEqual(
self._is_param_in_optimizer_group(p, strategy.optimizer), 1
)

# Needs torch 2.0 ?
def test_checkpointing(self):
Expand Down Expand Up @@ -295,7 +337,10 @@ def _test_optimizer(self, strategy):
# Remove a parameter
del strategy.model.new_module

strategy.make_optimizer()
strategy.make_optimizer(remove_params=False)
self.assertTrue(self._is_param_in_optimizer(param1, strategy.optimizer))

strategy.make_optimizer(remove_params=True)
self.assertFalse(self._is_param_in_optimizer(param1, strategy.optimizer))

def _test_optimizer_state(self, strategy):
Expand All @@ -307,7 +352,7 @@ def _test_optimizer_state(self, strategy):
strategy.model.add_module("new_module1", module1)
strategy.model.add_module("new_module2", module2)

strategy.make_optimizer()
strategy.make_optimizer(remove_params=True)

self.assertTrue(self._is_param_in_optimizer(param1, strategy.optimizer))
self.assertTrue(self._is_param_in_optimizer(param2, strategy.optimizer))
Expand All @@ -322,7 +367,7 @@ def _test_optimizer_state(self, strategy):
# Remove one module
del strategy.model.new_module1

strategy.make_optimizer()
strategy.make_optimizer(remove_params=True)

# Make an operation
self._optimizer_op(strategy.optimizer, module1.weight + module2.weight)
Expand All @@ -333,7 +378,7 @@ def _test_optimizer_state(self, strategy):

# Change one module size
strategy.model.new_module2 = torch.nn.Linear(10, 5)
strategy.make_optimizer()
strategy.make_optimizer(remove_params=True)

# Make an operation
self._optimizer_op(strategy.optimizer, module1.weight + module2.weight)
Expand Down

0 comments on commit 971fe65

Please sign in to comment.