Skip to content

Proper way to use mixed precision with manual optimization #20970

@aRI0U

Description

@aRI0U

📚 Documentation

Hi,
I am switching a project from automatic to manual optimization, and the manual version fails to converge with bfloat16. The doc doesn't indicate anything regarding how to do this either: https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html

Here is my training_step function with manual optimization:

    def training_step(self, batch, batch_idx) -> torch.Tensor:
        xa, xt, *_ = batch

        with torch.no_grad():
            ya_ema, za_ema, qa_ema = self.audio_ema(xa)  # q is after predictor
            yt_ema, zt_ema, qt_ema = self.text_ema(xt)

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=self.trainer.precision in ("16-mixed", "bf16-mixed")):
            # encode text and audio
            ya, za, qa = self.audio_encoder(xa)
            yt, zt, qt = self.text_encoder(xt)

            # compute MSE between q and z_ema for both modalities
            loss_dict = self.loss_fn(qa, qt, za_ema, zt_ema)
            total_loss = loss_dict["total_loss"]

        self.manual_backward(total_loss)

        # log metrics
        self.log_dict({f"loss/train/{k}": v for k, v in loss_dict.items()})
        
        # manual optimization
        optimizer = self.optimizers()
        scheduler = self.lr_schedulers()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
            
        scheduler.step()
        
        return total_loss

Here was the training_step I had when using automatic optimization :

    def training_step(self, batch, batch_idx) -> torch.Tensor:
        xa, xt, *_ = batch

        # encode text and audio
        ya, za, qa = self.audio_encoder(xa)
        yt, zt, qt = self.text_encoder(xt)

        with torch.no_grad():
            ya_ema, za_ema, qa_ema = self.audio_ema(xa)  # q is after predictor
            yt_ema, zt_ema, qt_ema = self.text_ema(xt)
        
        # compute MSE between q and z_ema for both modalities
        loss_dict = self.loss_fn(qa, qt, za_ema, zt_ema)

        # log metrics
        self.log_dict({f"loss/train/{k}": v for k, v in loss_dict.items()})

        return loss_dict["total_loss"]

Note that with precision=32, both give strictly overlapping loss curves. However, in bf16-mixed the one with manual_optimization diverges.

I saw there are sometimes problems between autocast and torch.no_grad, but in my case networks text_ema and audio_ema are updated as a moving average of text_encoder and audio_encoder, so they never require gradients.

Do you know know how to properly use bf16-mixed with lightning's manual optimization?

cc @lantiga @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    docsDocumentation related

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions