1111import triton .language as tl
1212import iris
1313from iris ._mpi_helpers import mpi_allgather
14- from examples .common .utils import read_realtime
14+ # from examples.common.utils import read_realtime
15+
16+ @triton .jit
17+ def read_realtime ():
18+ tmp = tl .inline_asm_elementwise (
19+ asm = "mov.u64 $0, %globaltimer;" ,
20+ constraints = ("=l" ),
21+ args = [],
22+ dtype = tl .int64 ,
23+ is_pure = False ,
24+ pack = 1 ,
25+ )
26+ return tmp
1527
1628
1729@triton .jit ()
@@ -38,10 +50,10 @@ def load_remote(
3850 if i == skip :
3951 start = read_realtime ()
4052 tl .store (mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets , start , time_stmp_mask )
41-
53+
4254 # iris.load(data + offsets, curr_rank, peer_rank,heap_bases, data_mask)
4355 from_base = tl .load (heap_bases + curr_rank )
44- to_base = tl .load (heap_bases + peer_rank )
56+ to_base = tl .load (heap_bases + peer_rank )
4557 offset = tl .cast (data + offsets , tl .uint64 ) - from_base
4658 translated_ptr = tl .cast (tl .cast (to_base , tl .pointer_type (tl .int8 )) + offset , (data + offsets ).dtype )
4759 result = tl .load (translated_ptr , mask = data_mask , cache_modifier = ".cv" , volatile = True )
@@ -240,15 +252,14 @@ def print_run_settings(
240252 grid = lambda meta : (1 ,)
241253 for source_rank in range (num_ranks ):
242254 for destination_rank in range (num_ranks ):
243- if cur_rank in [source_rank , destination_rank ]:
244- peer_for_me = destination_rank if cur_rank == source_rank else source_rank
255+ if cur_rank == source_rank :
245256 load_remote [grid ](
246257 source_buffer ,
247258 BUFFER_LEN ,
248259 skip ,
249260 niter ,
250261 cur_rank ,
251- peer_for_me ,
262+ destination_rank ,
252263 BLOCK_SIZE ,
253264 heap_bases ,
254265 mm_begin_timestamp ,
@@ -258,10 +269,12 @@ def print_run_settings(
258269
259270 mm_begin_cpu = mm_begin_timestamp .cpu ().numpy ()
260271 mm_end_cpu = mm_end_timestamp .cpu ().numpy ()
272+
273+ gpu_freq = iris .hip .get_wall_clock_rate (cur_rank ) * 1e-3
261274 for destination_rank in range (num_ranks ):
262275 delta = mm_end_cpu [destination_rank , :] - mm_begin_cpu [destination_rank , :]
263- avg_ns = float (delta .sum () / max (1 , delta .size ) / max (1 , niter ))
264- local_latency [destination_rank ] = avg_ns
276+ avg_cc = float (delta .sum () / max (1 , delta .size ) / max (1 , niter ))
277+ local_latency [destination_rank ] = avg_cc / gpu_freq
265278
266279 latency_matrix = mpi_allgather (local_latency .cpu ())
267280
0 commit comments