Open
Description
Thank you for your awesome project. I would like to ask how to solve the following issue:
I have implemented the logcumsumexp operator, where the input placement is Shard(-1) and the output placement is Replicate(). To obtain the final result, I need to create a custom all-reduce operator (instead of using the conventional sum). How should I go about implementing this?
More generally, for an operator function f
, given an input placement1 and an output placement2, where should I implement various custom communication operations? I would greatly appreciate it if you could provide some examples for this.