-
Notifications
You must be signed in to change notification settings - Fork 568
[simplefsdp] fix simplefsdp gradient_divide_factor #1793
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
8a414cf
to
4e14a0f
Compare
4e14a0f
to
dfccea8
Compare
dfccea8
to
fc14199
Compare
91c2e8e
to
5f0be26
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After enabling reduction_divide_factor, we will see FSDP(=2) + EP (=4) have identical loss:
From the pictures, it doesn't look they are identical. Did you fix the seed when comparing FSDP2 vs. SimpleFSDP?
torch.float32, | ||
torch.bfloat16, | ||
), "only support reduce_dtype to be fp32/bf16" | ||
pre_factor, post_factor = self.reduction_divide_factor, None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
we should do post-multiply according to https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L738-L739
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmmm according to here "PREMUL_SUM multiplies inputs by a given scalar locally before reduction" Link. I think we should do pre_factor instead of post_factor. Besides, this diff also does division locally before calling reduce scatter: https://www.internalfb.com/diff/D76546536
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hmm why do we care about PREMUL_SUM and MTIA?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If you have a non-overflow dtype and non-MTIA device, this function would use PREMUL_SUM to do _get_gradient_divide_factors
(the prefactor/postfactor are both None).
I didn't use PREMUL_SUM for the reason I mentioned in the PR description. Thus, I want to simulate what PREMUL_SUM does following MTIA by doing a division over self.reduction_divide_factor, and then doing reduce_scatter("sum").
Using either pre_factor or post_factor would give the same loss results. But I still have the concern that if we first do a SUM and then do division, the gathered float number might overflow....
ac_mode=job_config.activation_checkpoint.mode, | ||
mp_policy=mp_policy, | ||
shard_dim=experts_shard_dim, | ||
reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the future we probably should deprecate this logic anyway See related PR #1803 (comment)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I will add a todo for it
turns out i didn't open deterministic, updated the fig |
5f0be26
to
f668434
Compare
f668434
to
7ce9911
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Using either pre_factor or post_factor would give the same loss results. But I still have the concern that if we first do a SUM and then do division, the gathered float number might overflow....
I see your concern. I'm not sure how realistic it is, but OK to stick with this for now.
This pr is to unblock SimpleFSDP+`gradient_divide_factor` [here](pytorch/torchtitan#1793). We will need to create a subclass for DTensor `Partial` placement. When tracing `SimpleFSDPPartial`, I hit the assertion error that `SimpleFSDPPartial` is not in `ok_types`. I'm updating the code to check placement dtype via `isinstance` instead of `type(val)`. Pull Request resolved: #164985 Approved by: https://github.com/ezyang, https://github.com/eellison
this PR is a followup of SimpleFSDP+EP [PR](pytorch#1529). Here, we add a `gradient_divide_factor` following FSDP2 to ensure modules wrapped by (FSDP+EP) has the correct gradient reduction value. - The original FSDP2 implementation is in this [PR](pytorch#1551). - The `gradient_divide_factor` logic is [here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688) We have two ways of handling `gradient_divide_factor` in `reduce_scatter`: 1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the `gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only accepts `reduce_op` as a str input ([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)). To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM, we need to update the DTensor `_reduce_shard_tensor` and `torch.distributed._functional_collectives.reduce_scatter_tensor` so that it can pass the factor associated with ReduceOp.PREMUL_SUM as an input. 2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`. The logic is in this [Diff](https://www.internalfb.com/diff/D76546536). It does a `div_` over gradient before performing `ReduceOp.SUM`. Currently I'm following 2 since it is requires less change to `_functional_collectives`. After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4) have identical loss: <img width="1194" height="780" alt="Screenshot 2025-10-08 at 5 27 24 PM" src="https://github.com/user-attachments/assets/aaf83109-8db8-4051-973d-c7b6950513de" />
This pr is to unblock SimpleFSDP+`gradient_divide_factor` [here](pytorch/torchtitan#1793). We will need to create a subclass for DTensor `Partial` placement. When tracing `SimpleFSDPPartial`, I hit the assertion error that `SimpleFSDPPartial` is not in `ok_types`. I'm updating the code to check placement dtype via `isinstance` instead of `type(val)`. Pull Request resolved: pytorch#164985 Approved by: https://github.com/ezyang, https://github.com/eellison
this PR is a followup of SimpleFSDP+EP PR. Here, we add a
gradient_divide_factor
following FSDP2 to ensure modules wrapped by (FSDP+EP) has the correct gradient reduction value.gradient_divide_factor
logic is hereWe have two ways of handling
gradient_divide_factor
inreduce_scatter
:ReduceOp.PREMUL_SUM
to handle thegradient_divide_factor
. However, DTensor's_reduce_shard_value
only acceptsreduce_op
as a str input (here).To make
_reduce_shard_value
work correctly with ReduceOp.PREMUL_SUM, we need to update the DTensor_reduce_shard_tensor
andtorch.distributed._functional_collectives.reduce_scatter_tensor
so that it can pass the factor associated with ReduceOp.PREMUL_SUM as an input.ReduceOp.PREMUL_SUM
withReduceOp.SUM
. The logic is in this Diff. It does adiv_
over gradient before performingReduceOp.SUM
.Currently I'm following 2 since it is requires less change to
_functional_collectives
.After enabling
reduction_divide_factor
, we will see FSDP(=2) + EP (=4) have identical loss: