Skip to content

Proper way to use mixed precision with manual optimization #20970

@aRI0U

Description

@aRI0U

📚 Documentation

Hi,
I am switching a project from automatic to manual optimization, and the manual version fails to converge with bfloat16. The doc doesn't indicate anything regarding how to do this either: https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html

Here is my training_step function with manual optimization:

    def training_step(self, batch, batch_idx) -> torch.Tensor:
        xa, xt, *_ = batch

        with torch.no_grad():
            ya_ema, za_ema, qa_ema = self.audio_ema(xa)  # q is after predictor
            yt_ema, zt_ema, qt_ema = self.text_ema(xt)

        with torch.autocast(device_type="cuda", dtype=torch.bfloat16, enabled=self.trainer.precision in ("16-mixed", "bf16-mixed")):
            # encode text and audio
            ya, za, qa = self.audio_encoder(xa)
            yt, zt, qt = self.text_encoder(xt)

            # compute MSE between q and z_ema for both modalities
            loss_dict = self.loss_fn(qa, qt, za_ema, zt_ema)
            total_loss = loss_dict["total_loss"]

        self.manual_backward(total_loss)

        # log metrics
        self.log_dict({f"loss/train/{k}": v for k, v in loss_dict.items()})
        
        # manual optimization
        optimizer = self.optimizers()
        scheduler = self.lr_schedulers()

        optimizer.step()
        optimizer.zero_grad(set_to_none=True)
            
        scheduler.step()
        
        return total_loss

Here was the training_step I had when using automatic optimization :

    def training_step(self, batch, batch_idx) -> torch.Tensor:
        xa, xt, *_ = batch

        # encode text and audio
        ya, za, qa = self.audio_encoder(xa)
        yt, zt, qt = self.text_encoder(xt)

        with torch.no_grad():
            ya_ema, za_ema, qa_ema = self.audio_ema(xa)  # q is after predictor
            yt_ema, zt_ema, qt_ema = self.text_ema(xt)
        
        # compute MSE between q and z_ema for both modalities
        loss_dict = self.loss_fn(qa, qt, za_ema, zt_ema)

        # log metrics
        self.log_dict({f"loss/train/{k}": v for k, v in loss_dict.items()})

        return loss_dict["total_loss"]

Note that with precision=32, both give strictly overlapping loss curves. However, in bf16-mixed the one with manual_optimization diverges.

I saw there are sometimes problems between autocast and torch.no_grad, but in my case networks text_ema and audio_ema are updated as a moving average of text_encoder and audio_encoder, so they never require gradients.

Do you know know how to properly use bf16-mixed with lightning's manual optimization?

cc @lantiga @Borda

Metadata

Metadata

Assignees

No one assigned

    Labels

    docsDocumentation relatedneeds triageWaiting to be triaged by maintainers

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions