Skip to content

Commit 44ef929

Browse files
committed
Fix latency time measurement
1 parent 56ad603 commit 44ef929

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

benchmarks/bench_load_latency.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,19 @@
1111
import triton.language as tl
1212
import iris
1313
from iris._mpi_helpers import mpi_allgather
14-
from examples.common.utils import read_realtime
14+
# from examples.common.utils import read_realtime
15+
16+
@triton.jit
17+
def read_realtime():
18+
tmp = tl.inline_asm_elementwise(
19+
asm="mov.u64 $0, %globaltimer;",
20+
constraints=("=l"),
21+
args=[],
22+
dtype=tl.int64,
23+
is_pure=False,
24+
pack=1,
25+
)
26+
return tmp
1527

1628

1729
@triton.jit()
@@ -38,10 +50,10 @@ def load_remote(
3850
if i == skip:
3951
start = read_realtime()
4052
tl.store(mm_begin_timestamp_ptr + peer_rank * BLOCK_SIZE + offsets, start, time_stmp_mask)
41-
53+
4254
# iris.load(data + offsets, curr_rank, peer_rank,heap_bases, data_mask)
4355
from_base = tl.load(heap_bases + curr_rank)
44-
to_base = tl.load(heap_bases + peer_rank)
56+
to_base = tl.load(heap_bases + peer_rank)
4557
offset = tl.cast(data + offsets, tl.uint64) - from_base
4658
translated_ptr = tl.cast(tl.cast(to_base, tl.pointer_type(tl.int8)) + offset, (data + offsets).dtype)
4759
result = tl.load(translated_ptr, mask=data_mask, cache_modifier=".cv", volatile=True)
@@ -240,15 +252,14 @@ def print_run_settings(
240252
grid = lambda meta: (1,)
241253
for source_rank in range(num_ranks):
242254
for destination_rank in range(num_ranks):
243-
if cur_rank in [source_rank, destination_rank]:
244-
peer_for_me = destination_rank if cur_rank == source_rank else source_rank
255+
if cur_rank == source_rank:
245256
load_remote[grid](
246257
source_buffer,
247258
BUFFER_LEN,
248259
skip,
249260
niter,
250261
cur_rank,
251-
peer_for_me,
262+
destination_rank,
252263
BLOCK_SIZE,
253264
heap_bases,
254265
mm_begin_timestamp,
@@ -258,10 +269,12 @@ def print_run_settings(
258269

259270
mm_begin_cpu = mm_begin_timestamp.cpu().numpy()
260271
mm_end_cpu = mm_end_timestamp.cpu().numpy()
272+
273+
gpu_freq = iris.hip.get_wall_clock_rate(cur_rank) * 1e-3
261274
for destination_rank in range(num_ranks):
262275
delta = mm_end_cpu[destination_rank, :] - mm_begin_cpu[destination_rank, :]
263-
avg_ns = float(delta.sum() / max(1, delta.size) / max(1, niter))
264-
local_latency[destination_rank] = avg_ns
276+
avg_cc = float(delta.sum() / max(1, delta.size) / max(1, niter))
277+
local_latency[destination_rank] = avg_cc / gpu_freq
265278

266279
latency_matrix = mpi_allgather(local_latency.cpu())
267280

0 commit comments

Comments
 (0)