diff --git a/src/components/cl/basic/cl_basic_context.c b/src/components/cl/basic/cl_basic_context.c index 3dc3a608ff..66a2ecf0d0 100644 --- a/src/components/cl/basic/cl_basic_context.c +++ b/src/components/cl/basic/cl_basic_context.c @@ -61,20 +61,20 @@ UCC_CLASS_CLEANUP_FUNC(ucc_cl_basic_context_t) ucc_free(self->super.tl_ctxs); } -ucc_status_t ucc_cl_basic_mem_map(const ucc_base_context_t *context, int type, +ucc_status_t ucc_cl_basic_mem_map(const ucc_base_context_t *context, int type, /* NOLINT */ void *address, size_t len, void *memh, void *tl_h) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_cl_basic_mem_unmap(const ucc_base_context_t *context, int type, +ucc_status_t ucc_cl_basic_mem_unmap(const ucc_base_context_t *context, int type, /* NOLINT */ void *tl_h) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_cl_basic_memh_pack(const ucc_base_context_t *context, +ucc_status_t ucc_cl_basic_memh_pack(const ucc_base_context_t *context, /* NOLINT */ void *memh, void **packed_buffer) { return UCC_ERR_NOT_SUPPORTED; diff --git a/src/components/cl/hier/cl_hier_context.c b/src/components/cl/hier/cl_hier_context.c index 44bb0636f5..fc10358b3e 100644 --- a/src/components/cl/hier/cl_hier_context.c +++ b/src/components/cl/hier/cl_hier_context.c @@ -92,20 +92,20 @@ ucc_cl_hier_get_context_attr(const ucc_base_context_t *context, /* NOLINT */ return UCC_OK; } -ucc_status_t ucc_cl_hier_mem_map(const ucc_base_context_t *context, int type, +ucc_status_t ucc_cl_hier_mem_map(const ucc_base_context_t *context, int type, /* NOLINT */ void *address, size_t len, void *memh, void *tl_h) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_cl_hier_mem_unmap(const ucc_base_context_t *context, int type, +ucc_status_t ucc_cl_hier_mem_unmap(const ucc_base_context_t *context, int type, /* NOLINT */ void *tl_h) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_cl_hier_memh_pack(const ucc_base_context_t *context, +ucc_status_t ucc_cl_hier_memh_pack(const ucc_base_context_t *context, /* NOLINT */ void *memh, void **packed_buffer) { return UCC_ERR_NOT_SUPPORTED; diff --git a/src/components/tl/cuda/tl_cuda_context.c b/src/components/tl/cuda/tl_cuda_context.c index e6859a1471..44da168702 100644 --- a/src/components/tl/cuda/tl_cuda_context.c +++ b/src/components/tl/cuda/tl_cuda_context.c @@ -76,19 +76,19 @@ UCC_CLASS_INIT_FUNC(ucc_tl_cuda_context_t, return status; } -ucc_status_t ucc_tl_cuda_mem_map(const ucc_base_context_t *context, +ucc_status_t ucc_tl_cuda_mem_map(const ucc_base_context_t *context, /* NOLINT */ void *address, size_t len, void *memh) { return UCC_ERR_NOT_IMPLEMENTED; } -ucc_status_t ucc_tl_cuda_mem_unmap(const ucc_base_context_t *context, +ucc_status_t ucc_tl_cuda_mem_unmap(const ucc_base_context_t *context, /* NOLINT */ void *memh) { return UCC_ERR_NOT_IMPLEMENTED; } -ucc_status_t ucc_tl_cuda_memh_pack(const ucc_base_context_t *context, +ucc_status_t ucc_tl_cuda_memh_pack(const ucc_base_context_t *context, /* NOLINT */ void *memh, void **pack_buffer) { return UCC_ERR_NOT_IMPLEMENTED; diff --git a/src/components/tl/mlx5/tl_mlx5_context.c b/src/components/tl/mlx5/tl_mlx5_context.c index 1fdf3f1093..b31bd5198f 100644 --- a/src/components/tl/mlx5/tl_mlx5_context.c +++ b/src/components/tl/mlx5/tl_mlx5_context.c @@ -303,20 +303,20 @@ ucc_status_t ucc_tl_mlx5_context_create_epilog(ucc_base_context_t *context) return ucc_tl_mlx5_context_ib_ctx_pd_setup(context); } -ucc_status_t ucc_tl_mlx5_mem_map(const ucc_base_context_t *context, int type, +ucc_status_t ucc_tl_mlx5_mem_map(const ucc_base_context_t *context, int type, /* NOLINT */ void *address, size_t len, void *memh, void *tl_h) { return UCC_ERR_NOT_IMPLEMENTED; } -ucc_status_t ucc_tl_mlx5_mem_unmap(const ucc_base_context_t *context, int type, +ucc_status_t ucc_tl_mlx5_mem_unmap(const ucc_base_context_t *context, int type, /* NOLINT */ void *memh) { return UCC_ERR_NOT_IMPLEMENTED; } -ucc_status_t ucc_tl_mlx5_memh_pack(const ucc_base_context_t *context, +ucc_status_t ucc_tl_mlx5_memh_pack(const ucc_base_context_t *context, /* NOLINT */ void *memh, void **pack_buffer) { return UCC_ERR_NOT_IMPLEMENTED; diff --git a/src/components/tl/nccl/tl_nccl_context.c b/src/components/tl/nccl/tl_nccl_context.c index 14396d7e1c..84fdca9f6b 100644 --- a/src/components/tl/nccl/tl_nccl_context.c +++ b/src/components/tl/nccl/tl_nccl_context.c @@ -224,20 +224,20 @@ ucc_tl_nccl_get_context_attr(const ucc_base_context_t *context, /* NOLINT */ return UCC_OK; } -ucc_status_t ucc_tl_nccl_mem_map(const ucc_base_context_t *context, int type, +ucc_status_t ucc_tl_nccl_mem_map(const ucc_base_context_t *context, int type, /* NOLINT */ void *address, size_t len, void *memh, void *tl_h) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_tl_nccl_mem_unmap(const ucc_base_context_t *context, int type, +ucc_status_t ucc_tl_nccl_mem_unmap(const ucc_base_context_t *context, int type, /* NOLINT */ void *memh) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_tl_nccl_memh_pack(const ucc_base_context_t *context, +ucc_status_t ucc_tl_nccl_memh_pack(const ucc_base_context_t *context, /* NOLINT */ void *memh, void **pack_buffer) { return UCC_ERR_NOT_SUPPORTED; diff --git a/src/components/tl/rccl/tl_rccl_context.c b/src/components/tl/rccl/tl_rccl_context.c index 36c2cf6184..4d52c38756 100644 --- a/src/components/tl/rccl/tl_rccl_context.c +++ b/src/components/tl/rccl/tl_rccl_context.c @@ -121,20 +121,20 @@ ucc_tl_rccl_get_context_attr(const ucc_base_context_t *context, /* NOLINT */ return UCC_OK; } -ucc_status_t ucc_tl_rccl_mem_map(const ucc_base_context_t *context, int type, +ucc_status_t ucc_tl_rccl_mem_map(const ucc_base_context_t *context, int type, /* NOLINT */ void *address, size_t len, void *memh, void *tl_h) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_tl_rccl_mem_unmap(const ucc_base_context_t *context, int type, +ucc_status_t ucc_tl_rccl_mem_unmap(const ucc_base_context_t *context, int type, /* NOLINT */ void *memh) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_tl_rccl_memh_pack(const ucc_base_context_t *context, +ucc_status_t ucc_tl_rccl_memh_pack(const ucc_base_context_t *context, /* NOLINT */ void *memh, void **pack_buffer) { return UCC_ERR_NOT_SUPPORTED; diff --git a/src/components/tl/self/tl_self_context.c b/src/components/tl/self/tl_self_context.c index 6c11f3a428..d092019dde 100644 --- a/src/components/tl/self/tl_self_context.c +++ b/src/components/tl/self/tl_self_context.c @@ -50,20 +50,20 @@ ucc_tl_self_get_context_attr(const ucc_base_context_t *context, /* NOLINT */ return UCC_OK; } -ucc_status_t ucc_tl_self_mem_map(const ucc_base_context_t *context, int type, +ucc_status_t ucc_tl_self_mem_map(const ucc_base_context_t *context, int type, /* NOLINT */ void *address, size_t len, void *memh, void *tl_h) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_tl_self_mem_unmap(const ucc_base_context_t *context, int type, +ucc_status_t ucc_tl_self_mem_unmap(const ucc_base_context_t *context, int type, /* NOLINT */ void *memh) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_tl_self_memh_pack(const ucc_base_context_t *context, +ucc_status_t ucc_tl_self_memh_pack(const ucc_base_context_t *context, /* NOLINT */ void *memh, void **pack_buffer) { return UCC_ERR_NOT_SUPPORTED; diff --git a/src/components/tl/sharp/tl_sharp_context.c b/src/components/tl/sharp/tl_sharp_context.c index 8b23cbbe15..add0c617f5 100644 --- a/src/components/tl/sharp/tl_sharp_context.c +++ b/src/components/tl/sharp/tl_sharp_context.c @@ -512,20 +512,20 @@ ucc_status_t ucc_tl_sharp_get_context_attr(const ucc_base_context_t *context, /* return UCC_OK; } -ucc_status_t ucc_tl_sharp_mem_map(const ucc_base_context_t *context, int type, +ucc_status_t ucc_tl_sharp_mem_map(const ucc_base_context_t *context, int type, /* NOLINT */ void *address, size_t len, void *memh, void *tl_h) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_tl_sharp_mem_unmap(const ucc_base_context_t *context, int type, +ucc_status_t ucc_tl_sharp_mem_unmap(const ucc_base_context_t *context, int type, /* NOLINT */ void *memh) { return UCC_ERR_NOT_SUPPORTED; } -ucc_status_t ucc_tl_sharp_memh_pack(const ucc_base_context_t *context, +ucc_status_t ucc_tl_sharp_memh_pack(const ucc_base_context_t *context, /* NOLINT */ void *memh, void **pack_buffer) { return UCC_ERR_NOT_SUPPORTED; diff --git a/src/components/tl/ucp/alltoall/alltoall_onesided.c b/src/components/tl/ucp/alltoall/alltoall_onesided.c index 3a6666908b..8d5b57b2e5 100644 --- a/src/components/tl/ucp/alltoall/alltoall_onesided.c +++ b/src/components/tl/ucp/alltoall/alltoall_onesided.c @@ -35,18 +35,18 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask) UCPCHECK_GOTO( ucc_tl_ucp_put_nb((void *)(src + start * nelems), (void *)dest, nelems, - start, src_memh[start], dst_memh[start], team, task), + start, *src_memh, dst_memh[start], team, task), task, out); - UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, start, src_memh[start], + UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, start, *src_memh, dst_memh[start], team), task, out); for (peer = (start + 1) % gsize; peer != start; peer = (peer + 1) % gsize) { UCPCHECK_GOTO(ucc_tl_ucp_put_nb( (void *)(src + peer * nelems), (void *)dest, nelems, - peer, src_memh[peer], dst_memh[peer], team, task), + peer, *src_memh, dst_memh[peer], team, task), task, out); - UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, src_memh[peer], + UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, *src_memh, dst_memh[peer], team), task, out); } diff --git a/src/components/tl/ucp/alltoallv/alltoallv_onesided.c b/src/components/tl/ucp/alltoallv/alltoallv_onesided.c index ffbec5ee9a..fa92011c8b 100644 --- a/src/components/tl/ucp/alltoallv/alltoallv_onesided.c +++ b/src/components/tl/ucp/alltoallv/alltoallv_onesided.c @@ -48,11 +48,10 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask) UCPCHECK_GOTO(ucc_tl_ucp_put_nb(PTR_OFFSET(src, sd_disp), PTR_OFFSET(dest, dd_disp), - data_size, peer, src_memh[peer], + data_size, peer, *src_memh, dst_memh[peer], team, task), task, out); - UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, src_memh[peer], - dst_memh[peer], team), task, out); + UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, *src_memh, dst_memh[peer], team), task, out); } return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super); out: diff --git a/src/components/tl/ucp/tl_ucp.h b/src/components/tl/ucp/tl_ucp.h index 83658a62d5..eccf9c5c0e 100644 --- a/src/components/tl/ucp/tl_ucp.h +++ b/src/components/tl/ucp/tl_ucp.h @@ -197,6 +197,24 @@ extern ucc_config_field_t ucc_tl_ucp_lib_config_table[]; #define UCC_TL_UCP_REMOTE_RKEY(_ctx, _rank, _seg) \ ((_ctx)->rkeys[_rank * _ctx->n_rinfo_segs + _seg]) +#define UCC_TL_UCP_MEMH_TL_HEADERS 2 + +#define UCC_TL_UCP_MEMH_TL_PACKED_HEADERS 2 + +#define UCC_TL_UCP_MEMH_TL_HEADER_SIZE \ + (sizeof(size_t) * UCC_TL_UCP_MEMH_TL_HEADERS) + +#define UCC_TL_UCP_MEMH_TL_PACKED_HEADER_SIZE \ + (sizeof(size_t) * \ + (UCC_TL_UCP_MEMH_TL_HEADERS + UCC_TL_UCP_MEMH_TL_PACKED_HEADERS)) + +#define UCC_TL_UCP_MEMH_TL_KEY_SIZE(buffer) \ + (*(size_t *)(PTR_OFFSET(buffer, UCC_TL_UCP_MEMH_TL_HEADER_SIZE))) + +#define UCC_TL_UCP_MEMH_TL_PACKED_MEMH(buffer) \ + (PTR_OFFSET(buffer, UCC_TL_UCP_MEMH_TL_PACKED_HEADER_SIZE + \ + UCC_TL_UCP_MEMH_TL_KEY_SIZE(buffer))) + extern ucs_memory_type_t ucc_memtype_to_ucs[UCC_MEMORY_TYPE_LAST+1]; void ucc_tl_ucp_pre_register_mem(ucc_tl_ucp_team_t *team, void *addr, diff --git a/src/components/tl/ucp/tl_ucp_context.c b/src/components/tl/ucp/tl_ucp_context.c index 74dc39b74c..08c994ce57 100644 --- a/src/components/tl/ucp/tl_ucp_context.c +++ b/src/components/tl/ucp/tl_ucp_context.c @@ -166,10 +166,8 @@ UCC_CLASS_INIT_FUNC(ucc_tl_ucp_context_t, ucp_params.field_mask = UCP_PARAM_FIELD_FEATURES | UCP_PARAM_FIELD_TAG_SENDER_MASK | UCP_PARAM_FIELD_NAME; - ucp_params.features = UCP_FEATURE_TAG | UCP_FEATURE_AM; - if (params->params.mask & UCC_CONTEXT_PARAM_FIELD_MEM_PARAMS) { - ucp_params.features |= UCP_FEATURE_RMA | UCP_FEATURE_AMO64; - } + ucp_params.features = UCP_FEATURE_TAG | UCP_FEATURE_AM | UCP_FEATURE_EXPORTED_MEMH | + UCP_FEATURE_RMA | UCP_FEATURE_AMO64; ucp_params.tag_sender_mask = UCC_TL_UCP_TAG_SENDER_MASK; ucp_params.name = "UCC_UCP_CONTEXT"; @@ -560,11 +558,13 @@ ucc_status_t ucc_tl_ucp_mem_map_memhbuf(ucc_tl_ucp_context_t *ctx, { ucp_mem_map_params_t mmap_params; ucs_status_t status; + void *packed_memh; *mh = NULL; /* unpack here */ - size_t *key_size = (size_t *)pack_buffer; - void *packed_memh = PTR_OFFSET(pack_buffer, sizeof(size_t) * 2 + *key_size); + packed_memh = UCC_TL_UCP_MEMH_TL_PACKED_MEMH(pack_buffer); +// size_t *key_size = (size_t *)PTR_OFFSET(pack_buffer, sizeof(size_t) * 2); +// void *packed_memh = PTR_OFFSET(pack_buffer, sizeof(size_t) * 4 + *key_size); mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER; mmap_params.exported_memh_buffer = packed_memh; @@ -577,37 +577,24 @@ ucc_status_t ucc_tl_ucp_mem_map_memhbuf(ucc_tl_ucp_context_t *ctx, ucs_status_string(status)); return ucs_status_to_ucc_status(status); } else { - tl_debug(ctx->super.super.lib, - "ucp_mem_map could not map exported memory handle"); + tl_warn(ctx->super.super.lib, + "ucp_mem_map cannot map exported memory handles"); } } return UCC_OK; } -ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context, int type, - void *address, size_t len, void *memh, - void *tl_h) +ucc_status_t ucc_tl_ucp_mem_map_export(ucc_tl_ucp_context_t *ctx, void *address, + size_t len, + int type, + ucc_tl_ucp_memh_data_t *m_data) { - ucc_tl_ucp_context_t *ctx = ucc_derived_of(context, ucc_tl_ucp_context_t); - ucc_status_t ucc_status = UCC_OK; - ucc_mem_map_tl_t *p_memh = (ucc_mem_map_tl_t *)tl_h; - ucc_tl_ucp_memh_data_t *m_data = (ucc_tl_ucp_memh_data_t *)p_memh->tl_data; - ucp_mem_h mh = NULL; - ucc_mem_map_memh_t *l_memh = (ucc_mem_map_memh_t *)memh; - size_t offset = 0; - ucp_mem_map_params_t mmap_params; - ucs_status_t status; - ucp_memh_pack_params_t pack_params; + ucc_status_t ucc_status = UCC_OK; + ucp_mem_h mh; + ucp_mem_map_params_t mmap_params; + ucp_memh_pack_params_t pack_params; + ucs_status_t status; - if (type == UCC_MEM_MAP_TYPE_GLOBAL || !m_data) { - /* either we are importing or m_data is null */ - m_data = ucc_calloc(1, sizeof(ucc_tl_ucp_memh_data_t), "tl data"); - if (!m_data) { - tl_error(ctx->super.super.lib, "failed to allocate tl data"); - return UCC_ERR_NO_MEMORY; - } - p_memh->tl_data = m_data; - } if (type == UCC_MEM_MAP_TYPE_LOCAL) { mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH; @@ -616,35 +603,18 @@ ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context, int type, status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh); if (UCS_OK != status) { - tl_error(ctx->super.super.lib, - "ucp_mem_map failed with error code: %d", status); + tl_error(ctx->super.super.lib, "ucp_mem_map failed with error code: %d", + status); ucc_status = ucs_status_to_ucc_status(status); - } - } else { - for (int i = 0; i < l_memh->num_tls; i++) { - size_t *p = (size_t *)PTR_OFFSET(l_memh->pack_buffer, offset); - - if (tl_h == (void *)&l_memh->tl_h[i]) { - break; - } - /* this is not the index, skip this section of buffer if exists */ - if (p[0] == i) { - offset += p[1]; - } - } - - ucc_status = ucc_tl_ucp_mem_map_memhbuf( - ctx, PTR_OFFSET(l_memh->pack_buffer, offset), &mh); - if (ucc_status != UCC_OK) { - tl_error(ctx->super.super.lib, "ucp_mem_map failed to map memh"); return ucc_status; } - } - m_data->rinfo.mem_h = mh; - m_data->rinfo.va_base = address; - m_data->rinfo.len = len; + m_data->rinfo.mem_h = mh; + m_data->rinfo.va_base = address; + m_data->rinfo.len = len; + + pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS; + pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT; - if (type == UCC_MEM_MAP_TYPE_LOCAL) { status = ucp_memh_pack(mh, &pack_params, &m_data->packed_memh, &m_data->packed_memh_len); if (status != UCS_OK) { @@ -654,18 +624,101 @@ ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context, int type, m_data->packed_memh = 0; m_data->packed_memh_len = 0; } - // pack rkey - status = ucp_rkey_pack(ctx->worker.ucp_context, mh, - &m_data->rinfo.packed_key, - &m_data->rinfo.packed_key_len); - if (status != UCS_OK) { - tl_error(ctx->super.super.lib, "unable to pack rkey with error %s", - ucs_status_string(status)); + } + + if (!m_data->rinfo.mem_h) { + tl_error(ctx->super.super.lib, + "attempting to pack keys with invalid memory handle"); + return UCC_ERR_INVALID_PARAM; + } + + // pack rkey (mem_h should be valid by this point) + status = + ucp_rkey_pack(ctx->worker.ucp_context, m_data->rinfo.mem_h, + &m_data->rinfo.packed_key, &m_data->rinfo.packed_key_len); + if (status != UCS_OK) { + tl_error(ctx->super.super.lib, "unable to pack rkey with error %s", + ucs_status_string(status)); + ucp_mem_unmap(ctx->worker.ucp_context, m_data->rinfo.mem_h); + ucc_status = ucs_status_to_ucc_status(status); + return ucc_status; + } + + return ucc_status; +} + +ucc_status_t ucc_tl_ucp_mem_map_dpu_import(ucc_tl_ucp_context_t *ctx, + void *address, size_t len, + ucc_tl_ucp_memh_data_t *m_data, + ucc_mem_map_memh_t *l_memh, + void *tl_h) +{ + size_t offset = 0; + ucp_mem_h mh; + ucc_status_t ucc_status; + + for (int i = 0; i < l_memh->num_tls; i++) { + size_t *p = (size_t *)PTR_OFFSET(l_memh->pack_buffer, offset); + + if (tl_h == (void *)&l_memh->tl_h[i]) { + break; + } + /* this is not the index, skip this section of buffer if exists */ + if (p[0] == i) { + offset += p[1]; + } + } + + ucc_status = ucc_tl_ucp_mem_map_memhbuf( + ctx, PTR_OFFSET(l_memh->pack_buffer, offset), &mh); + if (ucc_status != UCC_OK) { + tl_error(ctx->super.super.lib, "ucp_mem_map failed to map memh"); + return ucc_status; + } + m_data->rinfo.mem_h = mh; /* used for xgvmi */ + m_data->rinfo.va_base = address; + m_data->rinfo.len = len; + + return UCC_OK; +} + +ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context, int type, + void *address, size_t len, void *memh, + void *tl_h) +{ + ucc_tl_ucp_context_t *ctx = ucc_derived_of(context, ucc_tl_ucp_context_t); + ucc_mem_map_tl_t *p_memh = (ucc_mem_map_tl_t *)tl_h; + ucc_tl_ucp_memh_data_t *m_data = (ucc_tl_ucp_memh_data_t *)p_memh->tl_data; + ucc_mem_map_memh_t *l_memh = (ucc_mem_map_memh_t *)memh; + ucc_status_t ucc_status = UCC_OK; + + /* technically, only need to do this on import */ + if (type == UCC_MEM_MAP_TYPE_GLOBAL || + type == UCC_MEM_MAP_TYPE_DPU_IMPORT || !m_data) { + /* either we are importing or m_data is null */ + m_data = ucc_calloc(1, sizeof(ucc_tl_ucp_memh_data_t), "tl data"); + if (!m_data) { + tl_error(ctx->super.super.lib, "failed to allocate tl data"); + return UCC_ERR_NO_MEMORY; + } + p_memh->tl_data = m_data; + } + + if (type == UCC_MEM_MAP_TYPE_LOCAL || type == UCC_MEM_MAP_TYPE_DPU_EXPORT) { + ucc_status = ucc_tl_ucp_mem_map_export(ctx, address, len, type, m_data); + if (UCC_OK != ucc_status) { + tl_error(ctx->super.super.lib, "failed to export memory handles"); + } + } else if (type == UCC_MEM_MAP_TYPE_DPU_IMPORT) { + ucc_status = ucc_tl_ucp_mem_map_dpu_import(ctx, address, len, m_data, l_memh, tl_h); + if (UCC_OK != ucc_status) { + tl_error(ctx->super.super.lib, "failed to import memory handle"); } - p_memh->packed_size = - m_data->packed_memh_len + m_data->rinfo.packed_key_len; } + p_memh->packed_size = + m_data->packed_memh_len + m_data->rinfo.packed_key_len; + return ucc_status; } diff --git a/src/components/tl/ucp/tl_ucp_sendrecv.h b/src/components/tl/ucp/tl_ucp_sendrecv.h index bcdb43cb3b..a3f30b05f4 100644 --- a/src/components/tl/ucp/tl_ucp_sendrecv.h +++ b/src/components/tl/ucp/tl_ucp_sendrecv.h @@ -239,6 +239,24 @@ static inline ucc_status_t find_tl_index(ucc_mem_map_mem_h map_memh, int *tl_ind return UCC_ERR_NOT_FOUND; } +static inline ucc_status_t ucc_tl_ucp_get_memh(ucc_tl_ucp_team_t *team, ucc_mem_map_mem_h map_memh, void **ucp_memh) +{ + ucc_mem_map_memh_t *memh = map_memh; + ucc_tl_ucp_memh_data_t *tl_data;// = (ucc_tl_ucp_memh_data_t *)memh->tl_h[tl_index].tl_data; + int tl_index = 0; + ucc_status_t status; + + status = find_tl_index(memh, &tl_index); + if (status == UCC_ERR_NOT_FOUND) { + tl_error(UCC_TL_TEAM_LIB(team), + "attempt to perform one-sided operation with malformed mem map handle"); + return status; + } + tl_data = (ucc_tl_ucp_memh_data_t *)memh->tl_h[tl_index].tl_data; + *ucp_memh = tl_data->rinfo.mem_h; + return UCC_OK; +} + static inline ucc_status_t ucc_tl_ucp_check_memh(ucp_ep_h *ep, void *va, uint64_t *rva, ucp_rkey_h *rkey, int tl_index, ucc_mem_map_mem_h map_memh) { @@ -249,6 +267,7 @@ static inline ucc_status_t ucc_tl_ucp_check_memh(ucp_ep_h *ep, void *va, uint64_ ucs_status_t ucs_status; int i; size_t offset; + uint64_t *key_size; base = (uint64_t)memh->address; end = base + memh->len; @@ -271,6 +290,9 @@ static inline ucc_status_t ucc_tl_ucp_check_memh(ucp_ep_h *ep, void *va, uint64_ if (UCS_OK != ucs_status) { return ucs_status_to_ucc_status(ucs_status); } + /* if they don't have a key, they don't have a memh */ + key_size = (uint64_t *)PTR_OFFSET(memh->pack_buffer, offset + sizeof(size_t) * 2); + tl_data->packed_memh = PTR_OFFSET(memh->pack_buffer, offset + sizeof(size_t) * 4 + *key_size); } *rkey = tl_data->rkey; return UCC_OK; @@ -322,7 +344,6 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep, ucc_status_t status; if (src_memh) { - //status = UCC_OK; status = find_tl_index(src_memh, &tl_index); if (status == UCC_ERR_NOT_FOUND) { tl_error(UCC_TL_TEAM_LIB(team), @@ -332,6 +353,7 @@ ucc_tl_ucp_resolve_p2p_by_va(ucc_tl_ucp_team_t *team, void *va, ucp_ep_h *ep, status = ucc_tl_ucp_check_memh(ep, va, rva, rkey, tl_index, src_memh); if (status == UCC_OK) { + printf("found %d va %p rva %lx\n", peer, va, *rva); return UCC_OK; } } @@ -422,22 +444,33 @@ static inline ucc_status_t ucc_tl_ucp_put_nb(void *buffer, void *target, ucs_status_ptr_t ucp_status; ucc_status_t status; ucp_ep_h ep; + void *ucp_memh = NULL; status = ucc_tl_ucp_get_ep(team, dest_group_rank, &ep); if (ucc_unlikely(UCC_OK != status)) { return status; } + status = ucc_tl_ucp_get_memh(team, src_memh, &ucp_memh); + if (ucc_unlikely(UCC_OK != status)) { + return status; + } + status = ucc_tl_ucp_resolve_p2p_by_va(team, target, &ep, dest_group_rank, &rva, &rkey, &segment, src_memh, dest_memh); if (ucc_unlikely(UCC_OK != status)) { return status; } + req_param.op_attr_mask = UCP_OP_ATTR_FIELD_CALLBACK | UCP_OP_ATTR_FIELD_USER_DATA; req_param.cb.send = ucc_tl_ucp_put_completion_cb; req_param.user_data = (void *)task; + if (ucp_memh) { + req_param.op_attr_mask |= UCP_OP_ATTR_FIELD_MEMH; + req_param.memh = ucp_memh; + } ucp_status = ucp_put_nbx(ep, buffer, msglen, rva, rkey, &req_param); @@ -467,6 +500,7 @@ static inline ucc_status_t ucc_tl_ucp_get_nb(void *buffer, void *target, ucs_status_ptr_t ucp_status; ucc_status_t status; ucp_ep_h ep; + //void *packed_memh; status = ucc_tl_ucp_get_ep(team, dest_group_rank, &ep); if (ucc_unlikely(UCC_OK != status)) { diff --git a/src/core/ucc_context.c b/src/core/ucc_context.c index d63c144d60..341ad1d60a 100644 --- a/src/core/ucc_context.c +++ b/src/core/ucc_context.c @@ -1133,12 +1133,16 @@ ucc_status_t ucc_context_get_attr(ucc_context_t *context, } ucc_status_t ucc_mem_map_import(ucc_context_h context, + ucc_mem_map_flags_t flags, ucc_mem_map_params_t *params, size_t *memh_size, ucc_mem_map_mem_h *memh) { ucc_context_t *ctx = (ucc_context_t *)context; ucc_status_t status = UCC_OK; ucc_config_names_array_t *tls = &ctx->all_tls; + ucc_mem_map_type_t type = (flags == UCC_MEM_MAP_IMPORT) + ? UCC_MEM_MAP_TYPE_GLOBAL + : UCC_MEM_MAP_TYPE_DPU_IMPORT; int i; ucc_mem_map_memh_t *local_memh; ucc_tl_lib_t *tl_lib; @@ -1152,18 +1156,15 @@ ucc_status_t ucc_mem_map_import(ucc_context_h context, return UCC_ERR_INVALID_PARAM; } - /* FIXME: it's legal for the user to export and immediately import a handle, - * is this an issue? */ local_memh = *memh; - - /* memh should have been used in exchanges or from a remote process, addresses, etc likely garbage. fix it */ + /* memh should have been used in exchanges or from a remote process, + addresses, etc. likely garbage. fix it */ local_memh->tl_h = (ucc_mem_map_tl_t *)ucc_calloc( ctx->n_tl_ctx, sizeof(ucc_mem_map_tl_t), "tl memh"); for (i = 0; i < ctx->n_tl_ctx; i++) { tl_lib = ucc_derived_of(ctx->tl_ctx[i]->super.lib, ucc_tl_lib_t); - /* FIXME: i don't think this will work properly for more than 1 TL */ status = tl_lib->iface->context.mem_map( - (const ucc_base_context_t *)ctx->tl_ctx[i], UCC_MEM_MAP_TYPE_GLOBAL, + (const ucc_base_context_t *)ctx->tl_ctx[i], type, params->segments[0].address, params->segments[0].len, local_memh, &local_memh->tl_h[i]); if (status < UCC_ERR_NOT_IMPLEMENTED) { @@ -1172,7 +1173,7 @@ ucc_status_t ucc_mem_map_import(ucc_context_h context, } strncpy(local_memh->tl_h[i].tl_name, tls->names[i], 8); } - local_memh->type = UCC_MEM_MAP_TYPE_GLOBAL; + local_memh->type = type; /* fix context as it will be incorrect on a different system */ local_memh->context = ctx; *memh_size = 0; @@ -1181,10 +1182,12 @@ ucc_status_t ucc_mem_map_import(ucc_context_h context, } ucc_status_t ucc_mem_map_export(ucc_context_h context, + ucc_mem_map_flags_t flags, ucc_mem_map_params_t *params, size_t *memh_size, ucc_mem_map_mem_h *memh) { ucc_context_t *ctx = (ucc_context_t *)context; + ucc_config_names_array_t *tls = &ctx->all_tls; size_t total_pack_size = 0; ucc_mem_map_memh_t *local_memh; ucc_mem_map_memh_t *exported_memh; @@ -1193,17 +1196,25 @@ ucc_status_t ucc_mem_map_export(ucc_context_h context, ucc_tl_lib_t *tl_lib; size_t offset; int i; - ucc_config_names_array_t *tls = &ctx->all_tls; + int type; - local_memh = (ucc_mem_map_memh_t *)ucc_calloc(1, sizeof(ucc_mem_map_memh_t), - "local memh"); - if (!local_memh) { - ucc_error("failed to allocate a local memory handle"); - return UCC_ERR_NO_MEMORY; - } + if (flags == UCC_MEM_MAP_EXPORT) { + local_memh = (ucc_mem_map_memh_t *)ucc_calloc(1, sizeof(ucc_mem_map_memh_t), + "local memh"); + if (!local_memh) { + ucc_error("failed to allocate a local memory handle"); + return UCC_ERR_NO_MEMORY; + } - local_memh->tl_h = (ucc_mem_map_tl_t *)ucc_calloc( - ctx->n_tl_ctx, sizeof(ucc_mem_map_tl_t), "tl memh"); + local_memh->tl_h = (ucc_mem_map_tl_t *)ucc_calloc( + ctx->n_tl_ctx, sizeof(ucc_mem_map_tl_t), "tl memh"); + local_memh->address = params->segments[0].address; + local_memh->len = params->segments[0].len; + type = UCC_MEM_MAP_TYPE_LOCAL; + } else { + local_memh = *memh; + type = UCC_MEM_MAP_TYPE_DPU_EXPORT; + } packed_buffers = (void **)ucc_calloc(ctx->n_tl_ctx, sizeof(void *), "packed buffers"); @@ -1212,7 +1223,7 @@ ucc_status_t ucc_mem_map_export(ucc_context_h context, tl_lib = ucc_derived_of(ctx->tl_ctx[i]->super.lib, ucc_tl_lib_t); /* always treat as a local mem handle */ status = tl_lib->iface->context.mem_map( - (const ucc_base_context_t *)ctx->tl_ctx[i], UCC_MEM_MAP_TYPE_LOCAL, + (const ucc_base_context_t *)ctx->tl_ctx[i], type, params->segments[0].address, params->segments[0].len, local_memh, &local_memh->tl_h[i]); if (status != UCC_OK) { @@ -1238,7 +1249,7 @@ ucc_status_t ucc_mem_map_export(ucc_context_h context, &local_memh->tl_h[i], &packed_buffers[i]); if (status != UCC_OK) { if (status < UCC_ERR_NOT_IMPLEMENTED) { - ucc_error("failed to map memory"); + ucc_error("failed to pack memory handles"); goto failed_pack; } } @@ -1254,7 +1265,8 @@ ucc_status_t ucc_mem_map_export(ucc_context_h context, "exported memh"); if (!exported_memh) { ucc_error("failed to allocate handle for exported buffers"); - return UCC_ERR_NO_MEMORY; + status = UCC_ERR_NO_MEMORY; + goto failed_pack; } /* copying */ @@ -1278,17 +1290,22 @@ ucc_status_t ucc_mem_map_export(ucc_context_h context, // copy name information for look ups later strncpy(exported_memh->tl_h[i].tl_name, tls->names[i], 8); } - ucc_free(local_memh); - exported_memh->type = UCC_MEM_MAP_TYPE_LOCAL; + exported_memh->type = type; exported_memh->context = ctx; - exported_memh->address = params->segments[0].address; - exported_memh->len = params->segments[0].len; + exported_memh->address = local_memh->address; + exported_memh->len = local_memh->len; exported_memh->my_ctx_rank = ctx->rank; exported_memh->num_tls = ctx->n_tl_ctx; *memh = exported_memh; - *memh_size = total_pack_size; + *memh_size = sizeof(ucc_mem_map_memh_t) + offset; + ucc_free(local_memh); + ucc_free(packed_buffers); return UCC_OK; failed_pack: + for (int j = 0; j < i; j++) { + ucc_free(packed_buffers[j]); + } + i = ctx->n_tl_ctx; failed_mem_map: for (int j = 0; j < i; j++) { tl_lib = ucc_derived_of(ctx->tl_ctx[i]->super.lib, ucc_tl_lib_t); @@ -1296,25 +1313,135 @@ ucc_status_t ucc_mem_map_export(ucc_context_h context, UCC_MEM_MAP_TYPE_LOCAL, &local_memh->tl_h[j]); } + ucc_free(local_memh); + ucc_free(packed_buffers); + *memh = NULL; + *memh_size = 0; return status; } +#if 0 +ucc_status_t ucc_mem_map_dpu_export(ucc_context_h context, + ucc_mem_map_params_t *params, size_t *memh_size, + ucc_mem_map_mem_h *memh) +{ + ucc_context_t *ctx = (ucc_context_t *)context; + size_t total_pack_size = 0; + ucc_mem_map_memh_t *local_memh = *memh; + ucc_config_names_array_t *tls = &ctx->all_tls; + ucc_mem_map_memh_t *exported_memh; + void **packed_buffers; + ucc_status_t status; + ucc_tl_lib_t *tl_lib; + size_t offset; + int i; + + packed_buffers = + (void **)ucc_calloc(ctx->n_tl_ctx, sizeof(void *), "packed buffers"); + + /* map all the memory */ + for (i = 0; i < ctx->n_tl_ctx; i++) { + tl_lib = ucc_derived_of(ctx->tl_ctx[i]->super.lib, ucc_tl_lib_t); + /* always treat as a local mem handle */ + status = tl_lib->iface->context.mem_map( + (const ucc_base_context_t *)ctx->tl_ctx[i], UCC_MEM_MAP_TYPE_DPU_EXPORT, + params->segments[0].address, params->segments[0].len, local_memh, + &local_memh->tl_h[i]); + if (status != UCC_OK) { + if (status < UCC_ERR_NOT_IMPLEMENTED) { + ucc_error("failed to map memory"); + goto failed_mem_map; + } + if (status == UCC_ERR_NOT_IMPLEMENTED || + status == UCC_ERR_NOT_SUPPORTED) { + /* either not implemented or not supported, set memh to null */ + local_memh->tl_h[i].packed_size = 0; + } + } + } + + /* now pack all the memories */ + for (i = 0; i < ctx->n_tl_ctx; i++) { + if (local_memh->tl_h[i].packed_size > 0) { + tl_lib = ucc_derived_of(ctx->tl_ctx[i]->super.lib, ucc_tl_lib_t); + /* tl should set packed_size, allocate buffer, pack memh */ + status = tl_lib->iface->context.memh_pack( + (const ucc_base_context_t *)ctx->tl_ctx[i], + &local_memh->tl_h[i], &packed_buffers[i]); + if (status != UCC_OK) { + if (status < UCC_ERR_NOT_IMPLEMENTED) { + ucc_error("failed to map memory"); + goto failed_pack; + } + } + total_pack_size += local_memh->tl_h[i].packed_size; + } + } + /* allocate exported memh, copy items over */ + exported_memh = (ucc_mem_map_memh_t *)ucc_calloc( + 1, sizeof(ucc_mem_map_memh_t) + total_pack_size + + 2 * sizeof(size_t) * ctx->n_tl_ctx, + "exported memh"); + if (!exported_memh) { + ucc_error("failed to allocate handle for exported buffers"); + return UCC_ERR_NO_MEMORY; + } +// exported_memh = local_memh; + + /* copying */ + exported_memh->tl_h = local_memh->tl_h; + for (i = 0, offset = 0; i < ctx->n_tl_ctx; i++) { + uint64_t tl_index = i; + if (local_memh->tl_h[i].packed_size == 0) { + continue; + } + memcpy(PTR_OFFSET(exported_memh->pack_buffer, offset), &tl_index, + sizeof(size_t)); + offset += sizeof(size_t); + memcpy(PTR_OFFSET(exported_memh->pack_buffer, offset), + &exported_memh->tl_h[i].packed_size, sizeof(size_t)); + offset += sizeof(size_t); + memcpy(PTR_OFFSET(exported_memh->pack_buffer, offset), + packed_buffers[i], exported_memh->tl_h[i].packed_size); + ucc_free(packed_buffers[i]); + offset += exported_memh->tl_h[i].packed_size; + // copy name information for look ups later + strncpy(exported_memh->tl_h[i].tl_name, tls->names[i], 8); + } + exported_memh->type = UCC_MEM_MAP_TYPE_DPU_EXPORT; + exported_memh->context = ctx; + exported_memh->address = local_memh->address; + exported_memh->len = local_memh->len; + exported_memh->my_ctx_rank = ctx->rank; + exported_memh->num_tls = ctx->n_tl_ctx; + *memh = exported_memh; + *memh_size = sizeof(ucc_mem_map_memh_t) + offset;//total_pack_size + 2 * sizeof(size_t) * ctx->n_tl_ctx; + ucc_free(local_memh); + return UCC_OK; +failed_pack: +failed_mem_map: + for (int j = 0; j < i; j++) { + tl_lib = ucc_derived_of(ctx->tl_ctx[i]->super.lib, ucc_tl_lib_t); + tl_lib->iface->context.mem_unmap((const ucc_base_context_t *)ctx, + UCC_MEM_MAP_TYPE_LOCAL, + &local_memh->tl_h[j]); + } + return status; +} +#endif ucc_status_t ucc_mem_map(ucc_context_h context, ucc_mem_map_flags_t flags, ucc_mem_map_params_t *params, size_t *memh_size, ucc_mem_map_mem_h *memh) { - if (params->n_segments > 1) { - ucc_error("UCC only supports one mapping per call"); - return UCC_ERR_INVALID_PARAM; - } - // check if flags is import / export - if (flags == UCC_MEM_MAP_IMPORT) { - // set map type to global - return ucc_mem_map_import(context, params, memh_size, memh); + if (flags == UCC_MEM_MAP_IMPORT || flags == UCC_MEM_MAP_IMPORT_OFFLOAD) { + return ucc_mem_map_import(context, flags, params, memh_size, memh); } else { - // set map type to local - return ucc_mem_map_export(context, params, memh_size, memh); + if (params->n_segments > 1) { + ucc_error("UCC only supports one mapping per call"); + return UCC_ERR_INVALID_PARAM; + } + return ucc_mem_map_export(context, flags, params, memh_size, memh); } } diff --git a/src/core/ucc_context.h b/src/core/ucc_context.h index 23a583c5bd..e70b46659b 100644 --- a/src/core/ucc_context.h +++ b/src/core/ucc_context.h @@ -98,7 +98,9 @@ typedef struct ucc_context_config { typedef enum { UCC_MEM_MAP_TYPE_LOCAL, - UCC_MEM_MAP_TYPE_GLOBAL + UCC_MEM_MAP_TYPE_GLOBAL, + UCC_MEM_MAP_TYPE_DPU_IMPORT, /* special case */ + UCC_MEM_MAP_TYPE_DPU_EXPORT } ucc_mem_map_type_t; typedef struct ucc_mem_map_tl_t { diff --git a/src/ucc/api/ucc.h b/src/ucc/api/ucc.h index 8e194c7dcf..e8342cbfe2 100644 --- a/src/ucc/api/ucc.h +++ b/src/ucc/api/ucc.h @@ -2257,8 +2257,10 @@ ucc_status_t ucc_collective_triggered_post(ucc_ee_h ee, ucc_ev_t *ee_event); typedef enum { UCC_MEM_MAP_EXPORT = 0, /*!< Indicate ucc_mem_map() should export memory handles from TLs used by context */ - UCC_MEM_MAP_IMPORT = 1 /*!< Indicate ucc_mem_map() should import + UCC_MEM_MAP_IMPORT = 1, /*!< Indicate ucc_mem_map() should import memory handles from user memory handle */ + UCC_MEM_MAP_EXPORT_OFFLOAD = 2, + UCC_MEM_MAP_IMPORT_OFFLOAD = 3 } ucc_mem_map_flags_t; /**