Skip to content

Commit 8708802

Browse files
syed-ahmedpytorchmergebot
authored andcommitted
Enables configuration of NCCL communicators (pytorch#97394)
NCCL 2.17+ introduces some user configurable parameters for NCCL communicators using [ncclConfig_t](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#c.ncclConfig_t) datatype and [ncclCommInitRankConfig](https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/comms.html#ncclcomminitrankconfig). This PR enables that feature. A user can tune the parameters as follows: ``` import torch.distributed as dist nccl_options = dist.ProcessGroupNCCL.Options() nccl_options.config.max_ctas = 32 nccl_options.config.min_ctas = 8 nccl_options.config.cga_cluster_size = 2 dist.init_process_group(backend='nccl', init_method='env://', pg_options=nccl_options) my_group = dist.new_group(pg_options=nccl_options) ``` The default values of these parameters are what is initialized by `NCCL_CONFIG_INITIALIZER`. Only for DistributedDataParallel, this PR sets the default value of cga_cluster_size to 2 (a heuristic that works well especially for DDP workloads). Tuning these parameters can lead to improvement in end-to-end performance, since it affects the communication-computation overlap for NCCL kernels. CC: @ptrblck @kwen2501 Pull Request resolved: pytorch#97394 Approved by: https://github.com/kwen2501
1 parent 3cae6d2 commit 8708802

File tree

5 files changed

+106
-15
lines changed

5 files changed

+106
-15
lines changed

test/distributed/test_c10d_nccl.py

Lines changed: 33 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
import math
55
import os
66
import random
7+
import re
78
import signal
89
import sys
910
import tempfile
@@ -2713,12 +2714,7 @@ def test_sequence_num_set_nccl_new_group(self):
27132714
torch.cuda.set_device(self.rank)
27142715
self._test_sequence_num_set_new_group(backend="nccl")
27152716

2716-
@requires_nccl()
2717-
@skip_if_lt_x_gpu(2)
2718-
def test_pass_nccl_options_high_priority_stream(self):
2719-
pg_opts = c10d.ProcessGroupNCCL.Options()
2720-
pg_opts.is_high_priority_stream = True
2721-
2717+
def _test_pass_nccl_options(self, pg_opts):
27222718
store = c10d.FileStore(self.file_name, self.world_size)
27232719
# Test init_process_group accepts options
27242720
dist.init_process_group(
@@ -2737,6 +2733,37 @@ def test_pass_nccl_options_high_priority_stream(self):
27372733
expected_tensor = torch.tensor([3] * 10).cuda(self.rank)
27382734
self.assertEqual(expected_tensor, t)
27392735

2736+
@requires_nccl()
2737+
@skip_if_lt_x_gpu(2)
2738+
def test_pass_nccl_options_high_priority_stream(self):
2739+
pg_opts = c10d.ProcessGroupNCCL.Options()
2740+
pg_opts.is_high_priority_stream = True
2741+
self._test_pass_nccl_options(pg_opts)
2742+
2743+
@requires_nccl()
2744+
@requires_nccl_version((2, 17), "Need NCCL 2.17+ for configuring NCCL communicators")
2745+
@skip_if_lt_x_gpu(2)
2746+
def test_pass_nccl_options_config(self):
2747+
pg_opts = c10d.ProcessGroupNCCL.Options()
2748+
pg_opts.config.max_ctas = 4
2749+
pg_opts.config.min_ctas = 2
2750+
pg_opts.config.cga_cluster_size = 2
2751+
nccl_debug_file = tempfile.NamedTemporaryFile()
2752+
os.environ["NCCL_DEBUG"] = "INFO"
2753+
os.environ["NCCL_DEBUG_FILE"] = nccl_debug_file.name
2754+
2755+
# Tests functionality when passing nccl config
2756+
self._test_pass_nccl_options(pg_opts)
2757+
2758+
# Tests if comms were configured
2759+
nccl_debug_file_content = nccl_debug_file.read()
2760+
max_ctas = re.search(rb'Max CTAs.*(\d+)|$', nccl_debug_file_content).group(1)
2761+
min_ctas = re.search(rb'Min CTAs.*(\d+)|$', nccl_debug_file_content).group(1)
2762+
cga_cluster_size = re.search(rb'CGA cluster.*(\d+)|$', nccl_debug_file_content).group(1)
2763+
self.assertEqual(pg_opts.config.max_ctas, int(max_ctas))
2764+
self.assertEqual(pg_opts.config.min_ctas, int(min_ctas))
2765+
self.assertEqual(pg_opts.config.cga_cluster_size, int(cga_cluster_size))
2766+
27402767
@requires_nccl()
27412768
@skip_if_lt_x_gpu(4)
27422769
def test_nccl_barrier(self):

torch/csrc/distributed/c10d/NCCLUtils.hpp

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,12 @@
5252
#define ENABLE_NCCL_PREMUL_SUM_SUPPORT
5353
#endif
5454

55+
#if defined(NCCL_MAJOR) && (NCCL_MAJOR == 2) && defined(NCCL_MINOR) && (NCCL_MINOR >= 17)
56+
#define NCCL_HAS_COMM_CTA_CGA
57+
#elif defined(NCCL_MAJOR) && (NCCL_MAJOR >= 3)
58+
#define NCCL_HAS_COMM_CTA_CGA
59+
#endif
60+
5561
// Macro to throw on a non-successful NCCL return value.
5662
#define C10D_NCCL_CHECK(cmd, failureReason) \
5763
do { \
@@ -179,22 +185,34 @@ class NCCLComm {
179185
int rank,
180186
ncclUniqueId commId) {
181187
auto comm = std::make_shared<NCCLComm>();
182-
#ifndef NCCL_HAS_COMM_NONBLOCKING
183188
C10D_NCCL_CHECK(
184189
ncclCommInitRank(&(comm->ncclComm_), numRanks, commId, rank), c10::nullopt);
185-
#else
186-
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
187-
if (nccl_use_nonblocking()) {
188-
config.blocking = 0;
189-
}
190-
C10D_NCCL_CHECK_TIMEOUT(
191-
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt);
192-
#endif
193190
comm->ncclId_ = commId;
194191
comm->rank_ = rank;
195192
return comm;
196193
}
197194

195+
#ifdef NCCL_HAS_COMM_NONBLOCKING
196+
static std::shared_ptr<NCCLComm> create(
197+
int numRanks,
198+
int rank,
199+
ncclUniqueId commId,
200+
ncclConfig_t& config) {
201+
auto comm = std::make_shared<NCCLComm>();
202+
if (nccl_use_nonblocking()) {
203+
config.blocking = 0;
204+
C10D_NCCL_CHECK_TIMEOUT(
205+
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), comm->ncclComm_, c10::nullopt);
206+
} else {
207+
C10D_NCCL_CHECK(
208+
ncclCommInitRankConfig(&(comm->ncclComm_), numRanks, commId, rank, &config), c10::nullopt);
209+
}
210+
comm->ncclId_ = commId;
211+
comm->rank_ = rank;
212+
return comm;
213+
}
214+
#endif
215+
198216
ncclUniqueId getNcclId() {
199217
return ncclId_;
200218
}

torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1156,7 +1156,11 @@ std::vector<std::shared_ptr<NCCLComm>>& ProcessGroupNCCL::getNCCLComm(
11561156
int deviceIndex = devices[i].index();
11571157

11581158
gpuGuard.set_index(deviceIndex);
1159+
#ifdef NCCL_HAS_COMM_NONBLOCKING
1160+
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID, options_->config);
1161+
#else
11591162
ncclComms[i] = NCCLComm::create(numRanks, rank, ncclID);
1163+
#endif
11601164

11611165
// Creates the NCCL streams
11621166
streamVal.push_back(

torch/csrc/distributed/c10d/ProcessGroupNCCL.hpp

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,11 @@ class TORCH_API ProcessGroupNCCL : public Backend {
279279

280280
// Schedule NCCL operations on high priority CUDA streams
281281
bool is_high_priority_stream;
282+
283+
#ifdef NCCL_HAS_COMM_NONBLOCKING
284+
// Configure ranks
285+
ncclConfig_t config = NCCL_CONFIG_INITIALIZER;
286+
#endif
282287
};
283288

284289
// If you wish to create multiple process groups, each with a potentially

torch/csrc/distributed/c10d/init.cpp

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2135,6 +2135,23 @@ options :class:`~torch.distributed.ProcessGroupNCCL.Options`).
21352135
.def_property_readonly(
21362136
"is_ucc_available", &::c10d::ProcessGroupNCCL::isUCCAvailable);
21372137

2138+
#ifdef NCCL_HAS_COMM_CTA_CGA
2139+
py::class_<ncclConfig_t>(
2140+
processGroupNCCL,
2141+
"NCCLConfig",
2142+
R"(
2143+
ncclConfig_t data type for configuring NCCL communicators.
2144+
See https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
2145+
for details.
2146+
)")
2147+
.def(py::init<>())
2148+
.def_readwrite("blocking", &ncclConfig_t::blocking)
2149+
.def_readwrite("cga_cluster_size", &ncclConfig_t::cgaClusterSize)
2150+
.def_readwrite("min_ctas", &ncclConfig_t::minCTAs)
2151+
.def_readwrite("max_ctas", &ncclConfig_t::maxCTAs)
2152+
.def_readwrite("net_name", &ncclConfig_t::netName);
2153+
#endif
2154+
21382155
intrusive_ptr_class_<::c10d::ProcessGroupNCCL::Options>(
21392156
processGroupNCCL,
21402157
"Options",
@@ -2148,19 +2165,39 @@ ProcessGroup options for the NCCL backend
21482165
to prioritize NCCL kernels when there are compute kernels waiting.
21492166
Default is False.
21502167
2168+
Attributes:
2169+
config (NCCLConfig): configures NCCL communicators (only avaiable for
2170+
builds using NCCL 2.17+). This can be used to improve
2171+
communication-computation overlap for NCCL kernels by tuning
2172+
available parameters in the config. See
2173+
https://docs.nvidia.com/deeplearning/nccl/user-guide/docs/api/types.html#ncclconfig-t
2174+
for details.
2175+
21512176
Example::
21522177
>>> import torch.distributed as dist
21532178
>>>
21542179
>>> nccl_options = dist.ProcessGroupNCCL.Options(is_high_priority_stream=True)
2180+
>>> # For builds using NCCL 2.17+, configure communicators
2181+
>>> nccl_options.config.cga_cluster_size = 2
2182+
>>> nccl_options.config.max_ctas = 4
2183+
>>> nccl_options.config.min_ctas = 2
21552184
>>> # initialize a nccl process group with the options just created
21562185
>>> dist.init_process_group("nccl", pg_options=nccl_options)
21572186
)")
21582187
.def(py::init<bool>(), py::arg("is_high_priority_stream") = false)
2188+
#ifdef NCCL_HAS_COMM_CTA_CGA
2189+
.def_readwrite(
2190+
"is_high_priority_stream",
2191+
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream)
2192+
.def_readwrite("config", &::c10d::ProcessGroupNCCL::Options::config);
2193+
#else
21592194
.def_readwrite(
21602195
"is_high_priority_stream",
21612196
&::c10d::ProcessGroupNCCL::Options::is_high_priority_stream);
21622197
#endif
21632198

2199+
#endif
2200+
21642201
#ifdef USE_C10D_MPI
21652202
auto processGroupMPI =
21662203
intrusive_ptr_no_gil_destructor_class_<::c10d::ProcessGroupMPI>(

0 commit comments

Comments
 (0)