-
Notifications
You must be signed in to change notification settings - Fork 62
Merge mscclpp-lang to mscclpp project #442
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
37 commits
Select commit
Hold shift + click to select a range
9376fb3
try to add mscclpp-lang
Binyang2014 671008c
WIP
Binyang2014 46f8dcd
Fix
Binyang2014 740e008
WIP
Binyang2014 35eeed9
WIP
Binyang2014 84e1f41
WIP
Binyang2014 8378af6
Merge branch 'main' into binyli/mscclpp-lang
Binyang2014 27e82b9
WIP
Binyang2014 a0eed7b
add examples
Binyang2014 b8a6e41
WIP
Binyang2014 5ee7cc6
update
Binyang2014 b6ad968
WIP
Binyang2014 8800ff1
fix lint
Binyang2014 e2c4df5
update
Binyang2014 d4bfa03
WIP
Binyang2014 5ed3210
fix
Binyang2014 05a5925
fix
Binyang2014 320d258
update
Binyang2014 254aad6
WIP
Binyang2014 4cfa765
WIP
Binyang2014 d1e9872
Merge branch 'main' into binyli/mscclpp-lang
Binyang2014 a249c04
update
Binyang2014 c92392f
update
Binyang2014 d9d8152
fix comment
Binyang2014 92eeab9
Merge branch 'main' into binyli/mscclpp-lang
Binyang2014 98ba12a
add broadcast
Binyang2014 a799959
Merge branch 'main' into binyli/mscclpp-lang
Binyang2014 607d6c3
add doc
Binyang2014 a9bd223
Merge branch 'binyli/mscclpp-lang' of https://github.com/microsoft/ms…
Binyang2014 0f3f433
update doc link
Binyang2014 7a3427e
address comments
Binyang2014 e179646
add comments
Binyang2014 11f3723
update
Binyang2014 2cfc915
address comments
Binyang2014 7afadee
address comments
Binyang2014 613bdd1
fix
Binyang2014 4585515
update ci
Binyang2014 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
name: MSCCLPPLang | ||
|
||
on: | ||
pull_request: | ||
branches: | ||
- main | ||
- release/* | ||
|
||
jobs: | ||
compare-diffs: | ||
runs-on: 'ubuntu-latest' | ||
container: | ||
image: ghcr.io/microsoft/mscclpp/mscclpp:base-dev-${{ matrix.version }} | ||
|
||
strategy: | ||
fail-fast: false | ||
matrix: | ||
version: [ 'cuda11.8', 'cuda12.2' ] | ||
|
||
steps: | ||
- uses: actions/checkout@v4 | ||
- name: Install mscclpp | ||
run: | | ||
CMAKE_ARGS="-DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON" pip3 install . | ||
|
||
- name: Copy test script/config to temp directory | ||
run: | | ||
cp python/test/test_generate_mscclpp_lang_result.py $RUNNER_TEMP/ | ||
cp python/test/configs/mscclpp_lang_test_config.json $RUNNER_TEMP/ | ||
- name: generate outputs | ||
run: | | ||
python3 $RUNNER_TEMP/test_generate_mscclpp_lang_result.py python/examples/ $RUNNER_TEMP/mscclpp_lang_test_config.json $RUNNER_TEMP/tests/pr-outputs/ | ||
- name: Checkout main branch | ||
uses: actions/checkout@v4 | ||
if: github.event_name == 'pull_request' || github.event_name == 'push' | ||
with: | ||
ref: main | ||
- name: Install msccl and dependencies | ||
run: | | ||
CMAKE_ARGS="-DMSCCLPP_BYPASS_GPU_CHECK=ON -DMSCCLPP_USE_CUDA=ON" pip3 install . | ||
- name: generate outputs | ||
run: | | ||
python3 $RUNNER_TEMP/test_generate_mscclpp_lang_result.py python/examples/ $RUNNER_TEMP/mscclpp_lang_test_config.json $RUNNER_TEMP/tests/main-outputs/ | ||
- name: Compare outputs | ||
run: | | ||
diff -rw $RUNNER_TEMP/tests/main-outputs/ $RUNNER_TEMP/tests/pr-outputs/ |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# MSCCL++ DSL | ||
## MSCCLPPLang Introduction | ||
MSCCLPPLang is a Python moudule for writing high-performance commnunication algorithms. It is designed to be easy to use and efficient, while providing a high-level interface for writing communication algorithms. MSCCLPPLang program will be compiled to json based execution plan, which can be executed by MSCCL++ executor. | ||
|
||
## How to use MSCCLPPLang | ||
### Install mscclpp package | ||
```bash | ||
git clone https://github.com/microsoft/mscclpp.git | ||
cd mscclpp | ||
pip install . | ||
``` | ||
|
||
### Import mscclpp language module | ||
```python | ||
import mscclpp.language * | ||
from mscclpp.language.types import ChannelType, ReplicationPolicy | ||
from mscclpp.language.collectives import AllGather | ||
|
||
instances = 1 | ||
size = gpus | ||
collective = AllGather(size, chunk_factor=1, inplace=True) | ||
with MSCCLPPProgram( | ||
"allgather", | ||
collective, | ||
size, | ||
instances, | ||
protocol="Simple", | ||
replication_policy=ReplicationPolicy.interleaved, | ||
): | ||
pass | ||
``` | ||
|
||
## How MSCCLPPLang Works | ||
MSCCLPPLang provides a high-level interface for writing communication algorithms. We treat the communication algorithm as a graph, where the nodes are the data and the edges are the communication operations. The graph is represented as a Python program, which is compiled to a json based execution plan. | ||
|
||
### Core Concepts | ||
|
||
#### MSCCLPPProgram | ||
A MSCCLPPProgram provides the context to write MSCCLPPLang program, which can be initialized with `with` statement in Python. Its parameters include: | ||
|
||
- `name`: Name of this program. | ||
- `collective`: Collective type of this program, should be from `mscclpp.language.collectives`. | ||
- `instances`: Number of parallel instances of this program. Please see the [Instance](#instance) section for more details. | ||
- `protocol`: Data transmission protocol used in this program, can be `LL` or `Simple`. Optional, default is `Simple`. | ||
- `instr_fusion`: Whether low-level instruction fusion is enabled. Optional, default is `True`. | ||
- `replication_policy`: Data replication policy, should be from `mscclpp.language.types.ReplicationPolicy`. Optional, default is `duplicated`. Please see the [Instance](#instance) section for more details. | ||
- `num_threads_per_block`: Thread block size. Optional, default is `1024`. | ||
- `use_double_scratch_buffer`: Whether requires double scratch buffer during execution. Optional, default is `False`. | ||
|
||
### Collective: | ||
A collective is a communication operation that involves multiple GPUs. We provide a set of collective operations for users to utilize. For example, the `AllGather` operation gathers data from all GPUs to all GPUs. To instantiate a collective, the user needs to specify the number of ranks, the chunk factor (how many chunks the input buffer will be split into), and whether the operation is in-place. | ||
|
||
#### Chunk | ||
A chunk is a piece of data that is sent between GPUs. It is the basic unit of data in MSCCLPPLang. Chunk can be a piece of data from input buffer, output buffer or intermediate buffer. | ||
Example of creating a chunk: | ||
```python | ||
c = chunk(rank, Buffer.input, index, size) | ||
``` | ||
- rank: the rank of the GPU that the chunk belongs to. | ||
- buffer: the buffer that the chunk belongs to. It can be Buffer.input, Buffer.output or Buffer.scratch. | ||
- index: the index of the chunk in the buffer. | ||
- size: the number of unit chunks. | ||
|
||
Assume we split the input data in the buffer into 4 chunks. On GPU rank 0, we can retrieve the chunks from indices 0 to 2 using the following command: | ||
```python | ||
c = chunk(0, Buffer.input, 0, 2) | ||
``` | ||
|
||
#### Operation | ||
The operation can only be applied to the chunks. We provide a set of communications operations for the users to use. For example, the `put` operation is used to send the data from one GPU to another GPU. The `get` operation is used to receive the data from another GPU. | ||
|
||
***Please notice***: MSCCLPPLang only provides one-sided communication operations. The user needs to make sure that the data is ready to be sent or received before calling the communication operations. Also we provides `wait/signal` operations to synchronize the communication across GPUs. | ||
|
||
#### Channel | ||
A channel is a communication channel between two GPUs. It is used to send and receive data between GPUs. We supports three types of channel: `ChannelType.sm`, `ChannelType.proxy` and `ChannelType.nvls`. | ||
|
||
`ChannelType.sm` is used for communication between GPUs on the same node. This channel uses GPU processors to transfer data. | ||
|
||
`ChannelType.proxy` is used for communication between GPUs, whether they are on different nodes or the same node. This channel will offload the data transfer to CPU processors, which can provide better throughput compared to `ChannelType.sm`. However, this comes at the cost of higher latency compared to `ChannelType.sm`. | ||
|
||
`ChannelType.nvls` is used for communication between GPUs on the same node. This feature offloads the data processing task to the switch, requiring specific hardware support. Refer [nvdia documentation](https://www.nvidia.com/en-us/data-center/nvlink/) for more details. | ||
|
||
#### Thread Block | ||
We can assign operations to a thread block. The thread block is a group of threads that are executed together on the GPU. In the operation function, we can specify the thread block that the operation belongs to via `sendtb` or `recvtb` parameter. | ||
|
||
#### Instance | ||
An instance is a parallel execution of the program. For example, if a collective algorithm is designed to run on `n` chunks with `m` thread blocks, setting the instance to 2 will run the algorithm on `2n` chunks with `2m` thread blocks. Serveral replication policies are supported, including `duplicated` and `interleaved`. | ||
- `duplicated`: Each chunk is split into smaller parts based on the number of instances, duplicating the same instructions for all parts. For example, ChunkA is split into ChunkA0 and ChunkA1, while ChunkB is split into ChunkB0 and ChunkB1. Both ChunkA0 and ChunkA1 belong to Instance 0, and both ChunkB0 and ChunkB1 belong to Instance 1. | ||
- `interleaved`: Assign chunks to instances in an interleaved manner. For example, ChunkA and ChunkB are split into to ChunkA0, ChunkA1, ChunkB0, and ChunkB1. ChunkA0 and ChunkB0 belong to Instance 0, while ChunkA1 and ChunkB1 belong to Instance 1. | ||
|
||
#### Instruction Fusion | ||
MSCCLPPLang provides the instruction fusion mechanism to fuse multiple operations into a single kernel. This can reduce the overhead of launching multiple instructions. When users create the MSCCLPPLang program, they can specify the `instr_fusion` parameter to enable the instruction fusion. By default, the instruction fusion is enabled. | ||
|
||
## MSCCLPPLang APIs | ||
|
||
### Basic APIs | ||
- `chunk(rank, buffer, index, size)`: create a chunk. | ||
- `put(self, dst, buffer, index, sendtb, chan_type)`: send the data from one GPU to another GPU. User can specify the index of the chunk in the destination buffer, the sendtb and the channel type. | ||
- `get(self, src, buffer, index, recvtb, chan_type)`: receive the data from another GPU. User can specify the index of the chunk in the destination buffer, the recvtb and the channel type. | ||
- `signal(self, dst, buffer, index, sendtb, chan_type)`: send a signal to another GPU. | ||
- `wait(self, src, buffer, index, recvtb, chan_type)`: wait for a signal from another GPU. | ||
- `flush(self, dst, buffer, index, sendtb, chan_type)`: flush the data in the buffer to the destination GPU. This is used to make sure the data is sent to the destination GPU. | ||
- `copy(self, dst, buffer, index, sendtb)`: copy the data from one buffer to another buffer in the same GPU. | ||
- `reduce(self, other_chunkref, recvtb, channel_type)`: Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref | ||
|
||
### Packet APIs | ||
Packet APIs are used when user wants to use LL algorithm. The packet APIs are similar to the basic APIs, it will packet the data and flags into a packet and send the packet to the destination GPU. The destination GPU will unpack the packet and get the data and flags. So no synchronization is needed when using packet APIs. (`ChannelType.nvls` does not support packet APIs) | ||
- `packet_put(self, dst, buffer, index, sendtb, chan_type)`: send the data from one GPU to another GPU using packet. | ||
- `copy_packet(self, dst, buffer, index, sendtb)`: copy the data from one buffer to another buffer in the same GPU using packet. | ||
- `reduce_packet(self, other_chunkref, recvtb)`: Reduces the chunk(s) referenced by other_chunkref into the chunk(s) referenced by this chunkref using packet. | ||
|
||
|
||
### Examples | ||
We provide several examples demonstrating how to use the MSCCL++ DSL to write communication collective algorithms. For more details, please refer to the [examples](https://github.com/microsoft/mscclpp/tree/main/mscclpp-lang/python/examples) folder. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Binyang2014 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import argparse | ||
from mscclpp.language import * | ||
from mscclpp.language.buffer import Buffer | ||
from mscclpp.language.collectives import AllGather | ||
from mscclpp.language.types import ChannelType, ReplicationPolicy | ||
|
||
|
||
def allgather_test(gpus, instances): | ||
""" | ||
Demonstrates how to use barrier in the MSCCL++ DSL with an allgather collective. | ||
This example uses an allpairs algorithm for the allgather operation. | ||
Steps: | ||
1. Each rank sends a chunk to all other ranks' output buffers and copies the chunk to its own output buffer. | ||
2. A barrier is called to synchronize the send and copy operations, and signal peers that the data has been sent. | ||
3. Wait for all the chunks from other ranks to be received. | ||
""" | ||
size = gpus | ||
collective = AllGather(size, 1, False) | ||
with MSCCLPPProgram( | ||
"allgather_with_barrier", | ||
collective, | ||
size, | ||
instances, | ||
protocol="Simple", | ||
replication_policy=ReplicationPolicy.interleaved, | ||
): | ||
for n in range(gpus): | ||
c = chunk(n, Buffer.input, 0, 1) | ||
for peer in range(gpus): | ||
if n != peer: | ||
c.put(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm) | ||
else: | ||
c.copy(n, Buffer.output, n, sendtb=peer) | ||
# explicit barrier | ||
r = rank(n) | ||
r.barrier(tb_list=list(range(gpus))) | ||
for peer in range(gpus): | ||
if n != peer: | ||
c.signal(peer, Buffer.output, n, sendtb=peer, chan_type=ChannelType.sm) | ||
|
||
for n in range(gpus): | ||
for peer in range(gpus): | ||
c = chunk(n, Buffer.output, peer, 1) | ||
if n != peer: | ||
c.wait(peer, Buffer.input, peer, recvtb=peer, chan_type=ChannelType.sm) | ||
|
||
Json() | ||
Check() | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("num_gpus", type=int, help="number of gpus") | ||
parser.add_argument("instances", type=int, help="number of instances") | ||
args = parser.parse_args() | ||
allgather_test(args.num_gpus, args.instances) | ||
Binyang2014 marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
# Copyright (c) Microsoft Corporation. | ||
# Licensed under the MIT License. | ||
|
||
import argparse | ||
from mscclpp.language import * | ||
from mscclpp.language.collectives import AllReduce | ||
from mscclpp.language.buffer import Buffer | ||
|
||
|
||
def allreduce_allpairs(gpus, instances, protocol): | ||
""" | ||
Demonstrate allreduce with all pairs algorithm using put semantics. | ||
Steps: | ||
1. Sync all ranks to ensure the data is ready. | ||
2. Each rank reads chunks from all peers and reduces the data. | ||
3. Put the reduced data to all peers. | ||
4. Sync all ranks to ensure the data is received. | ||
""" | ||
size = gpus | ||
chunksperloop = gpus * gpus | ||
collective = AllReduce(size, chunksperloop, True) | ||
with MSCCLPPProgram("allreduce_pairs", collective, size, instances, protocol=protocol): | ||
for rank in range(size): | ||
for tb in range(size): | ||
index = rank * size | ||
c = chunk(rank, Buffer.input, index + tb) | ||
# step1 make sure the data is ready | ||
for nghr in range(size): | ||
peer_index = nghr * size | ||
if rank != nghr: | ||
# signal peer the buffer is ready | ||
c_peer = chunk(rank, Buffer.input, peer_index + tb) | ||
c_peer.signal(nghr, Buffer.input, peer_index + tb, sendtb=tb) | ||
for nghr in range(size): | ||
if rank != nghr: | ||
c.wait(nghr, Buffer.input, index + tb, recvtb=tb) | ||
# step2 reduce the chunks and send to peers | ||
for nghr in range(size): | ||
if rank != nghr: | ||
c.reduce(chunk(nghr, Buffer.input, index + tb), recvtb=tb) | ||
for nghr in range(size): | ||
if rank != nghr: | ||
c.put(nghr, Buffer.input, index + tb, sendtb=tb) | ||
# step3 signal the peers buffer is ready | ||
for nghr in range(size): | ||
if rank != nghr: | ||
c.signal(nghr, Buffer.input, index + tb, sendtb=tb) | ||
for nghr in range(size): | ||
if rank != nghr: | ||
peer_index = nghr * size | ||
c_peer = chunk(rank, Buffer.input, peer_index + tb) | ||
c_peer.wait(nghr, Buffer.input, peer_index + tb, recvtb=tb) | ||
|
||
Json() | ||
Check() | ||
|
||
|
||
parser = argparse.ArgumentParser() | ||
parser.add_argument("num_gpus", type=int, help="number of gpus") | ||
parser.add_argument("instances", type=int, help="number of instances") | ||
parser.add_argument("--protocol", type=str, default="Simple", choices=["Simple"], help="Protocol") | ||
|
||
args = parser.parse_args() | ||
|
||
allreduce_allpairs(args.num_gpus, args.instances, args.protocol) |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.