Skip to content

Commit 1c7afa4

Browse files
authored
Fix bugs in MSCCL XML generation (Azure#17)
1. Change reduce_scatter to reducescatter in XML generation to match naming style of other collectives and executor-side implementation. 2. Avoid valid thread block channel being overwritten by 0.
1 parent ea828bf commit 1c7afa4

File tree

4 files changed

+47
-2
lines changed

4 files changed

+47
-2
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
# Copyright (c) Microsoft Corporation.
2+
# Licensed under the MIT License.
3+
4+
import argparse
5+
from msccl.language import *
6+
from msccl.topologies import *
7+
from msccl.language.collectives import ReduceScatter
8+
9+
def allreduce_allpairs(gpus, protocol):
10+
size = gpus
11+
topology = fully_connected(size)
12+
collective = ReduceScatter(gpus, gpus, True)
13+
with MSCCLProgram("reducescatter_pairs", topology, collective, 1, protocol=protocol,
14+
threadblock_policy=ThreadblockPolicy.manual, dependence_nop=True):
15+
16+
# Each rank sends the nth chunk to the nth rank into scratch space
17+
for r1 in range(size):
18+
for r2 in range(size):
19+
if r1 != r2:
20+
index = r2 * size
21+
c = chunk(r1, Buffer.input, index, size=size)
22+
c.copy(r2, 'scratch', sendtb=r2, recvtb=r1)
23+
24+
# Each rank performs a local reduction on the nth chunk
25+
# Utilize 8 threadblocks for this reduction for better parallelism
26+
for r in range(size):
27+
for index in range(0, size * (size-1)):
28+
c = chunk(r, Buffer.input, r*size + (index % size))
29+
c.reduce(chunk(r, 'scratch', index), sendtb=(index % size))
30+
31+
XML()
32+
Check()
33+
34+
parser = argparse.ArgumentParser()
35+
parser.add_argument('num_gpus', type=int, help ='number of gpus')
36+
parser.add_argument('--protocol', type=str, default='LL', choices=['Simple', 'LL128', 'LL'], help='Protocol')
37+
38+
args = parser.parse_args()
39+
40+
allreduce_allpairs(args.num_gpus, args.protocol)

msccl/language/collectives.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,7 @@ def get_buffer_index(self, rank, buffer, index):
174174
class ReduceScatter(Collective):
175175
def __init__(self, num_ranks, chunk_factor, inplace):
176176
Collective.__init__(self, num_ranks, chunk_factor, inplace)
177-
self.name = "reduce_scatter"
177+
self.name = "reducescatter"
178178

179179
def init_buffers(self):
180180
rank_buffers = []

msccl/language/tb_assignment.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ def manual_assign_tbs(rank_dag):
3030
tb = rank_dag.tbs[rank][tbid]
3131
if _verify_tb_op_compatible(tb, op):
3232
tb.ops.append(op)
33-
tb.channel = op.channel if op.channel != -1 else 0
33+
if tb.channel == -1:
34+
tb.channel = op.channel if op.channel != -1 else 0
3435
tb.send = op.dst.rank if op.is_send() else tb.send
3536
tb.recv = op.src.rank if op.is_recv() else tb.recv
3637
op.step = len(tb.ops)-1

tests/configs/test-config.json

+4
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,10 @@
6363
"filename": "pipeline_a100_ring.py",
6464
"args": ["8", "4", "2"]
6565
},
66+
{
67+
"filename": "reducescatter_allpairs.py",
68+
"args": ["8"]
69+
},
6670
{
6771
"filename": "mscclpp/allreduce_a100_allpairs_packet_mscclpp.py",
6872
"args": ["8", "8"]

0 commit comments

Comments
 (0)