Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add seq parallelism for attention and MoE MLP #1328

Closed
wants to merge 53 commits into from

Conversation

suexu1025
Copy link
Collaborator

@suexu1025 suexu1025 commented Mar 1, 2025

Description

  1. Add seq_parallelism + exp_parallelism for attention + followed MLP module
    with sp+ep, moe customer 2k seq inference improved by 20%
  2. Fix prefill_KV_cache sharding mismatch during seq_parallelism
  3. decode improved by 10%
  4. Enable inference auto layout in mistral model

FIXES: b/374773995

Tests

tested on v6e/v5p:
SEQ=2048

python MaxText/inference_microbenchmark.py MaxText/configs/inference.yml max_prefill_predict_length=$SEQ max_target_length=6144 model_name=mixtral-8x7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_expert_parallelism=1 ici_context_parallelism=4 ici_tensor_parallelism=1 scan_layers=false per_device_batch_size=1 attention=dot_product megablox=False quantization=int8 checkpoint_is_quantized=True quantize_kvcache=True capacity_factor=1 tokenizer_path=assets/tokenizer.mistral-v3 compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3 enable_jax_profiler=True inference_microbenchmark_prefill_lengths="$SEQ" base_output_directory=$OUT_DIR run_name=$RUN_NAME profiler=xplane model_call_mode=inference inference_microbenchmark_stages=prefill

Checklist

Before submitting this PR, please make sure (put X in square brackets):

  • I have performed a self-review of my code.
  • I have necessary comments in my code, particularly in hard-to-understand areas.
  • I have run end-to-end tests tests and provided workload links above if applicable.
  • I have made or will make corresponding changes to the doc if needed.

Copy link

google-cla bot commented Mar 1, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@suexu1025 suexu1025 changed the title [Draft] Add seq parallelism for attention and MLP Add seq parallelism for attention and MoE MLP Mar 6, 2025
@suexu1025 suexu1025 requested a review from RissyRan March 7, 2025 22:59
Copy link
Collaborator

@RissyRan RissyRan left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Qinwen! Would be great if you could help run a few steps of training tests with profiles. Previously we met issues in optimization of inference, and some changes have degragation of training performance.

@suexu1025 suexu1025 requested a review from RissyRan March 11, 2025 02:11
Copy link
Collaborator Author

@suexu1025 suexu1025 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

address comments

copybara-service bot pushed a commit that referenced this pull request Mar 11, 2025
--
7b8f711 by ZhiyuLi-goog <[email protected]>:

[exp] seq exp sharding

--
802313a by ZhiyuLi-goog <[email protected]>:

update

--
d8d4595 by ZhiyuLi-goog <[email protected]>:

update

--
b7f3225 by Qinwen Xu <[email protected]>:

merge for sp

--
7ed5fd1 by Qinwen Xu <[email protected]>:

fix merge parts

--
a1c6973 by Qinwen Xu <[email protected]>:

update merge confict base config

--
3f2d278 by Qinwen Xu <[email protected]>:

update to fix sharding mismatch

--
3e06ebb by Qinwen Xu <[email protected]>:

update sub_seq for masks

--
d23d27b by Qinwen Xu <[email protected]>:

update sharding axis

--
924ce77 by Qinwen Xu <[email protected]>:

update with reshape

--
b62812d by Qinwen Xu <[email protected]>:

solve merge conflict

--
746f4a3 by Qinwen Xu <[email protected]>:

update for generate sharding

--
a6d345c by Qinwen Xu <[email protected]>:

enable compute_axis configurable in mixtral model

--
e06c3d6 by Qinwen Xu <[email protected]>:

address output_logits sharding

--
65a64d4 by Qinwen Xu <[email protected]>:

clean up

--
10a9d82 by Qinwen Xu <[email protected]>:

update

--
0cca6df by Qinwen Xu <[email protected]>:

update

--
ebae8e0 by Qinwen Xu <[email protected]>:

fix tests

--
2e0c459 by Qinwen Xu <[email protected]>:

added contition for non-sharded kernel for cp during inference only

--
37c843e by Qinwen Xu <[email protected]>:

update

--
b63c63b by Qinwen Xu <[email protected]>:

bug fix

--
4007e7c by Qinwen Xu <[email protected]>:

fix tests

--
72f2a90 by Qinwen Xu <[email protected]>:

adddress comment

--
8da48f5 by Qinwen Xu <[email protected]>:

update

--
8a43dd5 by Qinwen Xu <[email protected]>:

address comments

--
56deeda by Qinwen Xu <[email protected]>:

address comments

--
1c6be59 by Qinwen Xu <[email protected]>:

revert

--
bd0e199 by Qinwen Xu <[email protected]>:

address lint

--
44d646f by Qinwen Xu <[email protected]>:

reformat for lint

--
5172068 by Qinwen Xu <[email protected]>:

update MOE test

--
d6787c3 by Qinwen Xu <[email protected]>:

add comment to explain grouping in generate_mask for moe model

--
f964acd by Qinwen Xu <[email protected]>:

address the comments

--
c5174de by Qinwen Xu <[email protected]>:

update to fix tests

--
b86e035 by Qinwen Xu <[email protected]>:

seperate yml for inference

--
e96340e by Qinwen Xu <[email protected]>:

update to address training perf difference

--
7446563 by Qinwen Xu <[email protected]>:

update

--
3b5346f by Qinwen Xu <[email protected]>:

revert back mask_shape for tests

--
b7dcb1e by Qinwen Xu <[email protected]>:

added back reshape and clean up merge changes

--
c859b4c by Qinwen Xu <[email protected]>:

address comment to remove reshape

--
7d91629 by Qinwen Xu <[email protected]>:

update with different softmaxt score for inference/training for mask_generate

--
1a0fdb3 by Qinwen Xu <[email protected]>:

lint

COPYBARA_INTEGRATE_REVIEW=#1328 from AI-Hypercomputer:qinwen/sharding_merge_main 4d94d95
PiperOrigin-RevId: 735916344
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants