Skip to content

Commit 7eb3ff7

Browse files
authored
Supporting New Packet Kernel Operation at Executor (#677)
This PR introduces three new operations to enhance flexibility and performance at executor. One operation can be invoked directly via the DSL API and two operations are created through fusion of existing operations, reducing overhead and improving efficiency. 1. Port Channel Put Packet (Direct DSL API Call): Sends data from pkt format to the remote side in pkt format via the port channel. Both source and destination buffers must be scratch. 2. Reduce Copy Packet (Fusion): Reduce Packet+Copy Packet=Reduce Copy Packet Triggered when the destination buffer of Reduce Packet matches the source buffer of Copy Packet. Purpose: Combine reduction and copy into a single step for better performance. 3. Reduce Copy Send Packet (Fusion): Reduce Copy Packet+Put Packet=Reduce Copy Send Packet (when dst buffer of Reduce Copy Packet matches src buffer of Put Packet) Reduce Copy Packet+Read Put Packet=Reduce Copy Send Packet (when dst pkt buffer of Reduce Copy Packet matches src buffer of Read Put Packet) Purpose: Combine reduction, copy, and send operations into one optimized pipeline. Fusion Diagram Reduce Packet + Copy Packet → Reduce Copy Packet Reduce Copy Packet + Put Packet → Reduce Copy Send Packet Reduce Copy Packet + Read Put Packet → Reduce Copy Send Packet Beyond this, this PR adjust the AllReduce 2 Node algorithm: Message Size | Latency (µs) 1K | 15.34 2K | 15.88 4K | 15.71 8K | 16.01 16K | 15.88 32K | 16.21 64K | 16.90 128K | 18.24 256K | 20.39 512K | 25.26 1M | 32.74 2M | 53.64
1 parent eb20278 commit 7eb3ff7

File tree

12 files changed

+376
-31
lines changed

12 files changed

+376
-31
lines changed

include/mscclpp/npkit/npkit_event.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,6 @@
4141
#define NPKIT_EVENT_KERNEL_ALLREDUCE_EXIT 0x1C
4242

4343
#define NPKIT_EVENT_EXECUTOR_OP_BASE_ENTRY 0x1D
44-
#define NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT 0x37
44+
#define NPKIT_EVENT_EXECUTOR_OP_BASE_EXIT 0x39
4545

4646
#endif

python/mscclpp/__main__.py

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,10 @@
1212

1313
default_algo_configs = [
1414
{
15-
"filename": "allreduce_2nodes.json",
15+
"filename": "allreduce_2nodes_1K_64K.json",
1616
"function": def_algo.allreduce_2nodes,
1717
"spec": AlgoSpec(
18-
name="allreduce_2nodes",
18+
name="allreduce_2nodes_1K_64K",
1919
collective=AllReduce(16, 1, True),
2020
nranks_per_node=8,
2121
world_size=16,
@@ -27,11 +27,32 @@
2727
reuse_resources=True,
2828
use_double_scratch_buffer=True,
2929
min_message_size=1 << 10,
30+
max_message_size=64 << 10,
31+
tags={"default": 1},
32+
),
33+
"additional_kwargs": {"thread_block_group_size": 1},
34+
},
35+
{
36+
"filename": "allreduce_2nodes_128K_2M.json",
37+
"function": def_algo.allreduce_2nodes,
38+
"spec": AlgoSpec(
39+
name="allreduce_2nodes_128K_2M",
40+
collective=AllReduce(16, 1, True),
41+
nranks_per_node=8,
42+
world_size=16,
43+
in_place=True,
44+
instances=1,
45+
protocol="LL",
46+
auto_sync=False,
47+
num_threads_per_block=1024,
48+
reuse_resources=True,
49+
use_double_scratch_buffer=True,
50+
min_message_size=128 << 10,
3051
max_message_size=2 << 20,
3152
tags={"default": 1},
3253
),
33-
"additional_args": [4],
34-
}
54+
"additional_kwargs": {"thread_block_group_size": 4},
55+
},
3556
]
3657

3758

@@ -46,12 +67,12 @@ def create_default_plans():
4667
filename = config["filename"]
4768
func = config["function"]
4869
spec = config["spec"]
49-
additional_args = config.get("additional_args", [])
70+
additional_kwargs = config.get("additional_kwargs", {})
5071
plan_path = os.path.join(plan_dir, filename)
5172

5273
try:
53-
if additional_args:
54-
prog = func(spec, *additional_args)
74+
if additional_kwargs:
75+
prog = func(spec, **additional_kwargs)
5576
else:
5677
prog = func(spec)
5778

python/mscclpp/language/channel.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -682,6 +682,55 @@ def put_with_signal_and_flush(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int)
682682

683683
get_program().add_operation(self.src_rank, tb, op)
684684

685+
def put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
686+
"""Transfer data from local buffer to remote scratch buffer in packet format.
687+
688+
Performs a specialized put operation that transfers data from the source rank's buffer
689+
to the destination rank's scratch buffer in packet format through the port channel.
690+
The destination chunk must be a scratch buffer.
691+
692+
Args:
693+
dst_chunk (Chunk): The destination scratch chunk on the destination rank.
694+
src_chunk (Chunk): The source chunk on the source rank (any buffer type).
695+
tb (int): The thread block ID that will execute this operation.
696+
697+
Raises:
698+
RuntimeError: If chunk ranks don't match channel configuration, if destination
699+
chunk is not a scratch buffer, or if chunk sizes don't match.
700+
701+
Example:
702+
>>> channel.put_packets(dst_chunk, src_chunk, tb=0)
703+
"""
704+
if src_chunk.rank != self.src_rank:
705+
raise RuntimeError(
706+
f"Source chunk rank {src_chunk.rank} does not match current channel source rank {self.src_rank}."
707+
)
708+
if dst_chunk.rank != self.dst_rank:
709+
raise RuntimeError(
710+
f"Dst chunk rank {dst_chunk.rank} does not match current channel dst rank {self.dst_rank}."
711+
)
712+
if dst_chunk.buffer != BufferType.scratch:
713+
raise RuntimeError(f"Destination chunk must be of type scratch.")
714+
if dst_chunk.size != src_chunk.size:
715+
raise RuntimeError(
716+
f"Destination chunk size {dst_chunk.size} does not match source chunk size {src_chunk.size}."
717+
)
718+
719+
remote_chunk = RemoteBuffer(src_chunk.rank, dst_chunk.rank, dst_chunk.buffer, self.channel_type)
720+
tb_chunk_id = get_program().setup_remote_chunk(self.src_rank, tb, remote_chunk, self.channel_type)
721+
tb_channel_ids = get_program().setup_channel(tb, self)
722+
723+
op = PutOperation(
724+
src_buff=[LocalChunk(src_chunk.buffer, src_chunk.index, src_chunk.size)],
725+
dst_buff=[RemoteChunk(dst_chunk.buffer, dst_chunk.index, dst_chunk.size, tb_chunk_id)],
726+
channel_ids=tb_channel_ids,
727+
channel_type=self.channel_type,
728+
from_packet=False,
729+
to_packet=True,
730+
)
731+
732+
get_program().add_operation(self.src_rank, tb, op)
733+
685734
def read_put_packets(self, dst_chunk: Chunk, src_chunk: Chunk, tb: int):
686735
"""Transfer data in packet format from local to remote scratch buffer.
687736

python/mscclpp/language/default_algos/allreduce_2nodes.py

Lines changed: 43 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,15 +34,31 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
3434
inter_node_port_channels = {}
3535
scratch_buffers = []
3636
thread_block_offset = 1
37-
thread_block_group = ThreadBlockGroup(
38-
tb_list=[i for i in range(thread_block_offset, thread_block_offset + thread_block_group_size)]
37+
thread_block_groups = []
38+
global_intra_node_tbg = ThreadBlockGroup(
39+
tb_list=[
40+
i
41+
for i in range(thread_block_offset, thread_block_offset + (gpus_per_node - 1) * thread_block_group_size)
42+
]
3943
)
44+
for i in range(gpus_per_node - 1):
45+
thread_block_groups.append(
46+
ThreadBlockGroup(
47+
tb_list=[
48+
i
49+
for i in range(
50+
thread_block_offset + i * thread_block_group_size,
51+
thread_block_offset + (i + 1) * thread_block_group_size,
52+
)
53+
]
54+
)
55+
)
4056

57+
scratch_buffer_size = packets_per_gpu * (total_gpus + 1)
4158
for node_id in range(num_nodes):
4259
for local_gpu_id in range(gpus_per_node):
4360
current_rank_id = local_gpu_id + gpus_per_node * node_id
4461
next_node_rank_id = (local_gpu_id + gpus_per_node * (node_id + 1)) % total_gpus
45-
scratch_buffer_size = 2 * total_gpus
4662
scratch_buffers.append(Buffer(current_rank_id, scratch_buffer_size))
4763
for peer_gpu_id in range(gpus_per_node):
4864
if peer_gpu_id != local_gpu_id:
@@ -64,13 +80,14 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
6480
for peer_gpu_id in range(gpus_per_node):
6581
peer_rank_id = peer_gpu_id + gpus_per_node * node_id
6682
peer_data_offset = peer_gpu_id * packets_per_gpu
83+
tbg_id = peer_gpu_id if peer_gpu_id < local_gpu_id else peer_gpu_id - 1
6784
if peer_gpu_id != local_gpu_id:
6885
intra_node_memory_channels[(peer_rank_id, current_rank_id)].put_packets(
6986
scratch_buffers[peer_rank_id][
7087
local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu
7188
],
7289
input_buffer[peer_data_offset : peer_data_offset + packets_per_gpu],
73-
tb_group=thread_block_group,
90+
tb_group=thread_block_groups[tbg_id],
7491
)
7592

7693
# Intra Node Reduce
@@ -84,20 +101,24 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
84101
current_rank.reduce(
85102
input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu],
86103
other_gpu_data,
87-
tb_group=thread_block_group,
104+
tb_group=global_intra_node_tbg,
88105
packet=True,
89106
)
90107

91-
# Copy Reduced Data to Scratch Buffer and send to Next Node
92108
current_rank.copy_packets(
93109
scratch_buffers[current_rank_id][
94110
local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu
95111
],
96112
input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu],
97-
tb_group=thread_block_group,
113+
tb_group=global_intra_node_tbg,
98114
)
115+
116+
current_rank.barrier(
117+
tb_list=[i for i in range(thread_block_offset + (gpus_per_node - 1) * thread_block_group_size)]
118+
)
119+
99120
inter_node_offset = total_gpus
100-
inter_node_port_channels[current_rank_id].read_put_packets(
121+
inter_node_port_channels[current_rank_id].put_packets(
101122
scratch_buffers[next_node_rank_id][
102123
inter_node_offset
103124
+ local_gpu_id * packets_per_gpu : inter_node_offset
@@ -122,31 +143,39 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
122143
current_rank.reduce(
123144
input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu],
124145
inter_node_data,
125-
tb_group=thread_block_group,
146+
tb_group=global_intra_node_tbg,
126147
packet=True,
127148
)
128149

150+
current_rank.copy_packets(
151+
scratch_buffers[current_rank_id][scratch_buffer_size - packets_per_gpu : scratch_buffer_size],
152+
input_buffer[local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu],
153+
tb_group=global_intra_node_tbg,
154+
)
155+
129156
# Broadcast Reduced Data
130157
for peer_gpu_id in range(gpus_per_node):
131158
peer_rank_id = peer_gpu_id + gpus_per_node * node_id
132159

133160
if peer_gpu_id != local_gpu_id:
134-
intra_node_memory_channels[(peer_rank_id, current_rank_id)].put_packets(
161+
tbg_id = peer_gpu_id if peer_gpu_id < local_gpu_id else peer_gpu_id - 1
162+
intra_node_memory_channels[(peer_rank_id, current_rank_id)].read_put_packets(
135163
scratch_buffers[peer_rank_id][
136164
inter_node_offset
137165
+ local_gpu_id * packets_per_gpu : inter_node_offset
138166
+ local_gpu_id * packets_per_gpu
139167
+ packets_per_gpu
140168
],
141-
input_buffer[
142-
local_gpu_id * packets_per_gpu : local_gpu_id * packets_per_gpu + packets_per_gpu
169+
scratch_buffers[current_rank_id][
170+
scratch_buffer_size - packets_per_gpu : scratch_buffer_size
143171
],
144-
tb_group=thread_block_group,
172+
tb_group=thread_block_groups[tbg_id],
145173
)
146174

147175
# Unpack Data Received from other GPUs in the same node
148176
for peer_gpu_id in range(gpus_per_node):
149177
if peer_gpu_id != local_gpu_id:
178+
tbg_id = peer_gpu_id if peer_gpu_id < local_gpu_id else peer_gpu_id - 1
150179
current_rank.unpack_packets(
151180
input_buffer[
152181
peer_gpu_id * packets_per_gpu : peer_gpu_id * packets_per_gpu + packets_per_gpu
@@ -157,7 +186,7 @@ def allreduce_2nodes(spec: AlgoSpec, thread_block_group_size: int) -> Collective
157186
+ peer_gpu_id * packets_per_gpu
158187
+ packets_per_gpu
159188
],
160-
tb_group=thread_block_group,
189+
tb_group=thread_block_groups[tbg_id],
161190
)
162191

163192
return prog

python/mscclpp/language/internal/operations.py

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -604,6 +604,7 @@ def __init__(
604604
self,
605605
local_src_buff: List[LocalChunk],
606606
local_dst_buff: List[LocalChunk],
607+
local_pkt_dst_buff: List[LocalChunk] = None,
607608
remote_src_buff: List[RemoteChunk] = None,
608609
remote_dst_buff: List[RemoteChunk] = None,
609610
channel_ids: List[int] = None,
@@ -613,19 +614,26 @@ def __init__(
613614
tbg_info: ThreadBlockGroupInfo = None,
614615
packet: bool = False,
615616
):
617+
local_pkt_dst_buff = local_pkt_dst_buff if local_pkt_dst_buff is not None else []
616618
remote_src_buff = remote_src_buff if remote_src_buff is not None else []
617619
remote_dst_buff = remote_dst_buff if remote_dst_buff is not None else []
618620
channel_ids = channel_ids if channel_ids is not None else []
619621
put_channel_ids = put_channel_ids if put_channel_ids is not None else []
620622

621623
if len(remote_src_buff) == 0 and len(remote_dst_buff) == 0:
622624
if packet:
623-
super().__init__(Instruction.reduce_packet)
625+
if len(local_pkt_dst_buff) == 0:
626+
super().__init__(Instruction.reduce_packet)
627+
else:
628+
super().__init__(Instruction.reduce_copy_packet)
624629
else:
625630
super().__init__(Instruction.reduce)
626631
elif len(remote_src_buff) == 0:
627632
if packet:
628-
super().__init__(Instruction.reduce_send_packet)
633+
if len(local_pkt_dst_buff) == 0:
634+
super().__init__(Instruction.reduce_send_packet)
635+
else:
636+
super().__init__(Instruction.reduce_copy_send_packet)
629637
else:
630638
super().__init__(Instruction.reduce_send)
631639
elif len(remote_dst_buff) == 0 and not packet:
@@ -637,6 +645,7 @@ def __init__(
637645

638646
self.local_src_buff = local_src_buff
639647
self.local_dst_buff = local_dst_buff
648+
self.local_pkt_dst_buff = local_pkt_dst_buff
640649
self.remote_src_buff = remote_src_buff
641650
self.remote_dst_buff = remote_dst_buff
642651
self.channel_ids = channel_ids
@@ -741,6 +750,49 @@ def __add__(self, other):
741750
tbg_info=self.tbg_info,
742751
packet=self.packet,
743752
)
753+
if (
754+
isinstance(other, CopyOperation)
755+
and self.name == Instruction.reduce_packet
756+
and other.name == Instruction.copy_packet
757+
and self.local_dst_buff[0] == other.src_buff[0]
758+
and self.tbg_info == other.tbg_info
759+
):
760+
fused_operation = ReduceOperation(
761+
self.local_src_buff,
762+
self.local_dst_buff,
763+
local_pkt_dst_buff=other.dst_buff,
764+
remote_src_buff=self.remote_src_buff,
765+
remote_dst_buff=self.remote_dst_buff,
766+
channel_ids=self.channel_ids,
767+
put_channel_ids=self.put_channel_ids,
768+
channel_type=self.channel_type,
769+
reduce_operation=self.reduce_operation,
770+
tbg_info=self.tbg_info,
771+
packet=self.packet,
772+
)
773+
if (
774+
isinstance(other, PutOperation)
775+
and (self.name == Instruction.reduce_copy_packet or self.name == Instruction.reduce_copy_send_packet)
776+
and (
777+
(other.name == Instruction.put_packet and self.local_dst_buff[0] == other.src_buff[0])
778+
or (other.name == Instruction.read_put_packet and self.local_pkt_dst_buff[0] == other.src_buff[0])
779+
)
780+
and other.channel_type == ChannelType.memory
781+
and self.tbg_info == other.tbg_info
782+
):
783+
fused_operation = ReduceOperation(
784+
self.local_src_buff,
785+
self.local_dst_buff,
786+
local_pkt_dst_buff=self.local_pkt_dst_buff,
787+
remote_src_buff=self.remote_src_buff,
788+
remote_dst_buff=self.remote_dst_buff + other.dst_buff,
789+
channel_ids=self.channel_ids,
790+
put_channel_ids=self.put_channel_ids + other.channel_ids,
791+
channel_type=other.channel_type,
792+
reduce_operation=self.reduce_operation,
793+
tbg_info=self.tbg_info,
794+
packet=self.packet,
795+
)
744796

745797
return fused_operation
746798

@@ -752,6 +804,8 @@ def to_dict(self):
752804
result["dst_buff"] = []
753805
for chunk in self.local_dst_buff:
754806
result["dst_buff"].append(chunk.to_dict())
807+
for chunk in self.local_pkt_dst_buff:
808+
result["dst_buff"].append(chunk.to_dict())
755809

756810
if len(self.remote_src_buff) > 0:
757811
for chunk in self.remote_src_buff:

python/mscclpp/language/internal/types.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@ class Instruction(Enum):
6969
unpack_packet = "upkt"
7070
reduce = "re"
7171
reduce_packet = "repkt"
72+
reduce_copy_packet = "recpkt"
7273
sem_acquire = "sem_acquire"
7374
sem_release = "sem_release"
7475
signal = "signal"
@@ -85,6 +86,7 @@ class Instruction(Enum):
8586
put_with_signal_and_flush = "pwsf"
8687
reduce_send = "res"
8788
reduce_send_packet = "respkt"
89+
reduce_copy_send_packet = "recspkt"
8890
read_reduce = "rre"
8991
read_reduce_send = "rres"
9092
group_store = "gstore"

0 commit comments

Comments
 (0)