Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generation support in TE for Gemma model #2

Open
wants to merge 276 commits into
base: main
Choose a base branch
from

Conversation

sudhakarsingh27
Copy link
Owner

Description

Reviving NVIDIA#829 but without the tutorial code which is, for now, in a different branch te_gemma_generation_tutorial

denera and others added 30 commits May 22, 2024 17:05
…ax.jit (NVIDIA#785)

* fixed static argnums for jax.jit in single gpu encoder test, changed warning filtering for pytest

Signed-off-by: Alp Dener <[email protected]>

* propagating the fix to the JAX mnist example

Signed-off-by: Alp Dener <[email protected]>

* fixed missing space ibetween flags i QAA scripts

Signed-off-by: Alp Dener <[email protected]>

* added TE warnings into the ignore list

Signed-off-by: Alp Dener <[email protected]>

---------

Signed-off-by: Alp Dener <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Add NVRTC kernels for cast-transpose

Signed-off-by: Tim Moon <[email protected]>

* Update copyright year

Signed-off-by: Tim Moon <[email protected]>

* Add noop flag to NVRTC cast-transpose kernel

Signed-off-by: Tim Moon <[email protected]>

* Apply suggestions from code review

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
)

* Support noop concat without providing full tensor

Stop storing fused buffers in linear modules.

Signed-off-by: Tim Moon <[email protected]>

* Debug noop cat func

Signed-off-by: Tim Moon <[email protected]>

* Construct TE modules in tests with correct dtypes

Signed-off-by: Tim Moon <[email protected]>

* Add tolerances to numerical tests

Signed-off-by: Tim Moon <[email protected]>

* Use plain PyTorch concat when exporting to ONNX

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
…#780)

* Allow multi-dims for dgamma and dbeta in LN descriptor.

Signed-off-by: Ming Huang <[email protected]>

* Fix the jit error in examples/jax

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Remove unnecessary Pylint overrides

Signed-off-by: Tim Moon <[email protected]>

* Fixes to lint

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* combined layernorm_geglu with layernorm_gelu into fused_layernorm

Signed-off-by: Phuong Nguyen <[email protected]>

* fixes to pass all unit tests in test_custom_call_compute.py,
test_layer.py, and test_praxis_layer.py

Signed-off-by: Phuong Nguyen <[email protected]>

* cleaning and formatting

Signed-off-by: Phuong Nguyen <[email protected]>

* renaming based on reviewers suggestions

Signed-off-by: Phuong Nguyen <[email protected]>

* implemented partial fused layernorm

Signed-off-by: Phuong Nguyen <[email protected]>

* geglu + bias passed tests

Signed-off-by: Phuong Nguyen <[email protected]>

* added partial fused calculation for dbias_1

Signed-off-by: Phuong Nguyen <[email protected]>

* clean up

Co-authored-by: Alp Dener <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Phuong Nguyen <[email protected]>
Co-authored-by: Alp Dener <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Try using global buffer for cu_seqlens

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Avoid using functools.lru_cache

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* fixes

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Co-authored-by: Vasudevan Rengasamy <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Added HF Nanotron to integrations and updated GTC 24 video to ondemand link

Signed-off-by: Santosh Bhavani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Implemented swiglu and silu

Signed-off-by: Phuong Nguyen <[email protected]>

* Renamed nvte-*silu to nvte-*swish + generalized GetDBiasDact functions

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* make FusedAttn with CP support bias

Signed-off-by: Xiaowei Ren <[email protected]>

* assert Alibi cannot work with CP

Signed-off-by: Xiaowei Ren <[email protected]>

* syntax fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix variable name

Signed-off-by: Xiaowei Ren <[email protected]>

* fix tensor shapes

Signed-off-by: Xiaowei Ren <[email protected]>

* a typo fix

Signed-off-by: Xiaowei Ren <[email protected]>

* fix bias indexing for CP

Signed-off-by: Xiaowei Ren <[email protected]>

* bug fix

Signed-off-by: Xiaowei Ren <[email protected]>

* add attn bias tests

Signed-off-by: Xiaowei Ren <[email protected]>

* change dbias update location

Signed-off-by: Xiaowei Ren <[email protected]>

