Skip to content

Conversation

ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Oct 3, 2025

this PR is a followup of SimpleFSDP+EP PR. Here, we add a gradient_divide_factor following FSDP2 to ensure modules wrapped by (FSDP+EP) has the correct gradient reduction value.

  • The original FSDP2 implementation is in this PR.
  • The gradient_divide_factor logic is here

We have two ways of handling gradient_divide_factor in reduce_scatter:

  1. The first one is to use ReduceOp.PREMUL_SUM to handle the gradient_divide_factor. However, DTensor's _reduce_shard_value only accepts reduce_op as a str input (here).

To make _reduce_shard_value work correctly with ReduceOp.PREMUL_SUM, we need to update the DTensor _reduce_shard_tensor and torch.distributed._functional_collectives.reduce_scatter_tensor so that it can pass the factor associated with ReduceOp.PREMUL_SUM as an input.

  1. Another way is to simulate ReduceOp.PREMUL_SUM with ReduceOp.SUM. The logic is in this Diff. It does a div_ over gradient before performing ReduceOp.SUM.

Currently I'm following 2 since it is requires less change to _functional_collectives.

After enabling reduction_divide_factor, we will see FSDP(=2) + EP (=4) have identical loss:

Screenshot 2025-10-08 at 5 27 24 PM

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Meta Open Source bot. label Oct 3, 2025
@ruisizhang123 ruisizhang123 marked this pull request as draft October 3, 2025 03:41
@ruisizhang123 ruisizhang123 force-pushed the ruisi/gradient_divisor branch 2 times, most recently from 8a414cf to 4e14a0f Compare October 7, 2025 07:25
@ruisizhang123 ruisizhang123 marked this pull request as ready for review October 8, 2025 16:50
@ruisizhang123 ruisizhang123 changed the title [wip][simplefsdp] fix simplefsdp gradient_divide_factor [simplefsdp] fix simplefsdp gradient_divide_factor Oct 8, 2025
@ruisizhang123 ruisizhang123 force-pushed the ruisi/gradient_divisor branch from 4e14a0f to dfccea8 Compare October 8, 2025 19:47
@ruisizhang123 ruisizhang123 force-pushed the ruisi/gradient_divisor branch from dfccea8 to fc14199 Compare October 8, 2025 20:00
@ruisizhang123 ruisizhang123 force-pushed the ruisi/gradient_divisor branch 2 times, most recently from 91c2e8e to 5f0be26 Compare October 8, 2025 22:25
@ruisizhang123 ruisizhang123 requested a review from tianyu-l October 8, 2025 22:40
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After enabling reduction_divide_factor, we will see FSDP(=2) + EP (=4) have identical loss:

From the pictures, it doesn't look they are identical. Did you fix the seed when comparing FSDP2 vs. SimpleFSDP?

torch.float32,
torch.bfloat16,
), "only support reduce_dtype to be fp32/bf16"
pre_factor, post_factor = self.reduction_divide_factor, None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmmm according to here "PREMUL_SUM multiplies inputs by a given scalar locally before reduction" Link. I think we should do pre_factor instead of post_factor. Besides, this diff also does division locally before calling reduce scatter: https://www.internalfb.com/diff/D76546536

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm why do we care about PREMUL_SUM and MTIA?

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Oct 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If you have a non-overflow dtype and non-MTIA device, this function would use PREMUL_SUM to do _get_gradient_divide_factors (the prefactor/postfactor are both None).

I didn't use PREMUL_SUM for the reason I mentioned in the PR description. Thus, I want to simulate what PREMUL_SUM does following MTIA by doing a division over self.reduction_divide_factor, and then doing reduce_scatter("sum").

Using either pre_factor or post_factor would give the same loss results. But I still have the concern that if we first do a SUM and then do division, the gathered float number might overflow....

ac_mode=job_config.activation_checkpoint.mode,
mp_policy=mp_policy,
shard_dim=experts_shard_dim,
reduction_divide_factor=parallel_dims.fsdp_gradient_divide_factor,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the future we probably should deprecate this logic anyway See related PR #1803 (comment)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add a todo for it

@ruisizhang123
Copy link
Contributor Author

After enabling reduction_divide_factor, we will see FSDP(=2) + EP (=4) have identical loss:

