File tree Expand file tree Collapse file tree 2 files changed +4
-1
lines changed Expand file tree Collapse file tree 2 files changed +4
-1
lines changed Original file line number Diff line number Diff line change @@ -899,6 +899,7 @@ def __init__(
899
899
augment_horizontal_flip = True ,
900
900
train_lr = 1e-4 ,
901
901
train_num_steps = 100000 ,
902
+ max_grad_norm = 1. ,
902
903
ema_update_every = 10 ,
903
904
ema_decay = 0.995 ,
904
905
betas = (0.9 , 0.99 ),
@@ -926,6 +927,7 @@ def __init__(
926
927
927
928
self .batch_size = train_batch_size
928
929
self .gradient_accumulate_every = gradient_accumulate_every
930
+ self .max_grad_norm = max_grad_norm
929
931
930
932
self .train_num_steps = train_num_steps
931
933
self .image_size = diffusion_model .image_size
@@ -1013,6 +1015,7 @@ def train(self):
1013
1015
pbar .set_description (f'loss: { total_loss :.4f} ' )
1014
1016
1015
1017
accelerator .wait_for_everyone ()
1018
+ accelerator .clip_grad_norm_ (self .model .parameters (), self .max_grad_norm )
1016
1019
1017
1020
self .opt .step ()
1018
1021
self .opt .zero_grad ()
Original file line number Diff line number Diff line change 3
3
setup (
4
4
name = 'RIN-pytorch' ,
5
5
packages = find_packages (exclude = []),
6
- version = '0.7.6 ' ,
6
+ version = '0.7.7 ' ,
7
7
license = 'MIT' ,
8
8
description = 'RIN - Recurrent Interface Network - Pytorch' ,
9
9
author = 'Phil Wang' ,
You can’t perform that action at this time.
0 commit comments