-
Notifications
You must be signed in to change notification settings - Fork 564
Introduce annotate_custom_sharding binding #9203
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
32344f3
to
a16b53e
Compare
a16b53e
to
93e8632
Compare
93e8632
to
da5885d
Compare
Hey @tengyifei! Let me know if you're able to review this, or feel free to add anyone else. We have this as a requirement to unblock our use case above, and we might consider cherry picking it to 2.7.1, since it is a blocker. |
@bhavya01 @zhanyong-wan Do you have cycles on this? We need this to unblock some of our use cases. |
@bhavya01 @ysiraichi Hey folks, stale PR - let me know if you can review. |
I will take a look at this later today! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not really familiar with sharding, but I do have one question: wouldn't it be better to allow mark_sharding
to be called multiple times, instead of creating a new API? I'm asking this question because, looking at the test, it's almost like mark_sharding
and annotate_custom_sharding
did the same thing (i.e. would a user know when to use which?).
At the moment, This means that it's not possible to provide a custom sharding annotation (on XLA) over an intentionally replicated tensor on the device. Similarly, as the motivation above mentions, it also limits us to provide a custom sharding annotation over a Device Data node (weights, inputs, etc) that have already been sharded to the device - since it can't disambiguate if a user meant to reshard the tensor, or simply add a custom sharding op. This API intended to provide these gaps to a user - and is intentionally targeting more familiar users who need to provide extra sharding annotation to XLA around the limitations above on their tensors. I think it's a well defined API that can be served for different use cases. If you have a new tensor on the host, I am happy to consider relaxing the |
@ysiraichi Do you want to take another look before merging? |
@ysiraichi Let me know if the above is justified. |
Sorry for taking too long.
That's what I was thinking. That said, again, I'm not familiar with sharding, so I trust your judgement, here. Also, thank you for the thorough explanation. It was very clarifying. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM.
No worries, thanks for taking the time. I think, for now, it's comparably better to introduce the API given the different behavior, than to modify/extend an essential standard one. I am happy to eventually converge as we see fit. I'll rebase and re-run the CI, thanks folks! |
da5885d
to
89dbf80
Compare
This PR adds a new binding API
annotate_custom_sharding
that allows annotating an existing tensor with a custom sharding IR node without modifying its data layout. This is useful for cases where a tensor has already been sharded withmark_sharding
but needs additional sharding annotations for compiler optimizations.Unlike the existing
mark_sharding
function,annotate_custom_sharding
only adds the annotation to the XLA IR without changing the underlying data distribution, enabling more flexible sharding strategies to be provided to XLA. This is particularly useful for introducing resharding annotations on already-sharded tensors.Use Case
There are instances where we want to provide an explicit annotation hint around a kernel with manual sharding. In this case, we are limited to introducing custom sharding hints to XLA prior to the manual resharding. For instance, if we have FSDP + TP, and we wish to gather all weights across the FSDP dimension prior to the kernel, this is not possible. This PR allows us to introduce such functionality and flexibility, by redefining the sharding spec associated with the IR prior to the manual sharding.