* fix CP test model configs

Signed-off-by: Xiaowei Ren <[email protected]>

* change CP test sequence length

Signed-off-by: Xiaowei Ren <[email protected]>

* make AttnFuncWithCP support qkv format of sbhd

Signed-off-by: Xiaowei Ren <[email protected]>

* make sure qkv are contiguous for CP in cuDNN fused attn

Signed-off-by: Xiaowei Ren <[email protected]>

* change assert message

Signed-off-by: Xiaowei Ren <[email protected]>

* fix code format

Signed-off-by: Xiaowei Ren <[email protected]>

---------

Signed-off-by: Xiaowei Ren <[email protected]>
Co-authored-by: cyanguwa <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Add support for MoE with FP8.

Signed-off-by: Dennis Liu <[email protected]>

* Fix unittest.

Signed-off-by: Dennis Liu <[email protected]>

* Fix error in linear backward.

Signed-off-by: Dennis Liu <[email protected]>

---------

Signed-off-by: Dennis Liu <[email protected]>
Co-authored-by: Przemyslaw Tredak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Add module level filter for deprecation warning in common

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

* Fix module

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>

---------

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
remove tp_size/tp_group as amax reduction is handled by fp8_group()

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
…IDIA#799)

restrict context parallel tests to sm80+ as fused/flash attn backends require sm80+

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Fix linter warnings from unused args

Signed-off-by: Tim Moon <[email protected]>

* Update .gitignore

Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Added pull request template

Signed-off-by: Przemek Tredak <[email protected]>

* Changes from the review

Signed-off-by: Przemek Tredak <[email protected]>

---------

