Skip to content

Commit

Permalink
modified update_optimizer with new logic
Browse files Browse the repository at this point in the history
  • Loading branch information
AlbinSou committed Mar 8, 2024
1 parent 181fa0e commit 068fae5
Show file tree
Hide file tree
Showing 2 changed files with 81 additions and 531 deletions.
137 changes: 81 additions & 56 deletions avalanche/models/dynamic_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
"""
from collections import defaultdict

import numpy as np

colors = {
"END": "\033[0m",
0: "\033[32m",
Expand All @@ -27,24 +29,65 @@
colors[None] = colors["END"]


def map_optimized_params(optimizer, new_params):
current_parameters_mapping = defaultdict(dict)
not_found = []
for n, p in new_params.items():
def map_optimized_params(optimizer, parameters, old_params=None):
"""
Establishes a mapping between a list of named parameters and the parameters
that are in the optimizer, additionally,
returns the list of
changed_parameters
new_parameters
removed_parameters: List of indexes of optimizer parameters that are not found in the new parameters
"""

group_mapping = defaultdict(dict)
new_parameters = []

found_indexes = []
changed_parameters = []
for group in optimizer.param_groups:
params = group["params"]
found_indexes.append(np.zeros(len(params)))

for n, p in parameters.items():
g = None
# Find param in optimizer
found = False

if n in old_params:
search_id = id(old_params[n])
else:
search_id = id(p)

for group_idx, group in enumerate(optimizer.param_groups):
params = group["params"]
for po in params:
if id(po) == id(p):
for param_idx, po in enumerate(params):
if id(po) == search_id:
g = group_idx
found = True
# Update found indexes
assert found_indexes[group_idx][param_idx] == 0
found_indexes[group_idx][param_idx] = 1
break

if not found:
not_found.append(n)
current_parameters_mapping[n] = g
return current_parameters_mapping, not_found
new_parameters.append(n)

if search_id != id(p):
if found:
changed_parameters.append((n, group_idx, param_idx))

group_mapping[n] = g

not_found_in_parameters = [np.where(arr == 0)[0] for arr in found_indexes]

return (
group_mapping,
changed_parameters,
new_parameters,
not_found_in_parameters,
)


def build_tree_from_name_groups(name_groups):
Expand Down Expand Up @@ -95,18 +138,17 @@ def print_group_information(node, prefix=""):
print_group_information(child_node, prefix + " ")


class OptimizedParameterStructure:
class ParameterGroupStructure:
"""
Structure holding a tree where each node is a pytorch module
the tree is linked to the current model parameters
Structure used for the resolution of unknown parameter groups,
stores parameters as a tree and propagates parameter groups from leaves of
the same hierarchical level
"""

def __init__(self, new_params, optimizer, verbose=False):
def __init__(self, name_groups, verbose=False):
# Here we rebuild the tree
name_groups, not_found = map_optimized_params(optimizer, new_params)
self.root, self.node_mapping = build_tree_from_name_groups(name_groups)
if verbose:
print(f"Not found in optimizer: {not_found}")
print_group_information(self.root)

def __getitem__(self, name):
Expand Down Expand Up @@ -224,59 +266,42 @@ def update_optimizer(optimizer, new_params, optimized_params, reset_state=False)
Defaults to False.
:return: Dict (name, param) of optimized parameters
"""
not_in_new, in_both, not_in_old = compare_keys(optimized_params, new_params)
(
group_mapping,
changed_parameters,
new_parameters,
not_found_in_parameters,
) = map_optimized_params(optimizer, new_params, old_params=optimized_params)

# Change reference to already existing parameters
# i.e growing IncrementalClassifier
for key in in_both:
old_p_hash = optimized_params[key]
new_p = new_params[key]
for name, group_idx, param_idx in changed_parameters:
group = optimizer.param_groups[group_idx]
old_p = optimized_params[name]
new_p = new_params[name]
# Look for old parameter id in current optimizer
found = False
for group in optimizer.param_groups:
for i, curr_p in enumerate(group["params"]):
if id(curr_p) == id(old_p_hash):
found = True
if id(curr_p) != id(new_p):
group["params"][i] = new_p
optimized_params[key] = new_p
optimizer.state[new_p] = {}
break
if not found:
raise Exception(
f"Parameter {key} expected but " "not found in the optimizer"
)
group["params"][param_idx] = new_p
if old_p in optimizer.state:
optimizer.state.pop(old_p)
optimizer.state[new_p] = {}

# Remove parameters that are not here anymore
# This should not happend in most use case
keys_to_remove = []
for key in not_in_new:
old_p_hash = optimized_params[key]
found = False
for i, group in enumerate(optimizer.param_groups):
keys_to_remove.append([])
for j, curr_p in enumerate(group["params"]):
if id(curr_p) == id(old_p_hash):
found = True
keys_to_remove[i].append((j, curr_p))
optimized_params.pop(key)
break
if not found:
raise Exception(
f"Parameter {key} expected but " "not found in the optimizer"
)

for i, idx_list in enumerate(keys_to_remove):
for j, p in sorted(idx_list, key=lambda x: x[0], reverse=True):
del optimizer.param_groups[i]["params"][j]
for group_idx, idx_list in enumerate(not_found_in_parameters):
for j in sorted(idx_list, key=lambda x: x, reverse=True):
p = optimizer.param_groups[group_idx]["params"][j]
optimizer.param_groups[group_idx]["params"].pop(j)
if p in optimizer.state:
optimizer.state.pop(p)
del p

# Add newly added parameters (i.e Multitask, PNN)
# by default, add to param groups 0

param_structure = OptimizedParameterStructure(new_params, optimizer, verbose=True)
param_structure = ParameterGroupStructure(group_mapping, verbose=True)

for key in not_in_old:
# New parameters
for key in new_parameters:
new_p = new_params[key]
group = param_structure[key].single_group
optimizer.param_groups[group]["params"].append(new_p)
Expand All @@ -286,7 +311,7 @@ def update_optimizer(optimizer, new_params, optimized_params, reset_state=False)
if reset_state:
optimizer.state = defaultdict(dict)

return optimized_params
return new_params


def add_new_params_to_optimizer(optimizer, new_params):
Expand Down
Loading

0 comments on commit 068fae5

Please sign in to comment.