From the pictures, it doesn't look they are identical. Did you fix the seed when comparing FSDP2 vs. SimpleFSDP?

turns out i didn't open deterministic, updated the fig

@ruisizhang123 ruisizhang123 force-pushed the ruisi/gradient_divisor branch from 5f0be26 to f668434 Compare October 9, 2025 00:31
@ruisizhang123 ruisizhang123 force-pushed the ruisi/gradient_divisor branch from f668434 to 7ce9911 Compare October 9, 2025 00:32
@ruisizhang123 ruisizhang123 requested a review from tianyu-l October 9, 2025 00:56
Copy link
Contributor

@tianyu-l tianyu-l left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Using either pre_factor or post_factor would give the same loss results. But I still have the concern that if we first do a SUM and then do division, the gathered float number might overflow....

I see your concern. I'm not sure how realistic it is, but OK to stick with this for now.

@ruisizhang123 ruisizhang123 merged commit f014f31 into main Oct 9, 2025
5 checks passed
@ruisizhang123 ruisizhang123 deleted the ruisi/gradient_divisor branch October 9, 2025 01:10
pytorchmergebot pushed a commit to pytorch/pytorch that referenced this pull request Oct 9, 2025
This pr is to unblock SimpleFSDP+`gradient_divide_factor` [here](pytorch/torchtitan#1793). We will need to create a subclass for DTensor `Partial` placement. When tracing `SimpleFSDPPartial`, I hit the assertion error that `SimpleFSDPPartial` is not in `ok_types`. I'm updating the code to check placement dtype via `isinstance` instead of `type(val)`.

Pull Request resolved: #164985
Approved by: https://github.com/ezyang, https://github.com/eellison
githubsgi pushed a commit to githubsgi/torchtitan that referenced this pull request Oct 13, 2025
this PR is a followup of SimpleFSDP+EP
[PR](pytorch#1529). Here, we add a
`gradient_divide_factor` following FSDP2 to ensure modules wrapped by
(FSDP+EP) has the correct gradient reduction value.

- The original FSDP2 implementation is in this
[PR](pytorch#1551).
- The `gradient_divide_factor` logic is
[here](https://github.com/pytorch/pytorch/blob/main/torch/distributed/fsdp/_fully_shard/_fsdp_collectives.py#L688)

We have two ways of handling `gradient_divide_factor` in
`reduce_scatter`:

1. The first one is to use `ReduceOp.PREMUL_SUM` to handle the
`gradient_divide_factor`. However, DTensor's `_reduce_shard_value` only
accepts `reduce_op` as a str input
([here](https://github.com/pytorch/pytorch/blob/8f705d019a64b1ca882e043b3eb98559273a9e59/torch/distributed/tensor/placement_types.py#L177-L210)).

To make` _reduce_shard_value` work correctly with ReduceOp.PREMUL_SUM,
we need to update the DTensor `_reduce_shard_tensor` and
`torch.distributed._functional_collectives.reduce_scatter_tensor` so
that it can pass the factor associated with ReduceOp.PREMUL_SUM as an
input.



2. Another way is to simulate `ReduceOp.PREMUL_SUM` with `ReduceOp.SUM`.
The logic is in this [Diff](https://www.internalfb.com/diff/D76546536).
It does a `div_` over gradient before performing `ReduceOp.SUM`.

Currently I'm following 2 since it is requires less change to
`_functional_collectives`.


After enabling `reduction_divide_factor`, we will see FSDP(=2) + EP (=4)
have identical loss:

<img width="1194" height="780" alt="Screenshot 2025-10-08 at 5 27 24 PM"
src="https://github.com/user-attachments/assets/aaf83109-8db8-4051-973d-c7b6950513de"
/>
Chao1Han pushed a commit to Chao1Han/pytorch that referenced this pull request Oct 21, 2025
This pr is to unblock SimpleFSDP+`gradient_divide_factor` [here](pytorch/torchtitan#1793). We will need to create a subclass for DTensor `Partial` placement. When tracing `SimpleFSDPPartial`, I hit the assertion error that `SimpleFSDPPartial` is not in `ok_types`. I'm updating the code to check placement dtype via `isinstance` instead of `type(val)`.

Pull Request resolved: pytorch#164985
Approved by: https://github.com/ezyang, https://github.com/eellison
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants