-
Notifications
You must be signed in to change notification settings - Fork 3.5k
Description
📚 Documentation
Hi,
I am switching a project from automatic to manual optimization, and the manual version fails to converge with bfloat16. The doc doesn't indicate anything regarding how to do this either: https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html
Here is my training_step function with manual optimization:
def training_step(self, batch, batch_idx) -> torch.Tensor:
xa, xt, *_ = batch
with torch.no_grad():
ya_ema, za_ema, qa_ema = self.audio_ema(xa) # q is after predictor
yt_ema, zt_ema, qt_ema = self.text_ema(xt)
with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=self.trainer.precision in ("16-mixed", "bf16-mixed")):
# encode text and audio
ya, za, qa = self.audio_encoder(xa)
yt, zt, qt = self.text_encoder(xt)
# compute MSE between q and z_ema for both modalities
loss_dict = self.loss_fn(qa, qt, za_ema, zt_ema)
total_loss = loss_dict["total_loss"]
self.manual_backward(total_loss)
# log metrics
self.log_dict({f"loss/train/{k}": v for k, v in loss_dict.items()})
# manual optimization
optimizer = self.optimizers()
scheduler = self.lr_schedulers()
optimizer.step()
optimizer.zero_grad(set_to_none=True)
scheduler.step()
return total_loss
Here was the training_step I had when using automatic optimization :
def training_step(self, batch, batch_idx) -> torch.Tensor:
xa, xt, *_ = batch
# encode text and audio
ya, za, qa = self.audio_encoder(xa)
yt, zt, qt = self.text_encoder(xt)
with torch.no_grad():
ya_ema, za_ema, qa_ema = self.audio_ema(xa) # q is after predictor
yt_ema, zt_ema, qt_ema = self.text_ema(xt)
# compute MSE between q and z_ema for both modalities
loss_dict = self.loss_fn(qa, qt, za_ema, zt_ema)
# log metrics
self.log_dict({f"loss/train/{k}": v for k, v in loss_dict.items()})
return loss_dict["total_loss"]
Note that with precision=32, both give strictly overlapping loss curves. However, in bf16-mixed the one with manual_optimization diverges.
I saw there are sometimes problems between autocast
and torch.no_grad
, but in my case networks text_ema and audio_ema are updated as a moving average of text_encoder and audio_encoder, so they never require gradients.
Do you know know how to properly use bf16-mixed with lightning's manual optimization?