Skip to content

Commit 3bfba6b

Browse files
authored
Merge pull request Azure#3 from Azure/binyli/refactory
Refactor msccl-tools code. - Add regression test which compares generated results between current commit and main branch - For rank_dag.py, change it to instruciton_dag.py. Add base class `InstructionDAG` with common functions. Some msccl related function move to `MscclInstructionDAG` class - Refactor ir.py. Move common types to types.py - Format some file with black
2 parents fa5accc + 06d7776 commit 3bfba6b

17 files changed

+783
-487
lines changed

.github/workflows/codeql.yml

+3-3
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,12 @@ jobs:
1919

2020
steps:
2121
- name: Checkout repository
22-
uses: actions/checkout@v2
22+
uses: actions/checkout@v4
2323

2424
- name: Initialize CodeQL
25-
uses: github/codeql-action/init@v1
25+
uses: github/codeql-action/init@v3
2626
with:
2727
languages: python
2828

2929
- name: Perform CodeQL Analysis
30-
uses: github/codeql-action/analyze@v1
30+
uses: github/codeql-action/analyze@v3

.github/workflows/tests.yaml

+57-4
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
11
name: Tests
22

33
on:
4+
workflow_dispatch:
5+
inputs:
6+
commit_hash:
7+
description: 'The git commit hash to compare against'
8+
required: true
9+
default: 'fa5accc63ac39840422ff0d6b0ee875706c95e90' # legacy main branch commit hash
410
push:
511
pull_request:
612
branches: [ main ]
@@ -11,20 +17,67 @@ jobs:
1117

1218
strategy:
1319
matrix:
14-
python-version: [3.6, 3.7, 3.8, 3.9]
20+
python-version: ['3.8', '3.9', '3.10']
1521

1622
name: Test with Python ${{ matrix.python-version }}
1723

1824
steps:
19-
- uses: actions/checkout@v2
25+
- uses: actions/checkout@v4
2026
- name: Set up Python ${{ matrix.python-version }}
21-
uses: actions/setup-python@v2
27+
uses: actions/setup-python@v5
2228
with:
2329
python-version: ${{ matrix.python-version }}
24-
- name: Install msccl and dependencies
30+
- name: Install msccl-tools and dependencies
2531
run: |
2632
pip install --upgrade pip
2733
pip install -r requirements.txt
2834
- name: Run tests and check at least 90% coverage
2935
run: |
3036
pytest
37+
38+
compare_outputs:
39+
runs-on: ubuntu-latest
40+
strategy:
41+
matrix:
42+
python-version: ['3.8', '3.9', '3.10']
43+
name: Compare outputs with Python ${{ matrix.python-version }}
44+
45+
steps:
46+
- name: Set up Python ${{ matrix.python-version }}
47+
uses: actions/setup-python@v2
48+
with:
49+
python-version: ${{ matrix.python-version }}
50+
- name: Checkout current branch
51+
uses: actions/checkout@v4
52+
- name: Install msccl-tools and dependencies
53+
run: |
54+
pip install --upgrade pip
55+
pip install -r requirements.txt
56+
- name: Copy test script/config to temp directory
57+
run: |
58+
cp tests/generate_test_results.py $RUNNER_TEMP/
59+
cp tests/configs/test-config.json $RUNNER_TEMP/
60+
- name: generate outputs
61+
run: |
62+
python $RUNNER_TEMP/generate_test_results.py examples/mscclang/ $RUNNER_TEMP/test-config.json $RUNNER_TEMP/tests/pr-outputs/
63+
- name: Checkout specific branch
64+
if: github.event_name == 'workflow_dispatch'
65+
uses: actions/checkout@v4
66+
with:
67+
ref: ${{ github.event.inputs.commit_hash }}
68+
- name: Checkout main branch
69+
uses: actions/checkout@v4
70+
if: github.event_name == 'pull_request' || github.event_name == 'push'
71+
with:
72+
ref: main
73+
- name: Install msccl and dependencies
74+
run: |
75+
pip install --upgrade pip
76+
pip install -r requirements.txt
77+
- name: generate outputs
78+
run: |
79+
python $RUNNER_TEMP/generate_test_results.py examples/mscclang/ $RUNNER_TEMP/test-config.json $RUNNER_TEMP/tests/main-outputs/
80+
- name: Compare outputs
81+
run: |
82+
diff -rw $RUNNER_TEMP/tests/main-outputs/ $RUNNER_TEMP/tests/pr-outputs/
83+

.gitignore

+3
Original file line numberDiff line numberDiff line change
@@ -131,3 +131,6 @@ dmypy.json
131131

132132
# Pyre type checker
133133
.pyre/
134+
135+
# vscode
136+
.vscode/

msccl/autosynth/ndv4_plans.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from msccl.programs.alltoall_a100_yifan import alltoall_hierarchical
88
from msccl.programs.alltoall_a100_8kp1 import alltoall_three_step
99
from msccl.topologies import fully_connected
10-
from msccl.language.ir import ThreadblockPolicy
10+
from msccl.language.types import ThreadblockPolicy
1111

1212
def register_ndv4_plans():
1313

@@ -47,4 +47,4 @@ def ndv4_alltoall_three_step(prog, nodes):
4747
def ndv4_alltoall_hierarchical_config2(prog, nodes):
4848
alltoall_hierarchical(num_nodes=nodes, gpus_per_node=8)
4949

50-
50+

msccl/autosynth/registry.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
import humanfriendly
1010

1111
from msccl.language import MSCCLProgram, ir_to_xml
12-
from msccl.language.ir import ThreadblockPolicy
12+
from msccl.language.types import ThreadblockPolicy
1313
import msccl.language.collectives as lang_collectives
1414
from msccl.topologies import distributed_fully_connected
1515

@@ -62,7 +62,7 @@ def wrapped(machines):
6262
return decorator
6363

6464

65-
def register_msccl_program(local_topology, collective, machine_type, machines=lambda x: True, sizes=None, protocol='Simple',
65+
def register_msccl_program(local_topology, collective, machine_type, machines=lambda x: True, sizes=None, protocol='Simple',
6666
chunk_factor=1, priority=0, collective_obj=None, instances=1, inplace=False, threadblock_policy=ThreadblockPolicy.auto,
6767
interleaved_replication=True, dependence_nop=False):
6868
def decorator(fun):
@@ -81,7 +81,7 @@ def wrapped(machines):
8181
co = lang_collectives.ReduceScatter(topology.num_nodes(), chunk_factor, inplace)
8282
else:
8383
raise RuntimeError(f'No collective_obj in msccl.language.collectives known for "{collective}"')
84-
prog = MSCCLProgram(name, topology, co, instances, protocol, threadblock_policy=threadblock_policy,
84+
prog = MSCCLProgram(name, topology, co, instances, protocol, threadblock_policy=threadblock_policy,
8585
interleaved_replication=interleaved_replication, dependence_nop=dependence_nop)
8686
with prog:
8787
fun(prog, machines)
@@ -96,4 +96,4 @@ def wrapped(machines):
9696
machine_type, machines, sizes, protocol, priority)
9797
# Return the original function to not break other usage
9898
return fun
99-
return decorator
99+
return decorator

0 commit comments

Comments
 (0)