Open
Description
Problem
currently the API of context parallel have five problems.
- only support apply CP to whole model. if we have some cross attn in prep part of model with unsupported shape, it's impossible to apply CP since
_context_parallel
always override all SDPA and need to wrap whole backward. - no shard/unshard with gradient support. when I try to apply CP to transformer blocks only and remain other SDPA replicate, the
context_parallel_unshard
in pytorch hasno_grad
decorator. - weight gradients inside CP region is divided by size of CP mesh because we reduce them in DP+CP, this may work for optimizer with norm support, but make unit test harder to write, we have to scale them back to get same gradients as model without CP.
- The length of the sequence must be divisible by the number of CP (CP * 2 for robin).
- replicate input of CP region may contain wrong gradient because its gradient may be
Partial
, we have to check every replicate input and useto_local(grad_placements=[Partial()])
.
To resolve problem 1 above, I remove context_parallel
context to disable SDPA override, only enable _enable_cp_dispatcher
context, then we can enable CP SDPA iff all inputs are converted to DTensor. problem 2 is easy to resolve, just write some auto grad functions.
here is my questions:
- is there a better way to support
CP region
? - do you have any plan to support
CP region
officially and resolve issues above?