Skip to content

Conversation

rpsilva-aws
Copy link
Collaborator

@rpsilva-aws rpsilva-aws commented May 19, 2025

This PR adds a new binding API annotate_custom_sharding that allows annotating an existing tensor with a custom sharding IR node without modifying its data layout. This is useful for cases where a tensor has already been sharded with mark_sharding but needs additional sharding annotations for compiler optimizations.

Unlike the existing mark_sharding function, annotate_custom_sharding only adds the annotation to the XLA IR without changing the underlying data distribution, enabling more flexible sharding strategies to be provided to XLA. This is particularly useful for introducing resharding annotations on already-sharded tensors.

Use Case

There are instances where we want to provide an explicit annotation hint around a kernel with manual sharding. In this case, we are limited to introducing custom sharding hints to XLA prior to the manual resharding. For instance, if we have FSDP + TP, and we wish to gather all weights across the FSDP dimension prior to the kernel, this is not possible. This PR allows us to introduce such functionality and flexibility, by redefining the sharding spec associated with the IR prior to the manual sharding.

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_resharding_v2 branch from 32344f3 to a16b53e Compare May 20, 2025 05:01
@rpsilva-aws rpsilva-aws changed the title Part 1: Disambiguate custom sharding op for DeviceData IR nodes Part 1: Add annotate_custom_sharding API May 20, 2025
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_resharding_v2 branch from a16b53e to 93e8632 Compare May 20, 2025 05:06
@rpsilva-aws rpsilva-aws changed the title Part 1: Add annotate_custom_sharding API Part 1: Introduce annotate_custom_sharding binding May 20, 2025
@rpsilva-aws rpsilva-aws force-pushed the rpsilva_resharding_v2 branch from 93e8632 to da5885d Compare May 21, 2025 05:51
@rpsilva-aws rpsilva-aws requested review from tengyifei and bhavya01 May 21, 2025 17:50
@rpsilva-aws rpsilva-aws marked this pull request as ready for review May 21, 2025 18:20
@rpsilva-aws rpsilva-aws changed the title Part 1: Introduce annotate_custom_sharding binding Introduce annotate_custom_sharding binding May 21, 2025
@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented May 21, 2025

Hey @tengyifei! Let me know if you're able to review this, or feel free to add anyone else. We have this as a requirement to unblock our use case above, and we might consider cherry picking it to 2.7.1, since it is a blocker.

@rpsilva-aws rpsilva-aws requested a review from zhanyong-wan May 22, 2025 20:50
@rpsilva-aws
Copy link
Collaborator Author

@bhavya01 @zhanyong-wan Do you have cycles on this? We need this to unblock some of our use cases.

@rpsilva-aws rpsilva-aws requested a review from ysiraichi May 23, 2025 17:19
@rpsilva-aws
Copy link
Collaborator Author

@bhavya01 @ysiraichi Hey folks, stale PR - let me know if you can review.

@rpsilva-aws rpsilva-aws removed the request for review from tengyifei May 28, 2025 00:58
@bhavya01
Copy link
Collaborator

I will take a look at this later today!

Copy link
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not really familiar with sharding, but I do have one question: wouldn't it be better to allow mark_sharding to be called multiple times, instead of creating a new API? I'm asking this question because, looking at the test, it's almost like mark_sharding and annotate_custom_sharding did the same thing (i.e. would a user know when to use which?).

@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented May 28, 2025

I am not really familiar with sharding, but I do have one question: wouldn't it be better to allow mark_sharding to be called multiple times, instead of creating a new API? I'm asking this question because, looking at the test, it's almost like mark_sharding and annotate_custom_sharding did the same thing (i.e. would a user know when to use which?).

At the moment, mark_sharding very simply does; 1/ If the tensor is an IR value (and not a clone of a device data node), add an IR custom sharding node, or 2/ If the tensor is a Device Data node on the host/device, we will actually invoke the sharded data transfer to the device.

This means that it's not possible to provide a custom sharding annotation (on XLA) over an intentionally replicated tensor on the device. Similarly, as the motivation above mentions, it also limits us to provide a custom sharding annotation over a Device Data node (weights, inputs, etc) that have already been sharded to the device - since it can't disambiguate if a user meant to reshard the tensor, or simply add a custom sharding op.

This API intended to provide these gaps to a user - and is intentionally targeting more familiar users who need to provide extra sharding annotation to XLA around the limitations above on their tensors. I think it's a well defined API that can be served for different use cases. If you have a new tensor on the host, mark_sharding and annotate_custom_sharding will have different reasonable behaviors.

I am happy to consider relaxing the mark_sharding API, e.g. removing: https://github.com/pytorch/xla/pull/9203/files#diff-76bd84e4abe22701ee8697bf77e9fc97e19b6d6ff05175f2dc87f938f3a88837R774-R775, but that basically means that if a tensor has already been sharded on the device, any following mark_sharding will entail an XLA custom sharding annotation - and NOT an attempt to reshard the tensor (which fails today if the sharding spec differs). Perhaps we introduce an argument that makes it clearer?

@bhavya01
Copy link
Collaborator

@ysiraichi Do you want to take another look before merging?

@rpsilva-aws
Copy link
Collaborator Author

@ysiraichi Let me know if the above is justified.

@ysiraichi
Copy link
Collaborator

Sorry for taking too long.

Perhaps we introduce an argument that makes it clearer?

That's what I was thinking. That said, again, I'm not familiar with sharding, so I trust your judgement, here. Also, thank you for the thorough explanation. It was very clarifying.

Copy link
Collaborator

@ysiraichi ysiraichi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM.

@rpsilva-aws
Copy link
Collaborator Author

rpsilva-aws commented Jun 5, 2025

Sorry for taking too long.

Perhaps we introduce an argument that makes it clearer?

That's what I was thinking. That said, again, I'm not familiar with sharding, so I trust your judgement, here. Also, thank you for the thorough explanation. It was very clarifying.

No worries, thanks for taking the time. I think, for now, it's comparably better to introduce the API given the different behavior, than to modify/extend an essential standard one. I am happy to eventually converge as we see fit.

I'll rebase and re-run the CI, thanks folks!

@rpsilva-aws rpsilva-aws force-pushed the rpsilva_resharding_v2 branch from da5885d to 89dbf80 Compare June 5, 2025 00:53
@zhanyong-wan zhanyong-wan removed their request for review June 5, 2025 01:31
@rpsilva-aws rpsilva-aws enabled auto-merge (squash) June 5, 2025 06:44
@rpsilva-aws rpsilva-aws merged commit 402612d into pytorch:master Jun 5, 2025
22 checks passed
@rpsilva-aws rpsilva-aws deleted the rpsilva_resharding_v2 branch June 5, 2025 06:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants