-
Notifications
You must be signed in to change notification settings - Fork 74
support dynamic shapes in warp specialized inner outer persistent scheduler #5765
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -129,9 +129,15 @@ TensorView* scheduleReductionTV( | |||||
| // Reduction: [Persistent, TIDx, Vect] | ||||||
| vectorize(inner_reduce_axis, rparams->unroll_factor_inner_reduction); | ||||||
|
|
||||||
| // static bdimx is required for TMA warp specialization | ||||||
| int64_t compute_bdimx = getComputeBdimx(option, rparams->lparams.bdimx()); | ||||||
| inner_parallel_static(inner_reduce_axis, ParallelType::TIDx, compute_bdimx); | ||||||
| reduction_tv->split( | ||||||
| inner_reduce_axis, rparams->batches_per_block_inner_reduction, false); | ||||||
| reduction_tv->axis(inner_reduce_axis + 1)->parallelize(ParallelType::TIDx); | ||||||
| reduction_tv->axis(inner_reduce_axis + 1)->padToMultipleOfWarp(); | ||||||
|
|
||||||
| // // static bdimx is required for TMA warp specialization | ||||||
| // int64_t compute_bdimx = getComputeBdimx(option, | ||||||
| // rparams->lparams.bdimx()); inner_parallel_static(inner_reduce_axis, | ||||||
| // ParallelType::TIDx, compute_bdimx); | ||||||
|
|
||||||
| // Iteration: [I/Unroll/BIDy, BIDy, Unroll] | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Wrong comment prefix - should be
Suggested change
|
||||||
| if (rparams->unroll_factor_iter_dom > 1) { | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -1087,8 +1087,11 @@ TEST_P(TmaWarpSpecializedTest, SimpleFusion) { | |||||
|
|
||||||
| auto fusion = std::make_unique<Fusion>(); | ||||||
| FusionGuard fg(fusion.get()); | ||||||
| auto tv0 = makeContigConcreteTensor({dim0, dim1}, dtype); | ||||||
| auto tv1 = makeContigConcreteTensor({dim0, dim1}, dtype); | ||||||
| // For case contig_1_dtype_float_batch_2048_hidden_8192 | ||||||
| // the performance is 59.7% SOL uisng concrete inputs | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Typo: 'uisng' should be 'using'.
Suggested change
|
||||||
| // for symbolic inputs, the performance is 59.1% SOL | ||||||
| auto tv0 = makeContigTensor(2, dtype); | ||||||
| auto tv1 = makeContigTensor(2, dtype); | ||||||
| fusion->addInput(tv0); | ||||||
| fusion->addInput(tv1); | ||||||
| tv0 = maybeCastOp(DataType::Float, tv0); | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -2196,8 +2196,8 @@ TEST_P(TmaPersistentTestP, TmaInnerPersistentRmsNorm) { | |||||||||||||||||||||||||||||
| const float kEps = 1e-6; | ||||||||||||||||||||||||||||||
| Val* eps_ptr = IrBuilder::create<Val>(kEps); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| auto tv0 = makeContigConcreteTensor({x, y}, dtype); | ||||||||||||||||||||||||||||||
| auto tv1 = makeContigConcreteTensor({y}, dtype); | ||||||||||||||||||||||||||||||
| auto tv0 = makeContigTensor(2, dtype); | ||||||||||||||||||||||||||||||
| auto tv1 = makeContigTensor(1, dtype); | ||||||||||||||||||||||||||||||
| fusion.addInput(tv0); | ||||||||||||||||||||||||||||||
| fusion.addInput(tv1); | ||||||||||||||||||||||||||||||
| tv0 = maybeCastOp(DataType::Float, tv0); | ||||||||||||||||||||||||||||||
|
|
@@ -2271,7 +2271,7 @@ TEST_P(TmaPersistentTestP, TmaInnerPersistentSoftmax) { | |||||||||||||||||||||||||||||
| auto& fusion = *fusion_ptr; | ||||||||||||||||||||||||||||||
| FusionGuard fg(fusion_ptr.get()); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| auto tv0 = makeContigTensor(2, dtype); | ||||||||||||||||||||||||||||||
| auto tv0 = makeContigConcreteTensor({x, y}, dtype); | ||||||||||||||||||||||||||||||
| fusion.addInput(tv0); | ||||||||||||||||||||||||||||||
| tv0 = maybeCastOp(DataType::Float, tv0); | ||||||||||||||||||||||||||||||
| auto res = softmax(tv0, 1); | ||||||||||||||||||||||||||||||
|
|
@@ -2297,8 +2297,8 @@ INSTANTIATE_TEST_SUITE_P( | |||||||||||||||||||||||||||||
| ::testing::Combine( | ||||||||||||||||||||||||||||||
| testing::Values(DataType::BFloat16), | ||||||||||||||||||||||||||||||
| testing::Values( | ||||||||||||||||||||||||||||||
| deviceSMCount() / 2, | ||||||||||||||||||||||||||||||
| 1024), // batch size, less or larger than sm count | ||||||||||||||||||||||||||||||
| deviceSMCount() / 2, // small batch, can't do grid stride loop | ||||||||||||||||||||||||||||||
| 2048), // batch size, less or larger than sm count | ||||||||||||||||||||||||||||||
| testing::ValuesIn(Pow2Vals1to1Million)), // hidden size | ||||||||||||||||||||||||||||||
| [](const testing::TestParamInfo<TmaPersistentTestParams>& info) { | ||||||||||||||||||||||||||||||
| auto dtype = std::get<0>(info.param); | ||||||||||||||||||||||||||||||
|
|
@@ -2308,4 +2308,85 @@ INSTANTIATE_TEST_SUITE_P( | |||||||||||||||||||||||||||||
| os << dtype << "_" << x << "_" << y; | ||||||||||||||||||||||||||||||
| return os.str(); | ||||||||||||||||||||||||||||||
| }); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| // Test that kernels with different launch parameters are not incorrectly reused | ||||||||||||||||||||||||||||||
| // This ensures that LaunchParams is properly included in the cache key | ||||||||||||||||||||||||||||||
| TEST_F(TmaPersistentTestF, KernelReuse) { | ||||||||||||||||||||||||||||||
| auto fusion_ptr = std::make_unique<Fusion>(); | ||||||||||||||||||||||||||||||
| auto& fusion = *fusion_ptr; | ||||||||||||||||||||||||||||||
| FusionGuard fg(fusion_ptr.get()); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| // Create an RMS norm fusion that will use inner persistent scheduler | ||||||||||||||||||||||||||||||
| const float kEps = 1e-6; | ||||||||||||||||||||||||||||||
| Val* eps_ptr = IrBuilder::create<Val>(kEps); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| auto tv0 = makeContigTensor(2, DataType::BFloat16); | ||||||||||||||||||||||||||||||
| auto tv1 = makeContigTensor(1, DataType::BFloat16); | ||||||||||||||||||||||||||||||
| fusion.addInput(tv0); | ||||||||||||||||||||||||||||||
| fusion.addInput(tv1); | ||||||||||||||||||||||||||||||
| tv0 = maybeCastOp(DataType::Float, tv0); | ||||||||||||||||||||||||||||||
| tv1 = maybeCastOp(DataType::Float, tv1); | ||||||||||||||||||||||||||||||
| auto rms_norm_results = rms_norm(tv0, 1, tv1, eps_ptr); | ||||||||||||||||||||||||||||||
| auto output = maybeCastOp(DataType::BFloat16, rms_norm_results.output); | ||||||||||||||||||||||||||||||
| fusion.addOutput(output); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| FusionExecutorCache executor_cache(std::move(fusion_ptr)); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| // Helper to get the number of compiled kernel runtimes | ||||||||||||||||||||||||||||||
| auto numRuntimes = [&executor_cache]() -> size_t { | ||||||||||||||||||||||||||||||
| // this is map<pair<device, conc_info>, vector<FusionKernelRuntime>> | ||||||||||||||||||||||||||||||
| const auto& runtime_map = executor_cache.getKernelRuntimes(); | ||||||||||||||||||||||||||||||
| if (runtime_map.empty()) { | ||||||||||||||||||||||||||||||
| return 0; | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
| return runtime_map | ||||||||||||||||||||||||||||||
| .begin() // There should be only one device/concretization pair | ||||||||||||||||||||||||||||||
| ->second.size(); | ||||||||||||||||||||||||||||||
| }; | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| // First run with specific dimensions that will produce launch config A | ||||||||||||||||||||||||||||||
| auto input1 = at::randn({2048, 4096}, options); | ||||||||||||||||||||||||||||||
| auto weight1 = at::randn({4096}, options); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| auto output1 = executor_cache.runFusionWithInputs({input1, weight1}); | ||||||||||||||||||||||||||||||
| testValidate( | ||||||||||||||||||||||||||||||
| executor_cache.fusion(), output1, {input1, weight1}, __LINE__, __FILE__); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| EXPECT_EQ(numRuntimes(), 1) << "First run should compile one kernel"; | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| FusionKernelRuntime* first_runtime = | ||||||||||||||||||||||||||||||
| executor_cache.getMostRecentKernelRuntime(); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| // Second run with different outer dimension - should reuse the kernel | ||||||||||||||||||||||||||||||
| auto input2 = at::randn({2048 + 8, 4096}, options); | ||||||||||||||||||||||||||||||
| auto weight2 = at::randn({4096}, options); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| auto output2 = executor_cache.runFusionWithInputs({input2, weight2}); | ||||||||||||||||||||||||||||||
| testValidate( | ||||||||||||||||||||||||||||||
| executor_cache.fusion(), output2, {input2, weight2}, __LINE__, __FILE__); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| EXPECT_EQ(numRuntimes(), 1) | ||||||||||||||||||||||||||||||
| << "Same dimensions should reuse the existing kernel"; | ||||||||||||||||||||||||||||||
|
Comment on lines
+2370
to
+2371
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Comment says "Same dimensions should reuse" but the test uses different outer dimension (2048 + 8 vs 2048) - comment is misleading. |
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| FusionKernelRuntime* second_runtime = | ||||||||||||||||||||||||||||||
| executor_cache.getMostRecentKernelRuntime(); | ||||||||||||||||||||||||||||||
| EXPECT_EQ(first_runtime, second_runtime) | ||||||||||||||||||||||||||||||
| << "Should reuse the same runtime for identical shapes"; | ||||||||||||||||||||||||||||||
|
Comment on lines
+2370
to
+2376
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. style: the comments are misleading here. Line 2370 says "Same dimensions should reuse the existing kernel" but the shapes are different (
Suggested change
|
||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| // Third run with slightly different inner dimension - should reuse the kernel | ||||||||||||||||||||||||||||||
| auto input3 = at::randn({2048 + 8, 4096 - 8}, options); | ||||||||||||||||||||||||||||||
| auto weight3 = at::randn({4096 - 8}, options); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| auto output3 = executor_cache.runFusionWithInputs({input3, weight3}); | ||||||||||||||||||||||||||||||
| testValidate( | ||||||||||||||||||||||||||||||
| executor_cache.fusion(), output3, {input3, weight3}, __LINE__, __FILE__); | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| // If launch params are properly included in cache, this should compile a new | ||||||||||||||||||||||||||||||
| // kernel | ||||||||||||||||||||||||||||||
| EXPECT_GE(numRuntimes(), 1) | ||||||||||||||||||||||||||||||
| << "Different dimensions may create new kernel if launch params differ"; | ||||||||||||||||||||||||||||||
| } | ||||||||||||||||||||||||||||||
|
|
||||||||||||||||||||||||||||||
| } // namespace nvfuser | ||||||||||||||||||||||||||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment formatting is broken - line breaks should be after the comment prefix.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!