Skip to content

EMA training for PEFT LoRAs #9998

@bghira

Description

@bghira
Contributor

Is your feature request related to a problem? Please describe.

EMAModel in Diffusers is not plumbed for interacting well with PEFT LoRAs, which leaves users to implement their own.

The idea has been thrown around that LoRA did not benefit from EMA, and research papers had shown this. However, after curiosity piqued, took a bit but managed to make it work.

Here is a pull request for SimpleTuner where I've updated my EMAModel implementation to behave more like nn.Module and allow EMAModel to be passed into more processes without "funny business".

This spot in the save hooks was hardcoded to take the class name following Diffusers convention but we can do more dynamic approach in perhaps a training_utils helper method.

Just a bit downward at L208 in the save hooks, I did something I'm not really 100% happy with, but users were:

  • For my own trainer's convenience, I save a copy of the EMA model in a simple loadable state_dict format so that I can load this during resume.
  • Additionally, we save a 2nd copy of the EMA in the PEFT LoRA format so that it can be loaded by pipelines.

The tricky part is the 2nd copy of the EMA model that gets saved in the standard LoRA format:

        if self.args.use_ema:
            # we'll temporarily overwrite teh LoRA parameters with the EMA parameters to save it.
            logger.info("Saving EMA model to disk.")
            trainable_parameters = [
                p
                for p in self._primary_model().parameters()
                if p.requires_grad
            ]
            self.ema_model.store(trainable_parameters)
            self.ema_model.copy_to(trainable_parameters)
            if self.transformer is not None:
                self.pipeline_class.save_lora_weights(
                    os.path.join(output_dir, "ema"),
                    transformer_lora_layers=convert_state_dict_to_diffusers(
                        get_peft_model_state_dict(self._primary_model())
                    ),
                )
            elif self.unet is not None:
                self.pipeline_class.save_lora_weights(
                    os.path.join(output_dir, "ema"),
                    unet_lora_layers=convert_state_dict_to_diffusers(
                        get_peft_model_state_dict(self._primary_model())
                    ),
                )
            self.ema_model.restore(trainable_parameters)

this could probably be done more nicely with a trainable_parameters() method on the model classes where appropriate.

I guess the decorations with converting state dicts are required for now, but it would be ideal if this could be simplified so that newcomers do not have to look into and understand so many moving pieces.

For quantised training, we have to quantise the EMA model just like the trained model had done to it.

The validations were kind of a pain but I wanted to make the EMA load/unload possible to do during the process repeatedly so that each prompt can be validated for the ckpt as well as the EMA weights. Here is my method for enabling (and just below, disabling) the EMA model at inference time.

However, the effect is really nice; here you see the starting SD 3.5M on the left, the trained LoRA in the centre, and EMA on the right.

image

image

image

image

image

these samples are from 60,000 steps of training a rank-128 PEFT LoRA on all of the attn layers for the SD 3.5 Medium model on ~120,000 high quality photos.

while it's not a cure-all for training problems, throughout the entire duration of training, the EMA model has outperformed the trained checkpoint.

It'd be a good idea to consider someday including EMA for LoRA with related improvements for saving/loading EMA weights on adapters so that users can receive better results from the training examples. I don't think the validation changes are needed, but they can be done in a non-intrusive way, more nicely than I have done here.

Activity

bghira

bghira commented on Nov 22, 2024

@bghira
ContributorAuthor

cc @linoytsaban @sayakpaul for your interest perhaps

sayakpaul

sayakpaul commented on Nov 23, 2024

@sayakpaul
Member

Thanks for the interesting thread.

I think for now we can refer the users to SimpleTuner for this. Also, perhaps, it's subjective but I don't necessarily find the EMA results to be better than what's without.

bghira

bghira commented on Nov 23, 2024

@bghira
ContributorAuthor

yeah the centre's outputs are actually entirely incoherent. don't know why that is preferred

double8fun

double8fun commented on Jul 8, 2025

@double8fun

The idea has been thrown around that LoRA did not benefit from EMA, and research papers had shown this

Hi, thanks for your great work. May I ask if there's specific paper that mentions this idea? I did some searching but couldn't find.

bghira

bghira commented on Jul 8, 2025

@bghira
ContributorAuthor

i think the most extensive exploration into EMA is the post-hoc EMA work from tero karras et al at nvidia.

double8fun

double8fun commented on Jul 8, 2025

@double8fun

i think the most extensive exploration into EMA is the post-hoc EMA work from tero karras et al at nvidia.

Thanks! I will check it out.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or request

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

      Development

      No branches or pull requests

        Participants

        @sayakpaul@double8fun@bghira

        Issue actions

          EMA training for PEFT LoRAs · Issue #9998 · huggingface/diffusers