From ebc57c5a2bf560770424bee5169fac09a2a3a9d1 Mon Sep 17 00:00:00 2001 From: ferrol aderholdt Date: Wed, 19 Feb 2025 15:09:35 -0800 Subject: [PATCH] REVIEW: fix a2a without memmap --- src/components/cl/basic/cl_basic_context.c | 8 ++-- .../tl/ucp/alltoall/alltoall_onesided.c | 37 ++++++++++--------- src/components/tl/ucp/tl_ucp_sendrecv.h | 11 +++--- 3 files changed, 30 insertions(+), 26 deletions(-) diff --git a/src/components/cl/basic/cl_basic_context.c b/src/components/cl/basic/cl_basic_context.c index 66a2ecf0d0..3b10eefeff 100644 --- a/src/components/cl/basic/cl_basic_context.c +++ b/src/components/cl/basic/cl_basic_context.c @@ -62,20 +62,20 @@ UCC_CLASS_CLEANUP_FUNC(ucc_cl_basic_context_t) } ucc_status_t ucc_cl_basic_mem_map(const ucc_base_context_t *context, int type, /* NOLINT */ - void *address, size_t len, void *memh, - void *tl_h) + void *address, size_t len, void *memh, /* NOLINT */ + void *tl_h) /* NOLINT */ { return UCC_ERR_NOT_SUPPORTED; } ucc_status_t ucc_cl_basic_mem_unmap(const ucc_base_context_t *context, int type, /* NOLINT */ - void *tl_h) + void *tl_h) /* NOLINT */ { return UCC_ERR_NOT_SUPPORTED; } ucc_status_t ucc_cl_basic_memh_pack(const ucc_base_context_t *context, /* NOLINT */ - void *memh, void **packed_buffer) + void *memh, void **packed_buffer) /* NOLINT */ { return UCC_ERR_NOT_SUPPORTED; } diff --git a/src/components/tl/ucp/alltoall/alltoall_onesided.c b/src/components/tl/ucp/alltoall/alltoall_onesided.c index 8d5b57b2e5..a1be136d91 100644 --- a/src/components/tl/ucp/alltoall/alltoall_onesided.c +++ b/src/components/tl/ucp/alltoall/alltoall_onesided.c @@ -33,24 +33,27 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask) nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype); dest = dest + grank * nelems; - UCPCHECK_GOTO( - ucc_tl_ucp_put_nb((void *)(src + start * nelems), (void *)dest, nelems, - start, *src_memh, dst_memh[start], team, task), - task, out); - UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, start, *src_memh, - dst_memh[start], team), - task, out); - - for (peer = (start + 1) % gsize; peer != start; peer = (peer + 1) % gsize) { - UCPCHECK_GOTO(ucc_tl_ucp_put_nb( - (void *)(src + peer * nelems), (void *)dest, nelems, - peer, *src_memh, dst_memh[peer], team, task), - task, out); - UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, *src_memh, - dst_memh[peer], team), - task, out); + if (ucc_likely(!(src_memh && dst_memh))) { + for (peer = start; task->onesided.put_posted < gsize; peer = (peer + 1) % gsize) { + UCPCHECK_GOTO(ucc_tl_ucp_put_nb( + (void *)(src + peer * nelems), (void *)dest, nelems, + peer, NULL, NULL, team, task), + task, out); + UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, NULL, + NULL, team), + task, out); + } + } else { + for (peer = start; task->onesided.put_posted < gsize; peer = (peer + 1) % gsize) { + UCPCHECK_GOTO(ucc_tl_ucp_put_nb( + (void *)(src + peer * nelems), (void *)dest, nelems, + peer, *src_memh, dst_memh[peer], team, task), + task, out); + UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, *src_memh, + dst_memh[peer], team), + task, out); + } } - return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); out: return task->super.status; diff --git a/src/components/tl/ucp/tl_ucp_sendrecv.h b/src/components/tl/ucp/tl_ucp_sendrecv.h index d304ec5089..0e34de80bb 100644 --- a/src/components/tl/ucp/tl_ucp_sendrecv.h +++ b/src/components/tl/ucp/tl_ucp_sendrecv.h @@ -421,19 +421,21 @@ static inline ucc_status_t ucc_tl_ucp_put_nb(void *buffer, void *target, int segment = 0; ucp_rkey_h rkey = NULL; uint64_t rva = 0; + void *ucp_memh = NULL; ucs_status_ptr_t ucp_status; ucc_status_t status; ucp_ep_h ep; - void *ucp_memh = NULL; status = ucc_tl_ucp_get_ep(team, dest_group_rank, &ep); if (ucc_unlikely(UCC_OK != status)) { return status; } - status = ucc_tl_ucp_get_memh(team, src_memh, &ucp_memh); - if (ucc_unlikely(UCC_OK != status)) { - return status; + if (src_memh) { + status = ucc_tl_ucp_get_memh(team, src_memh, &ucp_memh); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } } status = ucc_tl_ucp_resolve_p2p_by_va(team, target, &ep, dest_group_rank, @@ -442,7 +444,6 @@ static inline ucc_status_t ucc_tl_ucp_put_nb(void *buffer, void *target, return status; } - req_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA; req_param.cb.send = ucc_tl_ucp_put_completion_cb;