Skip to content

Conversation

cbcase
Copy link
Contributor

@cbcase cbcase commented Sep 22, 2023

As mentioned in #1728, the FusedAdam optimizer ignores master_weights=True for bfloat16 parameters. This PR fixes that oversight. I have confirmed that the behavior now matches a "by hand" implementation of master weights (hand-copying) along with vanilla torch.optim.AdamW on the fp32 copy.

@cbcase
Copy link
Contributor Author

cbcase commented Oct 16, 2023

Ping @minitu, looks like you added this support originally -- could you take a look? Thanks

@minitu
Copy link
Contributor

minitu commented Oct 17, 2023

LGTM, we only looked at adding master weights for FP16 AMP at the time of the original PR.
@crcrpar Could you review this as well?

Copy link
Collaborator

@crcrpar crcrpar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good but could you add a test case of bfloat16 model with fp32 weights to

def testGradScalerCapturableMaster(self):
?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants