forked from ContinualAI/avalanche
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
change to private funcs, add manual optimizer change test
- Loading branch information
Showing
3 changed files
with
105 additions
and
28 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,16 +3,19 @@ | |
# Copyrights licensed under the MIT License. # | ||
# See the accompanying LICENSE file for terms. # | ||
# # | ||
# Date: 01-12-2020 # | ||
# Author(s): Andrea Cossu # | ||
# Date: 25-03-2024 # | ||
# Author(s): Albin Soutif # | ||
# E-mail: [email protected] # | ||
# Website: avalanche.continualai.org # | ||
################################################################################ | ||
|
||
""" | ||
This example trains a Multi-head model on Split MNIST with Elastich Weight | ||
This example trains a Multi-head model on Split MNIST with Elastic Weight | ||
Consolidation. Each experience has a different task label, which is used at test | ||
time to select the appropriate head. | ||
time to select the appropriate head. Additionally, it assigns different parameter groups | ||
to the classifier and the backbone, assigning lower learning rate to | ||
the backbone than to the classifier. When the multihead classifier grows, | ||
new parameters are automatically assigned to the corresponding parameter group | ||
""" | ||
|
||
import argparse | ||
|
@@ -78,7 +81,7 @@ def main(args): | |
|
||
# train and test loop | ||
for train_task in train_stream: | ||
strategy.train(train_task, num_workers=4) | ||
strategy.train(train_task, num_workers=4, verbose=True) | ||
strategy.eval(test_stream, num_workers=4) | ||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters