Skip to content

Commit

Permalink
added adapt function, auto_adapt argument and recursive adapt across …
Browse files Browse the repository at this point in the history
…pytorch modules
  • Loading branch information
AlbinSou committed Feb 16, 2024
1 parent 071d813 commit 3a13735
Showing 1 changed file with 71 additions and 15 deletions.
86 changes: 71 additions & 15 deletions avalanche/models/dynamic_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,39 @@
to allow architectural modifications (multi-head classifiers, progressive
networks, ...).
"""
from typing import List, Optional

import torch
from torch.nn import Module
from typing import Optional

from avalanche.benchmarks.utils.flat_data import ConstantSequence
from avalanche.benchmarks.scenarios import CLExperience
from avalanche.benchmarks.utils.flat_data import ConstantSequence


def adapt_recursive(
module: Module,
experience: CLExperience,
_visited: List = None,
_initial_call: bool = True,
):
if _visited is None:
_visited = []

if module in _visited:
return

_visited.append(module)

if isinstance(module, DynamicModule):
if (not _initial_call) and (not module._auto_adapt):
# Some modules don't want to be auto-adapted
return
else:
module.adaptation(experience)

# Iterate over children
for name, submodule in module.named_children():
adapt_recursive(submodule, experience, _visited=_visited, _initial_call=False)


class DynamicModule(Module):
Expand All @@ -29,6 +56,22 @@ class DynamicModule(Module):
`model_adaptation`, which adapts the model given the current experience.
"""

def __init__(self, auto_adapt=True):
"""
:param auto_adapt: If True, will be adapted in the recursive adaptation loop
else, will be adapted by a module in charge
(i.e IncrementalClassifier inside MultiHeadClassifier)
"""
super().__init__()
self._auto_adapt = auto_adapt

def adapt(self, experience):
"""
Calls self.adaptation recursively accross
the hierarchy of module children
"""
adapt_recursive(self, experience)

def adaptation(self, experience: CLExperience):
"""Adapt the module (freeze units, add units...) using the current
data. Optimizers must be updated after the model adaptation.
Expand All @@ -43,6 +86,10 @@ def adaptation(self, experience: CLExperience):
require the model's adaptation, such as the discovery of new
classes or tasks.
.. warning::
This function only adapts the current module, to recursively adapt all
submodules use self.adapt() instead
:param experience: the current experience.
:return:
"""
Expand Down Expand Up @@ -97,13 +144,13 @@ class MultiTaskModule(DynamicModule):
the output is computed in parallel for each task.
"""

def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.max_class_label = 0
self.known_train_tasks_labels = set()
""" Set of task labels encountered up to now. """

def adaptation(self, experience: CLExperience):
def adaptation(self, experience: CLExperience, adapt_submodules=True):
"""Adapt the module (freeze units, add units...) using the current
data. Optimizers must be updated after the model adaptation.
Expand All @@ -122,10 +169,7 @@ def adaptation(self, experience: CLExperience):
"""
curr_classes = experience.classes_in_this_experience
self.max_class_label = max(self.max_class_label, max(curr_classes) + 1)
if self.training:
self.train_adaptation(experience)
else:
self.eval_adaptation(experience)
super().adaptation(experience)

def eval_adaptation(self, experience: CLExperience):
pass
Expand Down Expand Up @@ -207,6 +251,7 @@ def __init__(
initial_out_features=2,
masking=True,
mask_value=-1000,
**kwargs,
):
"""
:param in_features: number of input features.
Expand All @@ -215,7 +260,7 @@ def __init__(
:param masking: whether unused units should be masked (default=True).
:param mask_value: the value used for masked units (default=-1000).
"""
super().__init__()
super().__init__(**kwargs)
self.masking = masking
self.mask_value = mask_value

Expand All @@ -224,7 +269,7 @@ def __init__(
self.register_buffer("active_units", au_init)

@torch.no_grad()
def adaptation(self, experience: CLExperience):
def train_adaptation(self, experience: CLExperience):
"""If `dataset` contains unseen classes the classifier is expanded.
:param experience: data from the current experience.
Expand Down Expand Up @@ -256,6 +301,9 @@ def adaptation(self, experience: CLExperience):
self.classifier.weight[:old_nclasses] = old_w
self.classifier.bias[:old_nclasses] = old_b

def eval_adaptation(self, experience):
self.train_adaptation(experience)

def forward(self, x, **kwargs):
"""compute the output given the input `x`. This module does not use
the task label.
Expand Down Expand Up @@ -321,7 +369,10 @@ def __init__(
# masking in IncrementalClassifier is unaware of task labels
# so we do masking here instead.
first_head = IncrementalClassifier(
self.in_features, self.starting_out_features, masking=False
self.in_features,
self.starting_out_features,
masking=False,
auto_adapt=False,
)
self.classifiers["0"] = first_head
self.max_class_label = max(self.max_class_label, initial_out_features)
Expand All @@ -345,13 +396,12 @@ def task_masks(self):
res[tid] = getattr(self, f"active_units_T{tid}").to(torch.bool)
return res

def adaptation(self, experience: CLExperience):
def train_adaptation(self, experience: CLExperience):
"""If `dataset` contains new tasks, a new head is initialized.
:param experience: data from the current experience.
:return:
"""
super().adaptation(experience)
device = self._adaptation_device
curr_classes = experience.classes_in_this_experience
task_labels = experience.task_labels
Expand All @@ -364,7 +414,10 @@ def adaptation(self, experience: CLExperience):
# head adaptation
if tid not in self.classifiers: # create new head
new_head = IncrementalClassifier(
self.in_features, self.starting_out_features, masking=False
self.in_features,
self.starting_out_features,
masking=False,
auto_adapt=False,
).to(device)
self.classifiers[tid] = new_head

Expand Down Expand Up @@ -404,6 +457,9 @@ def adaptation(self, experience: CLExperience):
if self.training:
self._buffers[au_name][curr_classes] = 1

def eval_adaptation(self, experience):
self.train_adaptation(experience)

def forward_single_task(self, x, task_label):
"""compute the output given the input `x`. This module uses the task
label to activate the correct head.
Expand Down

0 comments on commit 3a13735

Please sign in to comment.