Skip to content

[question]can't disable CP for specific (unsupported) SDPA op #757

Open
@FindDefinition

Description

@FindDefinition

Problem

currently the API of context parallel have five problems.

  1. only support apply CP to whole model. if we have some cross attn in prep part of model with unsupported shape, it's impossible to apply CP since _context_parallel always override all SDPA and need to wrap whole backward.
  2. no shard/unshard with gradient support. when I try to apply CP to transformer blocks only and remain other SDPA replicate, the context_parallel_unshard in pytorch has no_grad decorator.
  3. weight gradients inside CP region is divided by size of CP mesh because we reduce them in DP+CP, this may work for optimizer with norm support, but make unit test harder to write, we have to scale them back to get same gradients as model without CP.
  4. The length of the sequence must be divisible by the number of CP (CP * 2 for robin).
  5. replicate input of CP region may contain wrong gradient because its gradient may be Partial, we have to check every replicate input and use to_local(grad_placements=[Partial()]).

To resolve problem 1 above, I remove context_parallel context to disable SDPA override, only enable _enable_cp_dispatcher context, then we can enable CP SDPA iff all inputs are converted to DTensor. problem 2 is easy to resolve, just write some auto grad functions.

here is my questions:

  1. is there a better way to support CP region?
  2. do you have any plan to support CP region officially and resolve issues above?

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions