-
Notifications
You must be signed in to change notification settings - Fork 336
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add seq parallelism for attention and MoE MLP #1328
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Qinwen! Would be great if you could help run a few steps of training tests with profiles. Previously we met issues in optimization of inference, and some changes have degragation of training performance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
address comments
-- 7b8f711 by ZhiyuLi-goog <[email protected]>: [exp] seq exp sharding -- 802313a by ZhiyuLi-goog <[email protected]>: update -- d8d4595 by ZhiyuLi-goog <[email protected]>: update -- b7f3225 by Qinwen Xu <[email protected]>: merge for sp -- 7ed5fd1 by Qinwen Xu <[email protected]>: fix merge parts -- a1c6973 by Qinwen Xu <[email protected]>: update merge confict base config -- 3f2d278 by Qinwen Xu <[email protected]>: update to fix sharding mismatch -- 3e06ebb by Qinwen Xu <[email protected]>: update sub_seq for masks -- d23d27b by Qinwen Xu <[email protected]>: update sharding axis -- 924ce77 by Qinwen Xu <[email protected]>: update with reshape -- b62812d by Qinwen Xu <[email protected]>: solve merge conflict -- 746f4a3 by Qinwen Xu <[email protected]>: update for generate sharding -- a6d345c by Qinwen Xu <[email protected]>: enable compute_axis configurable in mixtral model -- e06c3d6 by Qinwen Xu <[email protected]>: address output_logits sharding -- 65a64d4 by Qinwen Xu <[email protected]>: clean up -- 10a9d82 by Qinwen Xu <[email protected]>: update -- 0cca6df by Qinwen Xu <[email protected]>: update -- ebae8e0 by Qinwen Xu <[email protected]>: fix tests -- 2e0c459 by Qinwen Xu <[email protected]>: added contition for non-sharded kernel for cp during inference only -- 37c843e by Qinwen Xu <[email protected]>: update -- b63c63b by Qinwen Xu <[email protected]>: bug fix -- 4007e7c by Qinwen Xu <[email protected]>: fix tests -- 72f2a90 by Qinwen Xu <[email protected]>: adddress comment -- 8da48f5 by Qinwen Xu <[email protected]>: update -- 8a43dd5 by Qinwen Xu <[email protected]>: address comments -- 56deeda by Qinwen Xu <[email protected]>: address comments -- 1c6be59 by Qinwen Xu <[email protected]>: revert -- bd0e199 by Qinwen Xu <[email protected]>: address lint -- 44d646f by Qinwen Xu <[email protected]>: reformat for lint -- 5172068 by Qinwen Xu <[email protected]>: update MOE test -- d6787c3 by Qinwen Xu <[email protected]>: add comment to explain grouping in generate_mask for moe model -- f964acd by Qinwen Xu <[email protected]>: address the comments -- c5174de by Qinwen Xu <[email protected]>: update to fix tests -- b86e035 by Qinwen Xu <[email protected]>: seperate yml for inference -- e96340e by Qinwen Xu <[email protected]>: update to address training perf difference -- 7446563 by Qinwen Xu <[email protected]>: update -- 3b5346f by Qinwen Xu <[email protected]>: revert back mask_shape for tests -- b7dcb1e by Qinwen Xu <[email protected]>: added back reshape and clean up merge changes -- c859b4c by Qinwen Xu <[email protected]>: address comment to remove reshape -- 7d91629 by Qinwen Xu <[email protected]>: update with different softmaxt score for inference/training for mask_generate -- 1a0fdb3 by Qinwen Xu <[email protected]>: lint COPYBARA_INTEGRATE_REVIEW=#1328 from AI-Hypercomputer:qinwen/sharding_merge_main 4d94d95 PiperOrigin-RevId: 735916344
Description
with sp+ep, moe customer 2k seq inference improved by 20%
FIXES: b/374773995
Tests
tested on v6e/v5p:
SEQ=2048
python MaxText/inference_microbenchmark.py MaxText/configs/inference.yml max_prefill_predict_length=$SEQ max_target_length=6144 model_name=mixtral-8x7b ici_fsdp_parallelism=1 ici_autoregressive_parallelism=1 ici_expert_parallelism=1 ici_context_parallelism=4 ici_tensor_parallelism=1 scan_layers=false per_device_batch_size=1 attention=dot_product megablox=False quantization=int8 checkpoint_is_quantized=True quantize_kvcache=True capacity_factor=1 tokenizer_path=assets/tokenizer.mistral-v3 compute_axis_order=0,2,1,3 ar_cache_axis_order=0,2,1,3 enable_jax_profiler=True inference_microbenchmark_prefill_lengths="$SEQ" base_output_directory=$OUT_DIR run_name=$RUN_NAME profiler=xplane model_call_mode=inference inference_microbenchmark_stages=prefill
Checklist
Before submitting this PR, please make sure (put X in square brackets):