forked from NVIDIA/TransformerEngine
-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generation support in TE for Gemma model #2
Open
sudhakarsingh27
wants to merge
276
commits into
main
Choose a base branch
from
te_gemma_generation_support
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
…ax.jit (NVIDIA#785) * fixed static argnums for jax.jit in single gpu encoder test, changed warning filtering for pytest Signed-off-by: Alp Dener <[email protected]> * propagating the fix to the JAX mnist example Signed-off-by: Alp Dener <[email protected]> * fixed missing space ibetween flags i QAA scripts Signed-off-by: Alp Dener <[email protected]> * added TE warnings into the ignore list Signed-off-by: Alp Dener <[email protected]> --------- Signed-off-by: Alp Dener <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Add NVRTC kernels for cast-transpose Signed-off-by: Tim Moon <[email protected]> * Update copyright year Signed-off-by: Tim Moon <[email protected]> * Add noop flag to NVRTC cast-transpose kernel Signed-off-by: Tim Moon <[email protected]> * Apply suggestions from code review Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
) * Support noop concat without providing full tensor Stop storing fused buffers in linear modules. Signed-off-by: Tim Moon <[email protected]> * Debug noop cat func Signed-off-by: Tim Moon <[email protected]> * Construct TE modules in tests with correct dtypes Signed-off-by: Tim Moon <[email protected]> * Add tolerances to numerical tests Signed-off-by: Tim Moon <[email protected]> * Use plain PyTorch concat when exporting to ONNX Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
…#780) * Allow multi-dims for dgamma and dbeta in LN descriptor. Signed-off-by: Ming Huang <[email protected]> * Fix the jit error in examples/jax Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Remove unnecessary Pylint overrides Signed-off-by: Tim Moon <[email protected]> * Fixes to lint Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* combined layernorm_geglu with layernorm_gelu into fused_layernorm Signed-off-by: Phuong Nguyen <[email protected]> * fixes to pass all unit tests in test_custom_call_compute.py, test_layer.py, and test_praxis_layer.py Signed-off-by: Phuong Nguyen <[email protected]> * cleaning and formatting Signed-off-by: Phuong Nguyen <[email protected]> * renaming based on reviewers suggestions Signed-off-by: Phuong Nguyen <[email protected]> * implemented partial fused layernorm Signed-off-by: Phuong Nguyen <[email protected]> * geglu + bias passed tests Signed-off-by: Phuong Nguyen <[email protected]> * added partial fused calculation for dbias_1 Signed-off-by: Phuong Nguyen <[email protected]> * clean up Co-authored-by: Alp Dener <[email protected]> Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Phuong Nguyen <[email protected]> Co-authored-by: Alp Dener <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Try using global buffer for cu_seqlens Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Avoid using functools.lru_cache Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * fixes Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Co-authored-by: Vasudevan Rengasamy <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Added HF Nanotron to integrations and updated GTC 24 video to ondemand link Signed-off-by: Santosh Bhavani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Implemented swiglu and silu Signed-off-by: Phuong Nguyen <[email protected]> * Renamed nvte-*silu to nvte-*swish + generalized GetDBiasDact functions Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* make FusedAttn with CP support bias Signed-off-by: Xiaowei Ren <[email protected]> * assert Alibi cannot work with CP Signed-off-by: Xiaowei Ren <[email protected]> * syntax fix Signed-off-by: Xiaowei Ren <[email protected]> * fix variable name Signed-off-by: Xiaowei Ren <[email protected]> * fix tensor shapes Signed-off-by: Xiaowei Ren <[email protected]> * a typo fix Signed-off-by: Xiaowei Ren <[email protected]> * fix bias indexing for CP Signed-off-by: Xiaowei Ren <[email protected]> * bug fix Signed-off-by: Xiaowei Ren <[email protected]> * add attn bias tests Signed-off-by: Xiaowei Ren <[email protected]> * change dbias update location Signed-off-by: Xiaowei Ren <[email protected]> * fix CP test model configs Signed-off-by: Xiaowei Ren <[email protected]> * change CP test sequence length Signed-off-by: Xiaowei Ren <[email protected]> * make AttnFuncWithCP support qkv format of sbhd Signed-off-by: Xiaowei Ren <[email protected]> * make sure qkv are contiguous for CP in cuDNN fused attn Signed-off-by: Xiaowei Ren <[email protected]> * change assert message Signed-off-by: Xiaowei Ren <[email protected]> * fix code format Signed-off-by: Xiaowei Ren <[email protected]> --------- Signed-off-by: Xiaowei Ren <[email protected]> Co-authored-by: cyanguwa <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Add support for MoE with FP8. Signed-off-by: Dennis Liu <[email protected]> * Fix unittest. Signed-off-by: Dennis Liu <[email protected]> * Fix error in linear backward. Signed-off-by: Dennis Liu <[email protected]> --------- Signed-off-by: Dennis Liu <[email protected]> Co-authored-by: Przemyslaw Tredak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Add module level filter for deprecation warning in common Signed-off-by: Kirthi Shankar Sivamani <[email protected]> * Fix module Signed-off-by: Kirthi Shankar Sivamani <[email protected]> --------- Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
remove tp_size/tp_group as amax reduction is handled by fp8_group() Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
…IDIA#799) restrict context parallel tests to sm80+ as fused/flash attn backends require sm80+ Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Fix linter warnings from unused args Signed-off-by: Tim Moon <[email protected]> * Update .gitignore Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Added pull request template Signed-off-by: Przemek Tredak <[email protected]> * Changes from the review Signed-off-by: Przemek Tredak <[email protected]> --------- Signed-off-by: Przemek Tredak <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Vasudevan Rengasamy <[email protected]> Co-authored-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
…nite scale (NVIDIA#786) * Handle the scaling factor when amax is too tiny that leads to an infinite scale Signed-off-by: Jinze Xue <[email protected]> * revert formatting changes Signed-off-by: Jinze Xue <[email protected]> * fix comments Signed-off-by: Jinze Xue <[email protected]> * Apply review suggestion Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Jinze Xue <[email protected]> * Apply review suggestion Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Jinze Xue <[email protected]> * Apply review suggestion Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Jinze Xue <[email protected]> * apply review suggestion Signed-off-by: Jinze Xue <[email protected]> * add test_recipe.py to qa/L0_pytorch_unittest/test.sh; fix unittest for is_first_microbatch=False Signed-off-by: Jinze Xue <[email protected]> * revert changes to update_weight_scale_inv Signed-off-by: Jinze Xue <[email protected]> * Debug test failures Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Jinze Xue <[email protected]> Signed-off-by: Jinze Xue <[email protected]> Signed-off-by: Tim Moon <[email protected]> Co-authored-by: Jinze Xue <[email protected]> Co-authored-by: Tim Moon <[email protected]> Co-authored-by: Tim Moon <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
…> 1 on Paxml. (NVIDIA#774) * Support FP8 Meta Dtype (FM32) and Align FP8 Scale Update with PyTorch. Signed-off-by: Ming Huang <[email protected]> * Modify with the feedback of code review Signed-off-by: Ming Huang <[email protected]> * Hiding FlaxFloatMeta32 inside fp8.py Signed-off-by: Ming Huang <[email protected]> * Make functions to be JAX tracable objects. Signed-off-by: Ming Huang <[email protected]> * Rebased with mian. Signed-off-by: Ming Huang <[email protected]> * Update jax images for github workflow. Signed-off-by: Ming Huang <[email protected]> --------- Signed-off-by: Ming Huang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* initialize tp_group for FP8 DPA Signed-off-by: Charlene Yang <[email protected]> * fix cuDNN version in unit tests for cuDNN v9 Signed-off-by: Charlene Yang <[email protected]> * add hook to ignore missing fused_attn._extra_states if training from old checkpoints Signed-off-by: Charlene Yang <[email protected]> * remove test and redundant implementation from last commit Signed-off-by: Charlene Yang <[email protected]> * remove warning message and replace with docstring Signed-off-by: Charlene Yang <[email protected]> * remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group Signed-off-by: Charlene Yang <[email protected]> * move core_attention.fused_attention._extra_state to core_attention._extra_state Signed-off-by: Charlene Yang <[email protected]> * simplify post_state_dict_hooks between FU and DPA Signed-off-by: Charlene Yang <[email protected]> * add temporary test Signed-off-by: Charlene Yang <[email protected]> * remove previous attempts to move core_attention.fused_attention to core_attention; keep the test Signed-off-by: Charlene Yang <[email protected]> * remove the test Signed-off-by: Charlene Yang <[email protected]> * disable pylint self arg for hook which is required by hook Signed-off-by: Charlene Yang <[email protected]> --------- Signed-off-by: Charlene Yang <[email protected]> Signed-off-by: cyanguwa <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* Add layernorm_fp8_dot unit test Signed-off-by: Reese Wang <[email protected]> * Update the softmax primitives support conditions Signed-off-by: Reese Wang <[email protected]> * Add tests for the softmax primitives Signed-off-by: Reese Wang <[email protected]> * Round1 refactor of test_layer Signed-off-by: Reese Wang <[email protected]> * Split dropout arguments of ref code and add hidden/intermediate dropout elementwise comparison Signed-off-by: Reese Wang <[email protected]> * Add dropout_braodcast_dim, self_attn_mask tests and clean a few code Signed-off-by: Reese Wang <[email protected]> * Abstract test layer and fix a rope reference code diff Signed-off-by: Reese Wang <[email protected]> * Add bias tests Signed-off-by: Reese Wang <[email protected]> * Add epsilon and float32 tests Signed-off-by: Reese Wang <[email protected]> * Add relpos_bias and attention dropout tests Signed-off-by: Reese Wang <[email protected]> * Loose the atol Signed-off-by: Reese Wang <[email protected]> * Move common fixtures to conftest.py Signed-off-by: Reese Wang <[email protected]> * Add doc string for test_layer Signed-off-by: Reese Wang <[email protected]> * Add doc string for test_layer Signed-off-by: Reese Wang <[email protected]> * Fix conflicts of test_layer Signed-off-by: Reese Wang <[email protected]> * Avoid to left bias parameters in graph when use_bias=False Signed-off-by: Reese Wang <[email protected]> --------- Signed-off-by: Reese Wang <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* templated primitives and respective C++ functions Signed-off-by: Phuong Nguyen <[email protected]> * fixes for LayerNormMLP, tests in test_custom_compute all passed Signed-off-by: Phuong Nguyen <[email protected]> * added default arg for pybind get_workspace_size funcs Signed-off-by: Phuong Nguyen <[email protected]> * fixes for TestTransFormer with non-gated act tests Signed-off-by: Phuong Nguyen <[email protected]> * renamed gelu to act Signed-off-by: Phuong Nguyen <[email protected]> * improved enum implementation, avoid using magic numbers Signed-off-by: Phuong Nguyen <[email protected]> * Exposed C++ ActivationEnum to python side Signed-off-by: Phuong Nguyen <[email protected]> * Changed error messages Signed-off-by: Phuong Nguyen <[email protected]> * changed conditional check on input shape for dbias_cast_transpose Signed-off-by: Phuong Nguyen <[email protected]> * changed dtype (tol) for bias grad tests Signed-off-by: Phuong Nguyen <[email protected]> * fixes so that layer_norm_fp8_mlp can take bias = None Signed-off-by: Phuong Nguyen <[email protected]> * Set bias = None in flax modules Signed-off-by: Phuong Nguyen <[email protected]> --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Update FP8 recipe test to handle recipe changes Signed-off-by: Tim Moon <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Bump FA version to 2.5.8 Signed-off-by: Kirthi Shankar Sivamani <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
* fixes for ActLuPrimitive in PAXML * changed indices for arg_infos in sharding func in dbias_cast_transpose primitive --------- Signed-off-by: Phuong Nguyen <[email protected]> Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Description
Reviving NVIDIA#829 but without the tutorial code which is, for now, in a different branch te_gemma_generation_tutorial