diff --git a/src/components/tl/mlx5/alltoall/alltoall_coll.c b/src/components/tl/mlx5/alltoall/alltoall_coll.c index 70439263fb..6e8cb0ec5d 100644 --- a/src/components/tl/mlx5/alltoall/alltoall_coll.c +++ b/src/components/tl/mlx5/alltoall/alltoall_coll.c @@ -243,6 +243,10 @@ static ucc_status_t ucc_tl_mlx5_fanout_start(ucc_coll_task_t *coll_task) tl_debug(UCC_TASK_LIB(task), "fanout start"); /* start task if completion event received */ UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanout_start", 0); + if (team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank) { + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_wait-on-data_start", 0); + } /* Start fanout */ ucc_progress_enqueue(UCC_TL_CORE_CTX(team)->pq, coll_task); return UCC_OK; @@ -265,6 +269,8 @@ static void ucc_tl_mlx5_fanout_progress(ucc_coll_task_t *coll_task) coll_task->status = UCC_INPROGRESS; return; } + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_wait-on-data_complete, fanout_start", 0); } if (UCC_OK == ucc_tl_mlx5_node_fanout(team, task)) { @@ -342,12 +348,14 @@ static ucc_status_t ucc_tl_mlx5_asr_barrier_start(ucc_coll_task_t *coll_task) status = send_done(team, i); } if (status != UCC_OK) { - tl_error(UCC_TASK_LIB(task), "failed sending barrier notice"); + tl_error(UCC_TASK_LIB(task), "failed sending barrier notice"); return status; } + UCC_TL_MLX5_PROFILE_REQUEST_EVENT( + task, "mlx5_alltoall_barrier_send_posted", 0); } coll_task->status = UCC_OK; - UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barreir_done", + UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_barrier_done", 0); return ucc_task_complete(coll_task); } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h index 663ee636ed..49f2292166 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast.h @@ -199,22 +199,25 @@ struct pp_packet { uintptr_t buf; // buffer address, initialized once }; +struct mcast_group { + struct ibv_qp *qp; + struct ibv_ah *ah; + uint16_t lid; + union ibv_gid mgid; + struct sockaddr_in6 mcast_addr; +}; + struct mcast_ctx { - struct ibv_qp *qp; - struct ibv_ah *ah; struct ibv_send_wr swr; struct ibv_sge ssg; - + struct ibv_cq *scq; + struct ibv_cq *rcq; + struct ibv_srq *srq; + struct mcast_group groups[MAX_GROUP_COUNT]; // RC connection info for supporing one-sided based relibality struct ibv_qp **rc_qp; uint16_t *rc_lid; union ibv_gid *rc_gid; - - // multiple mcast group - struct ibv_qp **qp_list; - struct ibv_ah **ah_list; - struct ibv_send_wr *swr_list; - struct ibv_sge *ssg_list; }; struct packet { @@ -303,15 +306,10 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { ucc_tl_mlx5_mcast_coll_comm_init_spec_t params; ucc_tl_mlx5_mcast_p2p_interface_t p2p; int tx; - struct ibv_cq *scq; - struct ibv_cq *rcq; - struct ibv_srq *srq; ucc_rank_t rank; ucc_rank_t commsize; char *grh_buf; struct ibv_mr *grh_mr; - uint16_t mcast_lid; - union ibv_gid mgid; unsigned max_inline; size_t max_eager; int max_per_packet; @@ -334,7 +332,6 @@ typedef struct ucc_tl_mlx5_mcast_coll_comm { int comm_id; void *p2p_ctx; ucc_base_lib_t *lib; - struct sockaddr_in6 mcast_addr; int cuda_mem_enabled; ucc_tl_mlx5_mcast_join_info_t *group_setup_info; ucc_service_coll_req_t *group_setup_info_req; @@ -441,6 +438,8 @@ typedef struct ucc_tl_mlx5_mcast_coll_req { ucc_service_coll_req_t *allgather_rkeys_req; ucc_service_coll_req_t *barrier_req; void *recv_rreg; + ucc_ee_executor_task_t *exec_task; + ucc_coll_task_t *coll_task; } ucc_tl_mlx5_mcast_coll_req_t; typedef struct ucc_tl_mlx5_mcast_oob_p2p_context { @@ -490,7 +489,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_recv_buffers(ucc_tl_mlx5_mcast } if (i != 0) { rwr[i-1].next = NULL; - if (ibv_post_recv(comm->mcast.qp, &rwr[0], &bad_wr)) { + if (ibv_post_recv(comm->mcast.groups[0].qp, &rwr[0], &bad_wr)) { tl_error(comm->lib, "failed to prepost recvs: errno %d", errno); return UCC_ERR_NO_RESOURCE; } @@ -543,7 +542,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_ if (i > 0) { rwr[i-1].next = NULL; - if (ibv_post_recv(comm->mcast.qp_list[group_id], &rwr[0], &bad_wr)) { + if (ibv_post_recv(comm->mcast.groups[group_id].qp, &rwr[0], &bad_wr)) { tl_error(comm->lib, "Failed to prepost recvs: errno %d buffer count %d", errno, i); return UCC_ERR_NO_RESOURCE; @@ -555,6 +554,21 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_post_user_recv_buffers(ucc_tl_mlx5_ return UCC_OK; } +#define EXEC_TASK_TEST(_errmsg, _etask, _lib) do { \ + if (_etask != NULL) { \ + status = ucc_ee_executor_task_test(_etask); \ + if (status > 0) { \ + return status; \ + } \ + ucc_ee_executor_task_finalize(_etask); \ + _etask = NULL; \ + if (ucc_unlikely(status < 0)) { \ + tl_error(_lib, _errmsg); \ + return status; \ + } \ + } \ +} while(0) + ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *tl_context, ucc_tl_mlx5_mcast_team_t **mcast_team, ucc_tl_mlx5_mcast_context_t *ctx, diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c index b6fbe84e3d..cf813fd5af 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.c @@ -33,6 +33,10 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_r_window_recycle(ucc_tl_mlx5_mcast_ return status; } + while (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib); + } + comm->bcast_comm.n_mcast_reliable++; for (; comm->bcast_comm.last_acked < comm->psn; comm->bcast_comm.last_acked++) { @@ -267,7 +271,10 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_start(ucc_coll_task_t *coll_task) return ucc_task_complete(coll_task); } - coll_task->status = status; + ucc_assert(task->coll_mcast.req_handle != NULL); + + coll_task->status = status; + task->coll_mcast.req_handle->coll_task = coll_task; return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(mlx5_team)->pq, &task->super); } @@ -333,6 +340,7 @@ ucc_status_t ucc_tl_mlx5_mcast_bcast_init(ucc_tl_mlx5_task_t *task) { task->super.post = ucc_tl_mlx5_mcast_bcast_start; task->super.progress = ucc_tl_mlx5_mcast_collective_progress; + task->super.flags = UCC_COLL_TASK_FLAG_EXECUTOR; return UCC_OK; } diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h index f34e8827f4..ccc563ecc7 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_coll.h @@ -16,4 +16,5 @@ ucc_status_t ucc_tl_mlx5_mcast_test(ucc_tl_mlx5_mcast_coll_req_t* _req); ucc_status_t ucc_tl_mlx5_mcast_check_support(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team); + #endif diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c index a116c08cf8..ae736a37a9 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.c @@ -282,11 +282,14 @@ ucc_status_t ucc_tl_mlx5_setup_mcast_group_join_post(ucc_tl_mlx5_mcast_coll_comm ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, ucc_tl_mlx5_mcast_coll_comm_t *comm) { - struct ibv_qp_init_attr qp_init_attr = {0}; + int max_inline = INT_MAX; + struct ibv_qp_init_attr qp_init_attr = {0}; + int i; + int j; qp_init_attr.qp_type = IBV_QPT_UD; - qp_init_attr.send_cq = comm->scq; - qp_init_attr.recv_cq = comm->rcq; + qp_init_attr.send_cq = comm->mcast.scq; //cq can be shared between multiple QPs + qp_init_attr.recv_cq = comm->mcast.rcq; qp_init_attr.sq_sig_all = 0; qp_init_attr.cap.max_send_wr = comm->params.sx_depth; qp_init_attr.cap.max_recv_wr = comm->params.rx_depth; @@ -294,41 +297,68 @@ ucc_status_t ucc_tl_mlx5_mcast_init_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, qp_init_attr.cap.max_send_sge = comm->params.sx_sge; qp_init_attr.cap.max_recv_sge = comm->params.rx_sge; - comm->mcast.qp = ibv_create_qp(ctx->pd, &qp_init_attr); - if (!comm->mcast.qp) { - tl_warn(ctx->lib, "failed to create mcast qp, errno %d", errno); - return UCC_ERR_NO_RESOURCE; + for (i = 0; i < comm->mcast_group_count; i++) { + comm->mcast.groups[i].qp = ibv_create_qp(ctx->pd, &qp_init_attr); + if (!comm->mcast.groups[i].qp) { + tl_error(ctx->lib, "Failed to create mcast UD qp index %d, errno %d", i, errno); + goto error; + } + if (qp_init_attr.cap.max_inline_data < max_inline) { + max_inline = qp_init_attr.cap.max_inline_data; + } } if (comm->cuda_mem_enabled) { /* max inline send otherwise it segfault during ibv send */ comm->max_inline = 0; } else { - comm->max_inline = qp_init_attr.cap.max_inline_data; + comm->max_inline = max_inline; } return UCC_OK; + +error: + for (j = 0; j < i; j++) { + ibv_destroy_qp(comm->mcast.groups[j].qp); + comm->mcast.groups[j].qp = NULL; + } + return UCC_ERR_NO_RESOURCE; } static ucc_status_t ucc_tl_mlx5_mcast_create_ah(ucc_tl_mlx5_mcast_coll_comm_t *comm) { + int i, j, ret; struct ibv_ah_attr ah_attr = { .is_global = 1, .grh = {.sgid_index = 0}, - .dlid = comm->mcast_lid, .sl = DEF_SL, .src_path_bits = DEF_SRC_PATH_BITS, .port_num = comm->ctx->ib_port }; - memcpy(ah_attr.grh.dgid.raw, &comm->mgid, sizeof(ah_attr.grh.dgid.raw)); + for (i = 0; i < comm->mcast_group_count; i ++) { + ah_attr.dlid = comm->mcast.groups[i].lid; + memcpy(ah_attr.grh.dgid.raw, &comm->mcast.groups[i].mgid, sizeof(ah_attr.grh.dgid.raw)); - comm->mcast.ah = ibv_create_ah(comm->ctx->pd, &ah_attr); - if (!comm->mcast.ah) { - tl_warn(comm->lib, "failed to create AH"); - return UCC_ERR_NO_RESOURCE; + comm->mcast.groups[i].ah = ibv_create_ah(comm->ctx->pd, &ah_attr); + if (!comm->mcast.groups[i].ah) { + tl_error(comm->lib, "failed to create AH index %d", i); + goto error; + } } + return UCC_OK; + +error: + for (j = 0; j < i; j++) { + ret = ibv_destroy_ah(comm->mcast.groups[j].ah); + if (ret) { + tl_error(comm->lib, "couldn't destroy ah"); + return UCC_ERR_NO_RESOURCE; + } + comm->mcast.groups[j].ah = NULL; + } + return UCC_ERR_NO_RESOURCE; } ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, @@ -337,16 +367,15 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, struct ibv_port_attr port_attr; struct ibv_qp_attr attr; uint16_t pkey; + int i; ibv_query_port(ctx->ctx, ctx->ib_port, &port_attr); - for (ctx->pkey_index = 0; ctx->pkey_index < port_attr.pkey_tbl_len; ++ctx->pkey_index) { ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey); if (pkey == DEF_PKEY) break; } - if (ctx->pkey_index >= port_attr.pkey_tbl_len) { ctx->pkey_index = 0; ibv_query_pkey(ctx->ctx, ctx->ib_port, ctx->pkey_index, &pkey); @@ -359,43 +388,53 @@ ucc_status_t ucc_tl_mlx5_mcast_setup_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, "index 0 pkey:0x%04x", DEF_PKEY, ctx->ib_port, pkey); } - attr.qp_state = IBV_QPS_INIT; - attr.pkey_index = ctx->pkey_index; - attr.port_num = ctx->ib_port; - attr.qkey = DEF_QKEY; + for (i = 0; i < comm->mcast_group_count; i++) { + attr.qp_state = IBV_QPS_INIT; + attr.pkey_index = ctx->pkey_index; + attr.port_num = ctx->ib_port; + attr.qkey = DEF_QKEY; - if (ibv_modify_qp(comm->mcast.qp, &attr, - IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) { - tl_warn(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } + if (ibv_modify_qp(comm->mcast.groups[i].qp, &attr, + IBV_QP_STATE | IBV_QP_PKEY_INDEX | IBV_QP_PORT | IBV_QP_QKEY)) { + tl_error(ctx->lib, "failed to move mcast qp to INIT, errno %d", errno); + goto error; + } - if (ibv_attach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid)) { - tl_warn(ctx->lib, "failed to attach QP to the mcast group, errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } + if (ibv_attach_mcast(comm->mcast.groups[i].qp, &comm->mcast.groups[i].mgid, + comm->mcast.groups[i].lid)) { + tl_error(ctx->lib, "failed to attach QP to the mcast group with mcast_lid %d , errno %d", + errno, comm->mcast.groups[i].lid); + goto error; + } - /* Ok, now cycle to RTR on everyone */ - attr.qp_state = IBV_QPS_RTR; - if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE)) { - tl_warn(ctx->lib, "failed to modify QP to RTR, errno %d", errno); - return UCC_ERR_NO_RESOURCE; - } + attr.qp_state = IBV_QPS_RTR; + if (ibv_modify_qp(comm->mcast.groups[i].qp, &attr, IBV_QP_STATE)) { + tl_error(ctx->lib, "failed to modify QP to RTR, errno %d", errno); + goto error; + } - attr.qp_state = IBV_QPS_RTS; - attr.sq_psn = DEF_PSN; - if (ibv_modify_qp(comm->mcast.qp, &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) { - tl_warn(ctx->lib, "failed to modify QP to RTS, errno %d", errno); - return UCC_ERR_NO_RESOURCE; + attr.qp_state = IBV_QPS_RTS; + attr.sq_psn = DEF_PSN; + if (ibv_modify_qp(comm->mcast.groups[i].qp, &attr, IBV_QP_STATE | IBV_QP_SQ_PSN)) { + tl_error(ctx->lib, "failed to modify QP to RTS, errno %d", errno); + goto error; + } } - /* Create the address handle */ + /* create the address handle */ if (UCC_OK != ucc_tl_mlx5_mcast_create_ah(comm)) { tl_warn(ctx->lib, "failed to create adress handle"); - return UCC_ERR_NO_RESOURCE; + goto error; } return UCC_OK; + +error: + for (i=0; i < comm->mcast_group_count; i++) { + ibv_destroy_qp(comm->mcast.groups[i].qp); + comm->mcast.groups[i].qp = NULL; + } + return UCC_ERR_NO_RESOURCE; } ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *ctx, @@ -410,8 +449,8 @@ ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *c srq_init_attr.attr.max_wr = comm->params.rx_depth; srq_init_attr.attr.max_sge = 2; - comm->srq = ibv_create_srq(ctx->pd, &srq_init_attr); - if (!comm->srq) { + comm->mcast.srq = ibv_create_srq(ctx->pd, &srq_init_attr); + if (!comm->mcast.srq) { tl_error(ctx->lib, "ibv_create_srq() failed"); return UCC_ERR_NO_RESOURCE; } @@ -426,10 +465,10 @@ ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *c for (i = 0; i < comm->commsize; i++) { memset(&qp_init_attr, 0, sizeof(qp_init_attr)); - qp_init_attr.srq = comm->srq; + qp_init_attr.srq = comm->mcast.srq; qp_init_attr.qp_type = IBV_QPT_RC; - qp_init_attr.send_cq = comm->scq; - qp_init_attr.recv_cq = comm->rcq; + qp_init_attr.send_cq = comm->mcast.scq; + qp_init_attr.recv_cq = comm->mcast.rcq; qp_init_attr.sq_sig_all = 0; qp_init_attr.cap.max_send_wr = comm->params.sx_depth; qp_init_attr.cap.max_recv_wr = 0; // has srq @@ -454,7 +493,7 @@ ucc_status_t ucc_tl_mlx5_mcast_create_rc_qps(ucc_tl_mlx5_mcast_coll_context_t *c } } - if (ibv_destroy_srq(comm->srq)) { + if (ibv_destroy_srq(comm->mcast.srq)) { tl_error(comm->lib, "ibv_destroy_srq failed"); return UCC_ERR_NO_RESOURCE; } @@ -538,7 +577,7 @@ ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx, char buf[40]; const char *dst; - dst = inet_ntop(AF_INET6, &comm->mcast_addr, buf, 40); + dst = inet_ntop(AF_INET6, &comm->mcast.groups[0].mcast_addr, buf, 40); if (NULL == dst) { tl_error(comm->lib, "inet_ntop failed"); return UCC_ERR_NO_RESOURCE; @@ -546,7 +585,7 @@ ucc_status_t ucc_tl_mlx5_fini_mcast_group(ucc_tl_mlx5_mcast_coll_context_t *ctx, tl_debug(ctx->lib, "mcast leave: ctx %p, comm %p, dgid: %s", ctx, comm, buf); - if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast_addr)) { + if (rdma_leave_multicast(ctx->id, (struct sockaddr*)&comm->mcast.groups[0].mcast_addr)) { tl_error(comm->lib, "mcast rmda_leave_multicast failed"); return UCC_ERR_NO_RESOURCE; } @@ -559,11 +598,10 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) ucc_tl_mlx5_mcast_context_t *mcast_ctx = ucc_container_of(comm->ctx, ucc_tl_mlx5_mcast_context_t, mcast_context); ucc_tl_mlx5_context_t *mlx5_ctx = ucc_container_of(mcast_ctx, ucc_tl_mlx5_context_t, mcast); ucc_context_h context = mlx5_ctx->super.super.ucc_context; - int ret; + int ret, i; ucc_status_t status; - tl_debug(comm->lib, "cleaning mcast comm: %p, id %d, mlid %x", - comm, comm->comm_id, comm->mcast_lid); + tl_debug(comm->lib, "cleaning mcast comm: %p, id %d", comm, comm->comm_id); while (UCC_INPROGRESS == (status = ucc_tl_mlx5_mcast_reliable(comm))) { ucc_context_progress(context); @@ -575,32 +613,48 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) return status; } - if (comm->mcast.qp) { - ret = ibv_detach_mcast(comm->mcast.qp, &comm->mgid, comm->mcast_lid); - if (ret) { - tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno); - return UCC_ERR_NO_RESOURCE; + for (i = 0; i < comm->mcast_group_count; i++) { + if (comm->mcast.groups[i].qp) { + ret = ibv_detach_mcast(comm->mcast.groups[i].qp, &(comm->mcast.groups[i].mgid), comm->mcast.groups[i].lid); + if (ret) { + tl_error(comm->lib, "couldn't detach QP, ret %d, errno %d", ret, errno); + return UCC_ERR_NO_RESOURCE; + } + + ret = ibv_destroy_qp(comm->mcast.groups[i].qp); + if (ret) { + tl_error(comm->lib, "failed to destroy QP %d", ret); + return UCC_ERR_NO_RESOURCE; + } + + comm->mcast.groups[i].qp = NULL; + } + if (comm->mcast.groups[i].ah) { + ret = ibv_destroy_ah(comm->mcast.groups[i].ah); + if (ret) { + tl_error(comm->lib, "couldn't destroy ah"); + return UCC_ERR_NO_RESOURCE; + } + comm->mcast.groups[i].ah = NULL; } } - if (comm->mcast.qp) { - ret = ibv_destroy_qp(comm->mcast.qp); - if (ret) { - tl_error(comm->lib, "failed to destroy QP %d", ret); - return UCC_ERR_NO_RESOURCE; - } + status = ucc_tl_mlx5_fini_mcast_group(comm->ctx, comm); + if (status) { + tl_error(comm->lib, "couldn't leave mcast group"); + return status; } - if (comm->rcq) { - ret = ibv_destroy_cq(comm->rcq); + if (comm->mcast.rcq) { + ret = ibv_destroy_cq(comm->mcast.rcq); if (ret) { tl_error(comm->lib, "couldn't destroy rcq"); return UCC_ERR_NO_RESOURCE; } } - if (comm->scq) { - ret = ibv_destroy_cq(comm->scq); + if (comm->mcast.scq) { + ret = ibv_destroy_cq(comm->mcast.scq); if (ret) { tl_error(comm->lib, "couldn't destroy scq"); return UCC_ERR_NO_RESOURCE; @@ -643,22 +697,6 @@ ucc_status_t ucc_tl_mlx5_clean_mcast_comm(ucc_tl_mlx5_mcast_coll_comm_t *comm) ucc_free(comm->call_rsgs); } - if (comm->mcast.ah) { - ret = ibv_destroy_ah(comm->mcast.ah); - if (ret) { - tl_error(comm->lib, "couldn't destroy ah"); - return UCC_ERR_NO_RESOURCE; - } - } - - if (comm->mcast_lid) { - status = ucc_tl_mlx5_fini_mcast_group(comm->ctx, comm); - if (status) { - tl_error(comm->lib, "couldn't leave mcast group"); - return status; - } - } - if (comm->ctx->params.print_nack_stats) { tl_debug(comm->lib, "comm_id %d, comm_size %d, comm->psn %d, rank %d, " "nacks counter %d, n_mcast_rel %d", diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h index d0b1a1ddd3..fc5296568d 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_helper.h @@ -16,7 +16,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_poll_send(ucc_tl_mlx5_mcast_coll_co struct ibv_wc wc; int num_comp; - num_comp = ibv_poll_cq(comm->scq, 1, &wc); + num_comp = ibv_poll_cq(comm->mcast.scq, 1, &wc); tl_trace(comm->lib, "polled send completions: %d", num_comp); @@ -108,7 +108,7 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send(ucc_tl_mlx5_mcast_coll_comm_t tl_trace(comm->lib, "post_send, psn %d, length %d, zcopy %d, signaled %d", pp->psn, pp->length, zcopy, swr[0].send_flags & IBV_SEND_SIGNALED); - if (0 != (rc = ibv_post_send(comm->mcast.qp, &swr[0], &bad_wr))) { + if (0 != (rc = ibv_post_send(comm->mcast.groups[0].qp, &swr[0], &bad_wr))) { tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, " "to_recv %d, length %d, psn %d, inline %d", rc, req->start_psn, req->to_send, req->to_recv, @@ -202,7 +202,7 @@ static inline int ucc_tl_mlx5_mcast_recv(ucc_tl_mlx5_mcast_coll_comm_t *comm, while (num_left > 0) { memset(wc, 0, sizeof(struct ibv_wc) * POLL_PACKED); - num_comp = ibv_poll_cq(comm->rcq, POLL_PACKED, wc); + num_comp = ibv_poll_cq(comm->mcast.rcq, POLL_PACKED, wc); if (num_comp < 0) { tl_error(comm->lib, "recv queue poll completion failed %d", num_comp); @@ -329,19 +329,19 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_send_collective(ucc_tl_mlx5_mcast_c mcast_group_index = group_id; } - swr[0].wr.ud.ah = comm->mcast.ah_list[mcast_group_index]; + swr[0].wr.ud.ah = comm->mcast.groups[mcast_group_index].ah; tl_trace(comm->lib, "mcast allgather post_send, psn %d, length %d, " "zcopy %d, signaled %d qp->state %d qp->qp_num %d qp->pd %p " "coll_type %d mcast_group_index %d", pp->psn, pp->length, zcopy, swr[0].send_flags & IBV_SEND_SIGNALED, - comm->mcast.qp_list[mcast_group_index]->state, - comm->mcast.qp_list[mcast_group_index]->qp_num, - comm->mcast.qp_list[mcast_group_index]->pd, coll_type, + comm->mcast.groups[mcast_group_index].qp->state, + comm->mcast.groups[mcast_group_index].qp->qp_num, + comm->mcast.groups[mcast_group_index].qp->pd, coll_type, mcast_group_index); - if (0 != (rc = ibv_post_send(comm->mcast.qp_list[mcast_group_index], &swr[0], &bad_wr))) { + if (0 != (rc = ibv_post_send(comm->mcast.groups[mcast_group_index].qp, &swr[0], &bad_wr))) { tl_error(comm->lib, "post send failed: ret %d, start_psn %d, to_send %d, " "to_recv %d, length %d, psn %d, inline %d", rc, req->start_psn, req->to_send, req->to_recv, @@ -398,7 +398,7 @@ static inline int ucc_tl_mlx5_mcast_recv_collective(ucc_tl_mlx5_mcast_coll_comm_ while (num_left > recv_progressed) { memset(wc, 0, sizeof(sizeof(struct ibv_wc) * POLL_PACKED)); - num_comp = ibv_poll_cq(comm->rcq, POLL_PACKED, &wc[0]); + num_comp = ibv_poll_cq(comm->mcast.rcq, POLL_PACKED, &wc[0]); if (num_comp < 0) { tl_error(comm->lib, "recv queue poll completion failed %d", num_comp); @@ -460,10 +460,9 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_poll_recv(ucc_tl_mlx5_mcast_coll_co uint32_t psn; do { - num_comp = ibv_poll_cq(comm->rcq, 1, &wc); + num_comp = ibv_poll_cq(comm->mcast.rcq, 1, &wc); if (num_comp > 0) { - if (IBV_WC_SUCCESS != wc.status) { tl_error(comm->lib, "mcast_poll_recv: %s err %d num_comp", ibv_wc_status_str(wc.status), num_comp); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c index a3d16fd6d8..3db2d1a8f7 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_one_sided_reliability.c @@ -119,11 +119,11 @@ static inline ucc_status_t ucc_tl_mlx5_mcast_one_sided_cleanup(ucc_tl_mlx5_mcast comm->mcast.rc_qp = NULL; } - if (comm->srq != NULL && ibv_destroy_srq(comm->srq)) { + if (comm->mcast.srq != NULL && ibv_destroy_srq(comm->mcast.srq)) { tl_error(comm->lib, "ibv_destroy_srq failed"); return UCC_ERR_NO_RESOURCE; } - comm->srq = NULL; + comm->mcast.srq = NULL; if (comm->one_sided.slots_mr) { ibv_dereg_mr(comm->one_sided.slots_mr); diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c index 3620cf629f..8031af6dc0 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_progress.c @@ -391,9 +391,10 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com ucc_tl_mlx5_mcast_coll_req_t *req, struct pp_packet* pp) { - ucc_status_t status = UCC_OK; - void *dest; - ucc_memory_type_t mem_type; + ucc_status_t status = UCC_OK; + void *dest; + ucc_ee_executor_task_args_t eargs; + ucc_ee_executor_t *exec; ucc_assert(pp->psn >= req->start_psn && pp->psn < req->start_psn + req->num_packets); @@ -402,19 +403,30 @@ ucc_status_t ucc_tl_mlx5_mcast_process_packet(ucc_tl_mlx5_mcast_coll_comm_t *com if (pp->length > 0 ) { dest = req->ptr + PSN_TO_RECV_OFFSET(pp->psn, req, comm); - - if (comm->cuda_mem_enabled) { - mem_type = UCC_MEMORY_TYPE_CUDA; - } else { - mem_type = UCC_MEMORY_TYPE_HOST; + while (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to complete the nb memcpy", req->exec_task, comm->lib); } - status = ucc_mc_memcpy(dest, (void*) pp->buf, pp->length, - mem_type, mem_type); + /* for cuda copy, exec is nonblocking but for host copy it is blocking */ + status = ucc_coll_task_get_executor(req->coll_task, &exec); if (ucc_unlikely(status != UCC_OK)) { - tl_error(comm->lib, "failed to copy buffer"); return status; } + + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.src = (void*) pp->buf; + eargs.copy.dst = dest; + eargs.copy.len = pp->length; + + assert(req->exec_task == NULL); + status = ucc_ee_executor_task_post(exec, &eargs, &req->exec_task); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + + if (req->exec_task != NULL) { + EXEC_TASK_TEST("failed to progress the memcpy", req->exec_task, comm->lib); + } } comm->r_window[pp->psn & (comm->bcast_comm.wsize-1)] = pp; diff --git a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c index b70ca6e2f6..84efb5daf1 100644 --- a/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c +++ b/src/components/tl/mlx5/mcast/tl_mlx5_mcast_team.c @@ -114,8 +114,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, goto cleanup; } - comm->rcq = ibv_create_cq(mcast_context->ctx, comm->params.rx_depth, NULL, NULL, 0); - if (!comm->rcq) { + comm->mcast.rcq = ibv_create_cq(mcast_context->ctx, comm->params.rx_depth, NULL, NULL, 0); + if (!comm->mcast.rcq) { ibv_dereg_mr(comm->grh_mr); tl_error(mcast_context->lib, "could not create recv cq, rx_depth %d, errno %d", comm->params.rx_depth, errno); @@ -123,10 +123,10 @@ ucc_status_t ucc_tl_mlx5_mcast_team_init(ucc_base_context_t *base_context, goto cleanup; } - comm->scq = ibv_create_cq(mcast_context->ctx, comm->params.sx_depth, NULL, NULL, 0); - if (!comm->scq) { + comm->mcast.scq = ibv_create_cq(mcast_context->ctx, comm->params.sx_depth, NULL, NULL, 0); + if (!comm->mcast.scq) { ibv_dereg_mr(comm->grh_mr); - ibv_destroy_cq(comm->rcq); + ibv_destroy_cq(comm->mcast.rcq); tl_error(mcast_context->lib, "could not create send cq, sx_depth %d, errno %d", comm->params.sx_depth, errno); status = UCC_ERR_NO_RESOURCE; @@ -263,7 +263,7 @@ ucc_status_t ucc_tl_mlx5_mcast_coll_setup_comm_resources(ucc_tl_mlx5_mcast_coll_ ucc_list_add_tail(&comm->bpool, &comm->pp[i].super); } - comm->mcast.swr.wr.ud.ah = comm->mcast.ah; + comm->mcast.swr.wr.ud.ah = comm->mcast.groups[0].ah; comm->mcast.swr.num_sge = 1; comm->mcast.swr.sg_list = &comm->mcast.ssg; comm->mcast.swr.opcode = IBV_WR_SEND_WITH_IMM; @@ -325,8 +325,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) return UCC_INPROGRESS; } - comm->mcast_addr = net_addr; - tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; + comm->mcast.groups[0].mcast_addr = net_addr; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; return UCC_INPROGRESS; } @@ -373,11 +373,11 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) if (tl_team->mcast_state == TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_READY) { /* rank 0 bcast the lid/gid to other processes */ - data->status = UCC_OK; - data->dgid = comm->event->param.ud.ah_attr.grh.dgid; - data->dlid = comm->event->param.ud.ah_attr.dlid; - comm->mcast_lid = data->dlid; - comm->mgid = data->dgid; + data->status = UCC_OK; + data->dgid = comm->event->param.ud.ah_attr.grh.dgid; + data->dlid = comm->event->param.ud.ah_attr.dlid; + comm->mcast.groups[0].lid = data->dlid; + comm->mcast.groups[0].mgid = data->dgid; } else { /* rank 0 bcast the failed status to other processes so others do not hang */ data->status = UCC_ERR_NO_RESOURCE; @@ -522,8 +522,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) return status; } - comm->mcast_addr = net_addr; - tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; + comm->mcast.groups[0].mcast_addr = net_addr; + tl_team->mcast_state = TL_MLX5_TEAM_STATE_MCAST_GRP_JOIN_POST; return UCC_INPROGRESS; } @@ -549,8 +549,8 @@ ucc_status_t ucc_tl_mlx5_mcast_team_test(ucc_base_team_t *team) ucc_assert(comm->event != NULL); - comm->mcast_lid = comm->group_setup_info->dlid; - comm->mgid = comm->group_setup_info->dgid; + comm->mcast.groups[0].lid = comm->group_setup_info->dlid; + comm->mcast.groups[0].mgid = comm->group_setup_info->dgid; ucc_free(comm->group_setup_info); if (comm->event) { diff --git a/src/components/tl/mlx5/tl_mlx5_coll.c b/src/components/tl/mlx5/tl_mlx5_coll.c index 94d336ba6e..aabdbf8010 100644 --- a/src/components/tl/mlx5/tl_mlx5_coll.c +++ b/src/components/tl/mlx5/tl_mlx5_coll.c @@ -14,8 +14,8 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, ucc_coll_task_t **task_h) { - ucc_status_t status = UCC_OK; - ucc_tl_mlx5_task_t *task = NULL; + ucc_status_t status = UCC_OK; + ucc_tl_mlx5_task_t *task = NULL; status = ucc_tl_mlx5_mcast_check_support(coll_args, team); if (UCC_OK != status) { @@ -35,12 +35,14 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, if (ucc_unlikely(UCC_OK != status)) { goto free_task; } + *task_h = &(task->super); break; case UCC_COLL_TYPE_ALLGATHER: status = ucc_tl_mlx5_mcast_allgather_init(task); if (ucc_unlikely(UCC_OK != status)) { goto free_task; } + *task_h = &(task->super); break; default: status = UCC_ERR_NOT_SUPPORTED; @@ -48,8 +50,6 @@ ucc_status_t ucc_tl_mlx5_coll_mcast_init(ucc_base_coll_args_t *coll_args, goto free_task; } - *task_h = &(task->super); - tl_debug(UCC_TASK_LIB(task), "initialized mcast collective task %p", task); return UCC_OK; diff --git a/src/components/tl/mlx5/tl_mlx5_team.c b/src/components/tl/mlx5/tl_mlx5_team.c index 1e5f6ddf56..e5cc29490a 100644 --- a/src/components/tl/mlx5/tl_mlx5_team.c +++ b/src/components/tl/mlx5/tl_mlx5_team.c @@ -117,7 +117,7 @@ ucc_status_t ucc_tl_mlx5_team_destroy(ucc_base_team_t *tl_team) return UCC_OK; } -static inline ucc_status_t ucc_tl_mlx5_a2a_team_test(ucc_base_team_t *team) +static inline ucc_status_t ucc_tl_mlx5_alltoall_team_test(ucc_base_team_t *team) { ucc_tl_mlx5_team_t *tl_team = ucc_derived_of(team, ucc_tl_mlx5_team_t); @@ -198,7 +198,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) tl_warn(UCC_TL_TEAM_LIB(tl_team), "ibv_dereg_mr failed"); } - if (ibv_destroy_cq(comm->rcq)) { + if (ibv_destroy_cq(comm->mcast.rcq)) { tl_warn(UCC_TL_TEAM_LIB(tl_team), "ibv_destroy_cq failed"); } @@ -253,7 +253,7 @@ ucc_status_t ucc_tl_mlx5_team_create_test(ucc_base_team_t *team) goto initial_sync_post; } - a2a_status = ucc_tl_mlx5_a2a_team_test(team); + a2a_status = ucc_tl_mlx5_alltoall_team_test(team); if (a2a_status < 0) { tl_warn(team->context->lib, "ALLTOALL tl team: %p creation failed %d", team, a2a_status); diff --git a/src/components/tl/mlx5/tl_mlx5_wqe.c b/src/components/tl/mlx5/tl_mlx5_wqe.c index cf4d590658..d0ece52902 100644 --- a/src/components/tl/mlx5/tl_mlx5_wqe.c +++ b/src/components/tl/mlx5/tl_mlx5_wqe.c @@ -153,14 +153,14 @@ ucc_status_t ucc_tl_mlx5_post_umr(struct ibv_qp * qp, sizeof(struct mlx5_wqe_mkey_context_seg) + sizeof(struct mlx5_wqe_umr_pointer_seg)) / DS_SIZE; - uint8_t fm_ce_se = - MLX5_WQE_CTRL_INITIATOR_SMALL_FENCE | MLX5_WQE_CTRL_CQ_UPDATE; - struct ibv_qp_ex * qp_ex = ibv_qp_to_qp_ex(qp); - struct mlx5dv_qp_ex * mqp = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); - struct mlx5_wqe_ctrl_seg * ctrl; - struct mlx5_wqe_umr_ctrl_seg * umr_ctrl_seg; + uint8_t fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; + struct ibv_qp_ex *qp_ex = ibv_qp_to_qp_ex(qp); + struct mlx5dv_qp_ex *mqp = + mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); + struct mlx5_wqe_ctrl_seg *ctrl; + struct mlx5_wqe_umr_ctrl_seg *umr_ctrl_seg; struct mlx5_wqe_mkey_context_seg *mk_seg; - struct mlx5_wqe_umr_pointer_seg * pseg; + struct mlx5_wqe_umr_pointer_seg *pseg; char wqe_desc[n_ds * DS_SIZE]; int xlat_size; @@ -270,12 +270,12 @@ ucc_status_t ucc_tl_mlx5_post_wait_on_data(struct ibv_qp *qp, uint64_t value, void *task_ptr) { - uint32_t opcode = MLX5_OPCODE_WAIT; - uint32_t opmode = 0x1; //wait on data - uint32_t n_ds = 3; //CTRL + Wait on Data of Size 2 - struct ibv_qp_ex * qp_ex = ibv_qp_to_qp_ex(qp); - struct mlx5dv_qp_ex *mqp = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); - uint8_t fm_ce_se = MLX5_WQE_CTRL_FENCE | MLX5_WQE_CTRL_CQ_UPDATE; + uint32_t opcode = MLX5_OPCODE_WAIT; + uint32_t opmode = 0x1; //wait on data + uint32_t n_ds = 3; //CTRL + Wait on Data of Size 2 + struct ibv_qp_ex *qp_ex = ibv_qp_to_qp_ex(qp); + struct mlx5dv_qp_ex *mqp = mlx5dv_qp_ex_from_ibv_qp_ex(qp_ex); + uint8_t fm_ce_se = MLX5_WQE_CTRL_CQ_UPDATE; char wqe_desc[n_ds * DS_SIZE]; struct mlx5_wqe_ctrl_seg *ctrl; wait_on_data_seg_t * wseg; diff --git a/src/components/tl/ucp/allgather/allgather.c b/src/components/tl/ucp/allgather/allgather.c index 769c4fb981..e9acb0a648 100644 --- a/src/components/tl/ucp/allgather/allgather.c +++ b/src/components/tl/ucp/allgather/allgather.c @@ -58,3 +58,37 @@ char *ucc_tl_ucp_allgather_score_str_get(ucc_tl_ucp_team_t *team) UCC_TL_UCP_ALLGATHER_DEFAULT_ALG_SELECT_STR, algo_num); return str; } + +ucc_status_t loopback_self_copy(void *rbuf, void *sbuf, size_t data_size, + ucc_memory_type_t rmem, ucc_memory_type_t smem, + ucc_rank_t rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task) +{ + ucc_status_t status; + status = ucc_tl_ucp_send_nb(sbuf, data_size, smem, rank, team, task); + if (UCC_OK != status) { + task->super.status = status; + return task->super.status; + } + status = ucc_tl_ucp_recv_nb(rbuf, data_size, rmem, rank, team, task); + if (UCC_OK != status) { + task->super.status = status; + return task->super.status; + } + return UCC_OK; +} +ucc_status_t allgather_copy(void *rbuf, void *sbuf, size_t data_size, + ucc_memory_type_t rmem, ucc_memory_type_t smem, + ucc_rank_t rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task) +{ + ucc_status_t status; + int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback; + if (use_loopback) { + status = loopback_self_copy(rbuf, sbuf, data_size, rmem, smem, rank, + team, task); + } else { + status = ucc_mc_memcpy(rbuf, sbuf, data_size, rmem, smem); + } + return status; +} diff --git a/src/components/tl/ucp/allgather/allgather.h b/src/components/tl/ucp/allgather/allgather.h index 61733a4ab7..83f9ab7e26 100644 --- a/src/components/tl/ucp/allgather/allgather.h +++ b/src/components/tl/ucp/allgather/allgather.h @@ -7,6 +7,7 @@ #define ALLGATHER_H_ #include "../tl_ucp.h" #include "../tl_ucp_coll.h" +#include "tl_ucp_sendrecv.h" enum { UCC_TL_UCP_ALLGATHER_ALG_KNOMIAL, @@ -38,6 +39,16 @@ static inline int ucc_tl_ucp_allgather_alg_from_str(const char *str) ucc_status_t ucc_tl_ucp_allgather_init(ucc_tl_ucp_task_t *task); +ucc_status_t loopback_self_copy(void *rbuf, void *sbuf, size_t data_size, + ucc_memory_type_t rmem, ucc_memory_type_t smem, + ucc_rank_t rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task); + +ucc_status_t allgather_copy(void *rbuf, void *sbuf, size_t data_size, + ucc_memory_type_t rmem, ucc_memory_type_t smem, + ucc_rank_t rank, ucc_tl_ucp_team_t *team, + ucc_tl_ucp_task_t *task); + /* Ring */ ucc_status_t ucc_tl_ucp_allgather_ring_init(ucc_base_coll_args_t *coll_args, ucc_base_team_t *team, diff --git a/src/components/tl/ucp/allgather/allgather_knomial.c b/src/components/tl/ucp/allgather/allgather_knomial.c index 1fbcf773cc..58c2054902 100644 --- a/src/components/tl/ucp/allgather/allgather_knomial.c +++ b/src/components/tl/ucp/allgather/allgather_knomial.c @@ -13,6 +13,7 @@ #include "coll_patterns/sra_knomial.h" #include "utils/ucc_math.h" #include "utils/ucc_coll_utils.h" +#include "allgather.h" #define SAVE_STATE(_phase) \ do { \ @@ -54,8 +55,7 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, - ucc_tl_ucp_task_t); + ucc_tl_ucp_task_t * task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); ucc_coll_args_t *args = &TASK_ARGS(task); ucc_tl_ucp_team_t *team = TASK_TEAM(task); ucc_kn_radix_t radix = task->allgather_kn.p.radix; @@ -66,10 +66,10 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) size_t dt_size = ucc_dt_size(GET_DT(args)); ucc_rank_t size = task->subset.map.ep_num; size_t data_size = GET_TOTAL_COUNT(args, size); - ucc_rank_t broot = args->coll_type == UCC_COLL_TYPE_BCAST ? - args->root : 0; - ucc_rank_t rank = VRANK(task->subset.myrank, broot, size); - size_t local = GET_LOCAL_COUNT(args, size, rank); + ucc_rank_t broot = args->coll_type == UCC_COLL_TYPE_BCAST ? args->root : 0; + ucc_rank_t rank = VRANK(task->subset.myrank, broot, size); + size_t local = GET_LOCAL_COUNT(args, size, rank); + int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback; void *sbuf; ptrdiff_t peer_seg_offset, local_seg_offset; ucc_rank_t peer, peer_dist; @@ -78,8 +78,14 @@ void ucc_tl_ucp_allgather_knomial_progress(ucc_coll_task_t *coll_task) ucc_status_t status; size_t extra_count; - EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test", - task->allgather_kn.etask); + if (use_loopback) { + if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { + return; + } + } else { + EXEC_TASK_TEST(UCC_KN_PHASE_INIT, "failed during ee task test", + task->allgather_kn.etask); + } task->allgather_kn.etask = NULL; UCC_KN_GOTO_PHASE(task->allgather_kn.phase); if (KN_NODE_EXTRA == node_type) { @@ -209,6 +215,7 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task) ct == UCC_COLL_TYPE_BCAST ? args->root : 0, size); ucc_ee_executor_task_args_t eargs = {0}; + int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback; ucc_status_t status; ptrdiff_t offset; ucc_ee_executor_t *exec; @@ -225,21 +232,34 @@ ucc_status_t ucc_tl_ucp_allgather_knomial_start(ucc_coll_task_t *coll_task) ucc_dt_size(args->dst.info.datatype); rbuf = args->dst.info.buffer; if (!UCC_IS_INPLACE(*args)) { - status = ucc_coll_task_get_executor(&task->super, &exec); - if (ucc_unlikely(status != UCC_OK)) { - task->super.status = status; - return status; - } - eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; - eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset); - eargs.copy.src = args->src.info.buffer; - eargs.copy.len = args->src.info.count * - ucc_dt_size(args->src.info.datatype); - status = ucc_ee_executor_task_post(exec, &eargs, - &task->allgather_kn.etask); - if (ucc_unlikely(status != UCC_OK)) { - task->super.status = status; - return status; + if (use_loopback) { + status = loopback_self_copy( + PTR_OFFSET(args->dst.info.buffer, offset), + args->src.info.buffer, + args->src.info.count * ucc_dt_size(args->src.info.datatype), + args->dst.info.mem_type, args->src.info.mem_type, rank, + team, task); + if (ucc_unlikely(status != UCC_OK)) { + return status; + } + } else { + /* Executer */ + status = ucc_coll_task_get_executor(&task->super, &exec); + if (ucc_unlikely(status != UCC_OK)) { + task->super.status = status; + return status; + } + eargs.task_type = UCC_EE_EXECUTOR_TASK_COPY; + eargs.copy.dst = PTR_OFFSET(args->dst.info.buffer, offset); + eargs.copy.src = args->src.info.buffer; + eargs.copy.len = + args->src.info.count * ucc_dt_size(args->src.info.datatype); + status = ucc_ee_executor_task_post(exec, &eargs, + &task->allgather_kn.etask); + if (ucc_unlikely(status != UCC_OK)) { + task->super.status = status; + return status; + } } } } else if (ct == UCC_COLL_TYPE_ALLGATHERV) { diff --git a/src/components/tl/ucp/allgather/allgather_neighbor.c b/src/components/tl/ucp/allgather/allgather_neighbor.c index 534c197e4e..6520344425 100644 --- a/src/components/tl/ucp/allgather/allgather_neighbor.c +++ b/src/components/tl/ucp/allgather/allgather_neighbor.c @@ -81,9 +81,11 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *coll_task) ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; size_t count = TASK_ARGS(task).dst.info.count; size_t data_size = (count / tsize) * ucc_dt_size(dt); + int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback; ucc_rank_t neighbors[2], i; int i_parity, even_rank; void *tmprecv, *tmpsend; + int counter; if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; @@ -98,8 +100,13 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *coll_task) neighbors[1] = (trank + 1) % tsize; } - while (task->tagged.send_posted < (tsize / 2)) { - i = task->tagged.send_posted; + if ((!UCC_IS_INPLACE(TASK_ARGS(task))) && use_loopback) { + counter = task->tagged.send_posted - 1; + } else { + counter = task->tagged.send_posted; + } + while (counter < (tsize / 2)) { + i = counter; i_parity = i % 2; tmprecv = @@ -118,6 +125,11 @@ void ucc_tl_ucp_allgather_neighbor_progress(ucc_coll_task_t *coll_task) if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } + if ((!UCC_IS_INPLACE(TASK_ARGS(task))) && use_loopback) { + counter = task->tagged.send_posted - 1; + } else { + counter = task->tagged.send_posted; + } } ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); @@ -150,13 +162,15 @@ ucc_status_t ucc_tl_ucp_allgather_neighbor_start(ucc_coll_task_t *coll_task) ucc_tl_ucp_task_reset(task, UCC_INPROGRESS); if (!UCC_IS_INPLACE(TASK_ARGS(task))) { - status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * trank), sbuf, - data_size, rmem, smem); + status = allgather_copy(PTR_OFFSET(rbuf, data_size * trank), sbuf, + data_size, rmem, smem, trank, team, task); if (ucc_unlikely(UCC_OK != status)) { return status; } } + while ((!UCC_IS_INPLACE(TASK_ARGS(task))) && (UCC_INPROGRESS == ucc_tl_ucp_test(task))) { + } if (trank % 2) { neighbor = (trank - 1 + tsize) % tsize; } else { diff --git a/src/components/tl/ucp/allgather/allgather_ring.c b/src/components/tl/ucp/allgather/allgather_ring.c index 07178aea25..46a42663b6 100644 --- a/src/components/tl/ucp/allgather/allgather_ring.c +++ b/src/components/tl/ucp/allgather/allgather_ring.c @@ -31,15 +31,16 @@ static ucc_rank_t ucc_tl_ucp_allgather_ring_get_recv_block(ucc_subset_t *subset, void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) { - ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); - ucc_tl_ucp_team_t *team = TASK_TEAM(task); - ucc_rank_t trank = task->subset.myrank; - ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; - void *rbuf = TASK_ARGS(task).dst.info.buffer; - ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; - size_t count = TASK_ARGS(task).dst.info.count; - ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; - size_t data_size = (count / tsize) * ucc_dt_size(dt); + ucc_tl_ucp_task_t *task = ucc_derived_of(coll_task, ucc_tl_ucp_task_t); + ucc_tl_ucp_team_t *team = TASK_TEAM(task); + ucc_rank_t trank = task->subset.myrank; + ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; + void * rbuf = TASK_ARGS(task).dst.info.buffer; + ucc_memory_type_t rmem = TASK_ARGS(task).dst.info.mem_type; + size_t count = TASK_ARGS(task).dst.info.count; + ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; + size_t data_size = (count / tsize) * ucc_dt_size(dt); + int use_loopback = UCC_TL_UCP_TEAM_LIB(team)->cfg.allgather_use_loopback; ucc_rank_t sendto, recvfrom, sblock, rblock; int step; void *buf; @@ -49,9 +50,13 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) } sendto = ucc_ep_map_eval(task->subset.map, (trank + 1) % tsize); recvfrom = ucc_ep_map_eval(task->subset.map, (trank - 1 + tsize) % tsize); - - while (task->tagged.send_posted < tsize - 1) { + if (use_loopback && (!UCC_IS_INPLACE(TASK_ARGS(task)))) { + step = task->tagged.send_posted - 1; + } else { step = task->tagged.send_posted; + } + + while (step < tsize - 1) { sblock = task->allgather_ring.get_send_block(&task->subset, trank, tsize, step); rblock = task->allgather_ring.get_recv_block(&task->subset, trank, @@ -67,6 +72,11 @@ void ucc_tl_ucp_allgather_ring_progress(ucc_coll_task_t *coll_task) if (UCC_INPROGRESS == ucc_tl_ucp_test(task)) { return; } + if (use_loopback && (!UCC_IS_INPLACE(TASK_ARGS(task)))) { + step = task->tagged.send_posted - 1; + } else { + step = task->tagged.send_posted; + } } ucc_assert(UCC_TL_UCP_TASK_P2P_COMPLETE(task)); task->super.status = UCC_OK; @@ -86,6 +96,7 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task) ucc_datatype_t dt = TASK_ARGS(task).dst.info.datatype; ucc_rank_t trank = task->subset.myrank; ucc_rank_t tsize = (ucc_rank_t)task->subset.map.ep_num; + ucc_rank_t rank = ucc_ep_map_eval(task->subset.map, trank); size_t data_size = (count / tsize) * ucc_dt_size(dt); ucc_status_t status; ucc_rank_t block; @@ -96,13 +107,12 @@ ucc_status_t ucc_tl_ucp_allgather_ring_start(ucc_coll_task_t *coll_task) if (!UCC_IS_INPLACE(TASK_ARGS(task))) { block = task->allgather_ring.get_send_block(&task->subset, trank, tsize, 0); - status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * block), - sbuf, data_size, rmem, smem); + status = allgather_copy(PTR_OFFSET(rbuf, data_size * block), sbuf, + data_size, rmem, smem, rank, team, task); if (ucc_unlikely(UCC_OK != status)) { return status; } } - return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); } diff --git a/src/components/tl/ucp/allgather/allgather_sparbit.c b/src/components/tl/ucp/allgather/allgather_sparbit.c index 0edfc4d4a3..f33292e2ef 100644 --- a/src/components/tl/ucp/allgather/allgather_sparbit.c +++ b/src/components/tl/ucp/allgather/allgather_sparbit.c @@ -131,8 +131,8 @@ ucc_status_t ucc_tl_ucp_allgather_sparbit_start(ucc_coll_task_t *coll_task) task->allgather_sparbit.data_expected = 1; if (!UCC_IS_INPLACE(TASK_ARGS(task))) { - status = ucc_mc_memcpy(PTR_OFFSET(rbuf, data_size * trank), sbuf, - data_size, rmem, smem); + status = allgather_copy(PTR_OFFSET(rbuf, data_size * trank), sbuf, + data_size, rmem, smem, trank, team, task); if (ucc_unlikely(UCC_OK != status)) { return status; } diff --git a/src/components/tl/ucp/tl_ucp.c b/src/components/tl/ucp/tl_ucp.c index 7db99bdaf2..a0e49de3c9 100644 --- a/src/components/tl/ucp/tl_ucp.c +++ b/src/components/tl/ucp/tl_ucp.c @@ -48,7 +48,7 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = { ucc_offsetof(ucc_tl_ucp_lib_config_t, alltoallv_pairwise_num_posts), UCC_CONFIG_TYPE_ULUNITS}, -/* TODO: add radix to config once it's fully supported by the algorithm + /* TODO: add radix to config once it's fully supported by the algorithm {"ALLTOALLV_HYBRID_RADIX", "2", "Radix of the Hybrid Alltoallv algorithm", ucc_offsetof(ucc_tl_ucp_lib_config_t, alltoallv_hybrid_radix), @@ -140,6 +140,12 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = { ucc_offsetof(ucc_tl_ucp_lib_config_t, allgather_kn_radix), UCC_CONFIG_TYPE_UINT}, + {"ALLGATHER_USE_LOOPBACK", "0", + "If set to 1 performs network loopback for self copy, otherwise uses mc " + "cuda copy", + ucc_offsetof(ucc_tl_ucp_lib_config_t, allgather_use_loopback), + UCC_CONFIG_TYPE_BOOL}, + {"BCAST_KN_RADIX", "4", "Radix of the recursive-knomial bcast algorithm", ucc_offsetof(ucc_tl_ucp_lib_config_t, bcast_kn_radix), UCC_CONFIG_TYPE_UINT}, @@ -196,10 +202,8 @@ ucc_config_field_t ucc_tl_ucp_lib_config_table[] = { ucc_offsetof(ucc_tl_ucp_lib_config_t, reduce_scatterv_ring_bidirectional), UCC_CONFIG_TYPE_BOOL}, - {"USE_TOPO", "try", - "Allow usage of tl ucp topo", - ucc_offsetof(ucc_tl_ucp_lib_config_t, use_topo), - UCC_CONFIG_TYPE_TERNARY}, + {"USE_TOPO", "try", "Allow usage of tl ucp topo", + ucc_offsetof(ucc_tl_ucp_lib_config_t, use_topo), UCC_CONFIG_TYPE_TERNARY}, {"RANKS_REORDERING", "y", "Use topology information in TL UCP to reorder ranks. Requires topo info", diff --git a/src/components/tl/ucp/tl_ucp.h b/src/components/tl/ucp/tl_ucp.h index 3c439f4ae5..6d31c5aead 100644 --- a/src/components/tl/ucp/tl_ucp.h +++ b/src/components/tl/ucp/tl_ucp.h @@ -55,6 +55,7 @@ typedef struct ucc_tl_ucp_lib_config { ucc_mrange_uint_t allreduce_sra_kn_radix; uint32_t reduce_scatter_kn_radix; uint32_t allgather_kn_radix; + int allgather_use_loopback; uint32_t bcast_kn_radix; ucc_mrange_uint_t bcast_sag_kn_radix; uint32_t reduce_kn_radix; diff --git a/src/core/ucc_global_opts.h b/src/core/ucc_global_opts.h index 203ca65e9d..85cd9d4835 100644 --- a/src/core/ucc_global_opts.h +++ b/src/core/ucc_global_opts.h @@ -28,7 +28,7 @@ typedef struct ucc_global_config { char *install_path; int initialized; /* Profiling mode */ - unsigned profile_mode; + uint64_t profile_mode; /* Profiling output file name */ char *profile_file; diff --git a/test/gtest/coll/test_allgather.cc b/test/gtest/coll/test_allgather.cc index c48bb8303d..def6b64e03 100644 --- a/test/gtest/coll/test_allgather.cc +++ b/test/gtest/coll/test_allgather.cc @@ -9,7 +9,8 @@ using Param_0 = std::tuple; using Param_1 = std::tuple; -using Param_2 = std::tuple; +using Param_2 = std::tuple; class test_allgather : public UccCollArgs, public ucc::test { @@ -265,10 +266,12 @@ UCC_TEST_P(test_allgather_alg, alg) const gtest_ucc_inplace_t inplace = std::get<3>(GetParam()); int n_procs = 5; char tune[32]; + std::string use_loopback = std::get<5>(GetParam()); sprintf(tune, "allgather:@%s:inf", std::get<4>(GetParam()).c_str()); - ucc_job_env_t env = {{"UCC_CL_BASIC_TUNE", "inf"}, - {"UCC_TL_UCP_TUNE", tune}}; + ucc_job_env_t env = {{"UCC_CL_BASIC_TUNE", "inf"}, + {"UCC_TL_UCP_TUNE", tune}, + {"UCC_TL_UCP_ALLGATHER_USE_LOOPBACK", use_loopback}}; UccJob job(n_procs, UccJob::UCC_JOB_CTX_GLOBAL, env); UccTeam_h team = job.create_team(n_procs); UccCollCtxVec ctxs; @@ -294,15 +297,20 @@ INSTANTIATE_TEST_CASE_P( #else ::testing::Values(UCC_MEMORY_TYPE_HOST), #endif - ::testing::Values(1,3,8192), // count + ::testing::Values(1, 3, 8192), // count ::testing::Values(TEST_INPLACE, TEST_NO_INPLACE), - ::testing::Values("knomial", "ring", "neighbor", "bruck", "sparbit")), - [](const testing::TestParamInfo& info) { - std::string name; - name += ucc_datatype_str(std::get<0>(info.param)); - name += std::string("_") + std::string(ucc_mem_type_str(std::get<1>(info.param))); - name += std::string("_count_")+std::to_string(std::get<2>(info.param)); - name += std::string("_inplace_")+std::to_string(std::get<3>(info.param)); - name += std::string("_")+std::get<4>(info.param); - return name; - }); + ::testing::Values("knomial", "ring", "neighbor", "bruck", "sparbit"), + ::testing::Values("1", "0")), + [](const testing::TestParamInfo &info) { + std::string name; + name += ucc_datatype_str(std::get<0>(info.param)); + name += std::string("_") + + std::string(ucc_mem_type_str(std::get<1>(info.param))); + name += + std::string("_count_") + std::to_string(std::get<2>(info.param)); + name += + std::string("_inplace_") + std::to_string(std::get<3>(info.param)); + name += std::string("_") + std::get<4>(info.param); + name += std::string("_use_loopback_") + std::get<5>(info.param); + return name; + }); \ No newline at end of file