Skip to content

Commit 8d0de9a

Browse files
committed
add gradient clipping
1 parent 69b1664 commit 8d0de9a

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

rin_pytorch/rin_pytorch.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,7 @@ def __init__(
899899
augment_horizontal_flip = True,
900900
train_lr = 1e-4,
901901
train_num_steps = 100000,
902+
max_grad_norm = 1.,
902903
ema_update_every = 10,
903904
ema_decay = 0.995,
904905
betas = (0.9, 0.99),
@@ -926,6 +927,7 @@ def __init__(
926927

927928
self.batch_size = train_batch_size
928929
self.gradient_accumulate_every = gradient_accumulate_every
930+
self.max_grad_norm = max_grad_norm
929931

930932
self.train_num_steps = train_num_steps
931933
self.image_size = diffusion_model.image_size
@@ -1013,6 +1015,7 @@ def train(self):
10131015
pbar.set_description(f'loss: {total_loss:.4f}')
10141016

10151017
accelerator.wait_for_everyone()
1018+
accelerator.clip_grad_norm_(self.model.parameters(), self.max_grad_norm)
10161019

10171020
self.opt.step()
10181021
self.opt.zero_grad()

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'RIN-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.7.6',
6+
version = '0.7.7',
77
license='MIT',
88
description = 'RIN - Recurrent Interface Network - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)