Signed-off-by: Przemek Tredak <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Vasudevan Rengasamy <[email protected]>
Co-authored-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
…nite scale (NVIDIA#786)

* Handle the scaling factor when amax is too tiny that leads to an infinite scale

Signed-off-by: Jinze Xue <[email protected]>

* revert formatting changes

Signed-off-by: Jinze Xue <[email protected]>

* fix comments

Signed-off-by: Jinze Xue <[email protected]>

* Apply review suggestion

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Jinze Xue <[email protected]>

* Apply review suggestion

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Jinze Xue <[email protected]>

* Apply review suggestion

Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Jinze Xue <[email protected]>

* apply review suggestion

Signed-off-by: Jinze Xue <[email protected]>

* add test_recipe.py to qa/L0_pytorch_unittest/test.sh; fix unittest for is_first_microbatch=False

Signed-off-by: Jinze Xue <[email protected]>

* revert changes to update_weight_scale_inv

Signed-off-by: Jinze Xue <[email protected]>

* Debug test failures

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Jinze Xue <[email protected]>
Signed-off-by: Jinze Xue <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: Jinze Xue <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Co-authored-by: Tim Moon <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
…> 1 on Paxml. (NVIDIA#774)

* Support FP8 Meta Dtype (FM32) and Align FP8 Scale Update with PyTorch.

Signed-off-by: Ming Huang <[email protected]>

* Modify with the feedback of code review

Signed-off-by: Ming Huang <[email protected]>

* Hiding FlaxFloatMeta32 inside fp8.py

Signed-off-by: Ming Huang <[email protected]>

* Make functions to be JAX tracable objects.

Signed-off-by: Ming Huang <[email protected]>

* Rebased with mian.

Signed-off-by: Ming Huang <[email protected]>

* Update jax images for github workflow.

Signed-off-by: Ming Huang <[email protected]>

---------

Signed-off-by: Ming Huang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* initialize tp_group for FP8 DPA

Signed-off-by: Charlene Yang <[email protected]>

* fix cuDNN version in unit tests for cuDNN v9

Signed-off-by: Charlene Yang <[email protected]>

* add hook to ignore missing fused_attn._extra_states if training from old checkpoints

Signed-off-by: Charlene Yang <[email protected]>

* remove test and redundant implementation from last commit

Signed-off-by: Charlene Yang <[email protected]>

* remove warning message and replace with docstring

Signed-off-by: Charlene Yang <[email protected]>

* remove tp_size/tp_group in FusedAttention; amax reduction is handled with fp8_group

Signed-off-by: Charlene Yang <[email protected]>

* move core_attention.fused_attention._extra_state to core_attention._extra_state

Signed-off-by: Charlene Yang <[email protected]>

* simplify post_state_dict_hooks between FU and DPA

Signed-off-by: Charlene Yang <[email protected]>

* add temporary test

Signed-off-by: Charlene Yang <[email protected]>

* remove previous attempts to move core_attention.fused_attention to core_attention; keep the test

Signed-off-by: Charlene Yang <[email protected]>

* remove the test

Signed-off-by: Charlene Yang <[email protected]>

* disable pylint self arg for hook which is required by hook

Signed-off-by: Charlene Yang <[email protected]>

---------

Signed-off-by: Charlene Yang <[email protected]>
Signed-off-by: cyanguwa <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* Add layernorm_fp8_dot unit test

Signed-off-by: Reese Wang <[email protected]>

* Update the softmax primitives support conditions

Signed-off-by: Reese Wang <[email protected]>

* Add tests for the softmax primitives

Signed-off-by: Reese Wang <[email protected]>

* Round1 refactor of test_layer

Signed-off-by: Reese Wang <[email protected]>

* Split dropout arguments of ref code and add hidden/intermediate dropout elementwise comparison

Signed-off-by: Reese Wang <[email protected]>

* Add dropout_braodcast_dim, self_attn_mask tests and clean a few code

Signed-off-by: Reese Wang <[email protected]>

* Abstract test layer and fix a rope reference code diff

Signed-off-by: Reese Wang <[email protected]>

* Add bias tests

Signed-off-by: Reese Wang <[email protected]>

* Add epsilon and float32 tests

Signed-off-by: Reese Wang <[email protected]>

* Add relpos_bias and attention dropout tests

Signed-off-by: Reese Wang <[email protected]>

* Loose the atol

Signed-off-by: Reese Wang <[email protected]>

* Move common fixtures to conftest.py

Signed-off-by: Reese Wang <[email protected]>

* Add doc string for test_layer

Signed-off-by: Reese Wang <[email protected]>

* Add doc string for test_layer

Signed-off-by: Reese Wang <[email protected]>

* Fix conflicts of test_layer

Signed-off-by: Reese Wang <[email protected]>

* Avoid to left bias parameters in graph when use_bias=False

Signed-off-by: Reese Wang <[email protected]>

---------

Signed-off-by: Reese Wang <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* templated primitives and respective C++ functions

Signed-off-by: Phuong Nguyen <[email protected]>

* fixes for LayerNormMLP, tests in test_custom_compute all passed

Signed-off-by: Phuong Nguyen <[email protected]>

* added default arg for pybind get_workspace_size funcs

Signed-off-by: Phuong Nguyen <[email protected]>

* fixes for TestTransFormer with non-gated act tests

Signed-off-by: Phuong Nguyen <[email protected]>

* renamed gelu to act

Signed-off-by: Phuong Nguyen <[email protected]>

* improved enum implementation, avoid using magic numbers

Signed-off-by: Phuong Nguyen <[email protected]>

* Exposed C++ ActivationEnum to python side

Signed-off-by: Phuong Nguyen <[email protected]>

* Changed error messages

Signed-off-by: Phuong Nguyen <[email protected]>

* changed conditional check on input shape for dbias_cast_transpose

Signed-off-by: Phuong Nguyen <[email protected]>

* changed dtype (tol) for bias grad tests

Signed-off-by: Phuong Nguyen <[email protected]>

* fixes so that layer_norm_fp8_mlp can take bias = None

Signed-off-by: Phuong Nguyen <[email protected]>

* Set bias = None in flax modules

Signed-off-by: Phuong Nguyen <[email protected]>

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Update FP8 recipe test to handle recipe changes

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Bump FA version to 2.5.8

Signed-off-by: Kirthi Shankar Sivamani <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
* fixes for ActLuPrimitive in PAXML

* changed indices for arg_infos in sharding func in dbias_cast_transpose primitive

---------

Signed-off-by: Phuong Nguyen <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
pggPL and others added 30 commits June 6, 2024 09:20
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Pawel Gadzinski <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Signed-off-by: Sudhakar Singh <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.