-
Notifications
You must be signed in to change notification settings - Fork 22
Enhance GroupedLinear with integrating AITER triton kernels #413
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: dev
Are you sure you want to change the base?
Changes from all commits
bbbaa80
e3eee4e
566f92f
7bc9215
7aa1f6c
748a5af
0e4dd7c
0d7a307
4261bc0
0436ca9
0d25b61
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -10,6 +10,7 @@ | |
| import pytest | ||
| import random | ||
|
|
||
| from triton_kernels.test_common import get_tolerances | ||
| import torch | ||
| import torch.nn as nn | ||
| from torch.nn import Parameter | ||
|
|
@@ -2010,6 +2011,113 @@ def _test_grouped_linear_accuracy( | |
| return outputs | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", param_types, ids=str) | ||
| @pytest.mark.parametrize("num_gemms", [3, 6]) | ||
| @pytest.mark.parametrize("bs", batch_sizes) | ||
| @pytest.mark.parametrize("model", ["126m"]) | ||
| @pytest.mark.parametrize("recipe", [None]) | ||
| @pytest.mark.parametrize("fp8_model_params", [False]) | ||
| @pytest.mark.parametrize("fuse_wgrad_accumulation", [False]) | ||
| @pytest.mark.parametrize("bias", all_boolean) | ||
| @pytest.mark.parametrize("delay_wgrad_compute", all_boolean) | ||
| def test_grouped_linear_triton_accuracy( | ||
| dtype, | ||
| num_gemms, | ||
| bs, | ||
| model, | ||
| recipe, | ||
| fp8_model_params, | ||
| fuse_wgrad_accumulation, | ||
| bias, | ||
| delay_wgrad_compute, | ||
| parallel_mode=None, | ||
| ): | ||
| os.environ["NVTE_USE_GROUPED_GEMM_TRITON"] = "1" | ||
| try: | ||
| fp8 = recipe is not None | ||
|
|
||
| if IS_HIP_EXTENSION: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This whole new test series are for rocm-only. Maybe we should put not IS_HIP_EXTENSION to skip condition? |
||
| if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8: | ||
| pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") | ||
| if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: | ||
| pytest.skip("FP8 parameters are not supported in debug mode.") | ||
|
|
||
| config = model_configs[model] | ||
| if config.max_seqlen_q % 16 != 0 and fp8: | ||
| pytest.skip("FP8 requires sequence length to be divisible by 16.") | ||
|
|
||
| with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): | ||
| grouped_linear = GroupedLinear( | ||
| num_gemms, | ||
| config.hidden_size, | ||
| 4 * config.hidden_size, | ||
| bias=bias, | ||
| params_dtype=dtype, | ||
| parallel_mode=parallel_mode, | ||
| device="cuda", | ||
| fuse_wgrad_accumulation=fuse_wgrad_accumulation, | ||
| delay_wgrad_compute=delay_wgrad_compute, | ||
| save_original_input=False, | ||
| ).eval() | ||
| sequential_linear = torch.nn.ModuleList( | ||
| [ | ||
| Linear( | ||
| config.hidden_size, | ||
| 4 * config.hidden_size, | ||
| bias=bias, | ||
| params_dtype=dtype, | ||
| parallel_mode=parallel_mode, | ||
| device="cuda", | ||
| fuse_wgrad_accumulation=fuse_wgrad_accumulation, | ||
| ).eval() | ||
| for _ in range(num_gemms) | ||
| ] | ||
| ) | ||
|
|
||
| # Share params | ||
| with torch.no_grad(): | ||
| for i in range(num_gemms): | ||
| sequential_linear[i].weight = Parameter(getattr(grouped_linear, f"weight{i}").clone()) | ||
| if bias: | ||
| sequential_linear[i].bias = Parameter(getattr(grouped_linear, f"bias{i}").clone()) | ||
| if fuse_wgrad_accumulation: | ||
| weight_i = getattr(grouped_linear, f"weight{i}") | ||
| weight_i.main_grad = torch.rand_like(weight_i, dtype=torch.float32) | ||
| sequential_linear[i].weight.main_grad = weight_i.main_grad.clone() | ||
|
|
||
| outputs_ref = _test_grouped_linear_accuracy( | ||
| sequential_linear, | ||
| num_gemms, | ||
| bs, | ||
| dtype, | ||
| config, | ||
| recipe, | ||
| fp8, | ||
| fuse_wgrad_accumulation, | ||
| delay_wgrad_compute, | ||
| ) | ||
| outputs = _test_grouped_linear_accuracy( | ||
| grouped_linear, | ||
| num_gemms, | ||
| bs, | ||
| dtype, | ||
| config, | ||
| recipe, | ||
| fp8, | ||
| fuse_wgrad_accumulation, | ||
| delay_wgrad_compute, | ||
| ) | ||
|
|
||
| atol, rtol = get_tolerances(dtype) | ||
| if dtype == torch.float32: | ||
| atol = 2.6e-6 | ||
| rtol = 5e-2 | ||
| for o, o_ref in zip(outputs, outputs_ref): | ||
| torch.testing.assert_close(o, o_ref, rtol=rtol, atol=atol) | ||
| finally: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This newly added pytest has a large overlap with the following test_grouped_linear_accuracy. Can we consolidate these two?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We can, but I don't want to make future IFUs harder. But since this is just a unit test, we can probably unify them. I'll make changes. |
||
| os.environ.pop("NVTE_USE_GROUPED_GEMM_TRITON", None) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("dtype", param_types, ids=str) | ||
| @pytest.mark.parametrize("num_gemms", [3, 6]) | ||
| @pytest.mark.parametrize("bs", batch_sizes) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This env won't be cleared if the test is skipped of failed