|
20 | 20 | #include "tensorrt_llm/common/dataType.h"
|
21 | 21 | #include "tensorrt_llm/common/opUtils.h"
|
22 | 22 | #include "tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h"
|
23 |
| -#include "tensorrt_llm/kernels/communicationKernels/customLowPrecisionAllReduceKernels.h" |
24 | 23 | #include "tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h"
|
25 | 24 | #include "tensorrt_llm/kernels/customAllReduceKernels.h"
|
26 | 25 | #include "tensorrt_llm/kernels/internal_cutlass_kernels/include/fp4_gemm.h"
|
@@ -178,8 +177,6 @@ class AllreduceOp
|
178 | 177 | case AllReduceStrategyType::ONESHOT:
|
179 | 178 | case AllReduceStrategyType::TWOSHOT:
|
180 | 179 | return runFusionAllReduce(input, residual, norm_weight, scale, bias, workspace, runtime_strategy);
|
181 |
| - case AllReduceStrategyType::LOWPRECISION: |
182 |
| - return runLowPrecisionAllReduce(input, residual, norm_weight, scale, bias); |
183 | 180 | default: TORCH_CHECK(false, "Invalid runtime strategy"); return {};
|
184 | 181 | }
|
185 | 182 | }
|
@@ -299,73 +296,6 @@ class AllreduceOp
|
299 | 296 | return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output);
|
300 | 297 | }
|
301 | 298 |
|
302 |
| - std::vector<torch::Tensor> runLowPrecisionAllReduce(torch::Tensor const& input, |
303 |
| - torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight, |
304 |
| - torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias) noexcept |
305 |
| - { |
306 |
| -#ifdef ENABLE_FP8 |
307 |
| - auto stream = at::cuda::getCurrentCUDAStream(input.get_device()); |
308 |
| - int size = input.numel(); |
309 |
| - int hidden_size = input.size(-1); |
310 |
| - |
311 |
| - auto const tp_size = mGroup.size(); |
312 |
| - auto const cur_rank = COMM_SESSION.getRank(); |
313 |
| - int tp_rank = 0; |
314 |
| - |
315 |
| - for (auto const& currentRank : mGroup) |
316 |
| - { |
317 |
| - if (cur_rank == currentRank) |
318 |
| - break; |
319 |
| - ++tp_rank; |
320 |
| - } |
321 |
| - |
322 |
| - int bytes_per_element = input.element_size(); |
323 |
| - |
324 |
| - int token_num = size / hidden_size; |
325 |
| - |
326 |
| - auto parts = tensorrt_llm::kernels::splitNumber(size); |
327 |
| - |
328 |
| - torch::Tensor reduce_output = torch::empty_like(input); |
329 |
| - |
330 |
| - size_t global_offset = 0; |
331 |
| - for (size_t i = 0; i < parts.size(); ++i) |
332 |
| - { |
333 |
| - size_t tmp_size = parts[i]; |
334 |
| - tensorrt_llm::kernels::LowPrecisionAllReduceParams tmp_param; |
335 |
| - if (tp_size <= 4) |
336 |
| - { |
337 |
| - tmp_param = tensorrt_llm::kernels::LowPrecisionAllReduceParams::deserialize( |
338 |
| - tp_size, tp_rank, mType, token_num, hidden_size); |
339 |
| - } |
340 |
| - else |
341 |
| - { |
342 |
| - tmp_param = tensorrt_llm::kernels::LowPrecisionAllReduceParams::deserialize_hier( |
343 |
| - tp_size, tp_rank, mType, token_num, hidden_size); |
344 |
| - } |
345 |
| - |
346 |
| - tmp_param.local_input_buffer_ptr = reinterpret_cast<void const*>( |
347 |
| - reinterpret_cast<char const*>(input.data_ptr()) + global_offset * bytes_per_element); |
348 |
| - tmp_param.local_output_buffer_ptr = reinterpret_cast<void*>( |
349 |
| - reinterpret_cast<char*>(reduce_output.mutable_data_ptr()) + global_offset * bytes_per_element); |
350 |
| - tmp_param.elts_total = tmp_size; |
351 |
| - tensorrt_llm::kernels::customLowPrecisionAllReduce(tmp_param, mType, stream); |
352 |
| - |
353 |
| - global_offset += tmp_size; |
354 |
| - } |
355 |
| - |
356 |
| - if (mOp == AllReduceFusionOp::NONE) |
357 |
| - { |
358 |
| - return {reduce_output}; |
359 |
| - } |
360 |
| - |
361 |
| - // Treat any other patterns as fallback cases. |
362 |
| - return fallbackRunSubsequentOps(input, residual, norm_weight, scale, bias, reduce_output); |
363 |
| - |
364 |
| -#else |
365 |
| - C10_THROW_ERROR(NotImplementedError, "Can't use LOWPRECISION without compile with ENABLE FP8."); |
366 |
| -#endif |
367 |
| - } |
368 |
| - |
369 | 299 | std::vector<torch::Tensor> runFusionAllReduce(torch::Tensor const& input,
|
370 | 300 | torch::optional<torch::Tensor> const& residual, torch::optional<torch::Tensor> const& norm_weight,
|
371 | 301 | torch::optional<torch::Tensor> const& scale, torch::optional<torch::Tensor> const& bias,
|
@@ -664,11 +594,6 @@ class AllreduceOp
|
664 | 594 | TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: UB", rank);
|
665 | 595 | break;
|
666 | 596 | }
|
667 |
| - case AllReduceStrategyType::LOWPRECISION: |
668 |
| - { |
669 |
| - TLLM_LOG_DEBUG("AllReducePlugin strategy for rank %d: LOWPRECISION", rank); |
670 |
| - break; |
671 |
| - } |
672 | 597 | default: break;
|
673 | 598 | }
|
674 | 599 | }
|
@@ -841,21 +766,7 @@ class AllreduceOp
|
841 | 766 | AllReduceStrategyType selectImplementation(
|
842 | 767 | size_t seq_len, size_t message_size, int world_size, nvinfer1::DataType type) noexcept
|
843 | 768 | {
|
844 |
| - |
845 |
| - if (isUsingLowPrecision(message_size)) |
846 |
| - { |
847 |
| - return AllReduceStrategyType::LOWPRECISION; |
848 |
| - } |
849 |
| - else |
850 |
| - { |
851 |
| - if (mStrategy == AllReduceStrategyType::LOWPRECISION) |
852 |
| - { |
853 |
| - mStrategy = AllReduceStrategyType::AUTO; |
854 |
| - } |
855 |
| - } |
856 |
| - |
857 | 769 | // Check that heuristic is only applied when AUTO is set.
|
858 |
| - // Use Auto select |
859 | 770 | bool const is_auto = (mStrategy == AllReduceStrategyType::AUTO);
|
860 | 771 | auto const message_size_bytes = message_size * tensorrt_llm::common::getDTypeSize(type);
|
861 | 772 | auto const max_workspace_size
|
@@ -936,24 +847,6 @@ class AllreduceOp
|
936 | 847 | return strategy;
|
937 | 848 | }
|
938 | 849 |
|
939 |
| - bool isUsingLowPrecision(size_t message_size) const noexcept |
940 |
| - { |
941 |
| - static char* force_low_precision_allreduce_strategy_char |
942 |
| - = std::getenv("FORCE_LOW_PRECISION_ALL_REDUCE_STRATEGY"); |
943 |
| - bool force_low_precision = (force_low_precision_allreduce_strategy_char != nullptr) |
944 |
| - || (mStrategy == AllReduceStrategyType::LOWPRECISION); |
945 |
| - |
946 |
| -#ifdef ENABLE_FP8 |
947 |
| - // Use LowPrecision if PCIe and p2p support and message size is larger than 2MB |
948 |
| - constexpr int LowPrecisionMinMessageSize = 2 * 1024 * 1024; |
949 |
| - return force_low_precision && !mIsNVLINKSupported && mIsP2PSupported |
950 |
| - && message_size >= LowPrecisionMinMessageSize; |
951 |
| -#else |
952 |
| - // Low precision is not available when FP8 is not enabled |
953 |
| - return false; |
954 |
| -#endif |
955 |
| - } |
956 |
| - |
957 | 850 | private:
|
958 | 851 | std::set<int> mGroup;
|
959 | 852 | bool mIsNVLINKSupported;
|
@@ -1073,22 +966,10 @@ TORCH_LIBRARY_FRAGMENT(trtllm, m)
|
1073 | 966 | "int rank,"
|
1074 | 967 | "int nranks,"
|
1075 | 968 | "float eps) -> Tensor[]");
|
1076 |
| - m.def("initialize_static_lowprecision_buffers(Tensor workspace, int tp_size) -> Tensor[]"); |
1077 | 969 | }
|
1078 | 970 |
|
1079 | 971 | TORCH_LIBRARY_IMPL(trtllm, CUDA, m)
|
1080 | 972 | {
|
1081 | 973 | m.impl("allreduce", &torch_ext::allreduce);
|
1082 | 974 | m.impl("moe_allreduce", &torch_ext::moe_allreduce);
|
1083 | 975 | }
|
1084 |
| - |
1085 |
| -TORCH_LIBRARY_IMPL(trtllm, CPU, m) |
1086 |
| -{ |
1087 |
| - m.impl("initialize_static_lowprecision_buffers", |
1088 |
| - [](at::Tensor const& workspace, int64_t tp_size) |
1089 |
| - { |
1090 |
| - tensorrt_llm::kernels::initialize_static_lowprecision_buffers( |
1091 |
| - reinterpret_cast<int64_t*>(workspace.data_ptr()), (int) tp_size); |
1092 |
| - return std::vector<at::Tensor>{}; |
1093 |
| - }); |
1094 |
| -} |
0 commit comments