Skip to content

Commit

Permalink
TL/UCP: add support for ucc_mem_map/unmap
Browse files Browse the repository at this point in the history
TL/UCP: add support for communication with memmap

REVIEW: various fixes

TL/UCP: Enable multiple packed buffers
  • Loading branch information
ferrol aderholdt committed Jan 16, 2025
1 parent c2ebac7 commit 93547ea
Show file tree
Hide file tree
Showing 10 changed files with 361 additions and 30 deletions.
7 changes: 7 additions & 0 deletions src/components/tl/ucp/alltoall/alltoall.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,13 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_init(ucc_base_coll_args_t *coll_args,
goto out;
}
}
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH)) {
coll_args->args.src_memh.global_memh = NULL;
}
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH)) {
coll_args->args.dst_memh.global_memh = NULL;
}

task = ucc_tl_ucp_init_task(coll_args, team);
*task_h = &task->super;
task->super.post = ucc_tl_ucp_alltoall_onesided_start;
Expand Down
11 changes: 7 additions & 4 deletions src/components/tl/ucp/alltoall/alltoall_onesided.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,22 +24,25 @@ ucc_status_t ucc_tl_ucp_alltoall_onesided_start(ucc_coll_task_t *ctask)
ucc_rank_t gsize = UCC_TL_TEAM_SIZE(team);
ucc_rank_t start = (grank + 1) % gsize;
long * pSync = TASK_ARGS(task).global_work_buffer;
ucc_mem_map_mem_h *src_memh = TASK_ARGS(task).src_memh.global_memh;
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
ucc_rank_t peer;

ucc_tl_ucp_task_reset(task, UCC_INPROGRESS);
/* TODO: change when support for library-based work buffers is complete */
nelems = (nelems / gsize) * ucc_dt_size(TASK_ARGS(task).src.info.datatype);
dest = dest + grank * nelems;

UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + start * nelems),
(void *)dest, nelems, start, team, task),
(void *)dest, nelems, start, src_memh[start], dst_memh[start], team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, start, team), task, out);
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, start, src_memh[start], dst_memh[start], team), task, out);

for (peer = (start + 1) % gsize; peer != start; peer = (peer + 1) % gsize) {
UCPCHECK_GOTO(ucc_tl_ucp_put_nb((void *)(src + peer * nelems),
(void *)dest, nelems, peer, team, task),
(void *)dest, nelems, peer, src_memh[peer], dst_memh[peer], team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, team), task,
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, src_memh[peer], dst_memh[peer], team), task,
out);
}

Expand Down
12 changes: 10 additions & 2 deletions src/components/tl/ucp/alltoallv/alltoallv_onesided.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask)
ucc_aint_t *d_disp = TASK_ARGS(task).dst.info_v.displacements;
size_t sdt_size = ucc_dt_size(TASK_ARGS(task).src.info_v.datatype);
size_t rdt_size = ucc_dt_size(TASK_ARGS(task).dst.info_v.datatype);
ucc_mem_map_mem_h *src_memh = TASK_ARGS(task).src_memh.global_memh;
ucc_mem_map_mem_h *dst_memh = TASK_ARGS(task).dst_memh.global_memh;
ucc_rank_t peer;
size_t sd_disp, dd_disp, data_size;

Expand All @@ -46,9 +48,9 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask)

UCPCHECK_GOTO(ucc_tl_ucp_put_nb(PTR_OFFSET(src, sd_disp),
PTR_OFFSET(dest, dd_disp),
data_size, peer, team, task),
data_size, peer, src_memh[peer], dst_memh[peer], team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, team), task, out);
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, src_memh[peer], dst_memh[peer], team), task, out);
}
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
out:
Expand Down Expand Up @@ -93,6 +95,12 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_init(ucc_base_coll_args_t *coll_args,
goto out;
}
}
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_SRC_MEMH)) {
coll_args->args.src_memh.global_memh = NULL;
}
if (!(coll_args->args.mask & UCC_COLL_ARGS_FIELD_MEM_MAP_DST_MEMH)) {
coll_args->args.dst_memh.global_memh = NULL;
}

task = ucc_tl_ucp_init_task(coll_args, team);
*task_h = &task->super;
Expand Down
11 changes: 11 additions & 0 deletions src/components/tl/ucp/tl_ucp.c
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,17 @@ ucc_status_t ucc_tl_ucp_get_lib_properties(ucc_base_lib_properties_t *prop);
ucc_status_t ucc_tl_ucp_get_context_attr(const ucc_base_context_t *context,
ucc_base_ctx_attr_t *base_attr);

ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context,
int type,
void *address,
size_t len,
void *memh,
void *tl_h);

ucc_status_t ucc_tl_ucp_memh_pack(const ucc_base_context_t *context, void *memh, void **pack_buffer);

ucc_status_t ucc_tl_ucp_mem_unmap(const ucc_base_context_t *context, int type, void *memh);

ucc_config_field_t ucc_tl_ucp_lib_config_table[] = {
{"", "", NULL, ucc_offsetof(ucc_tl_ucp_lib_config_t, super),
UCC_CONFIG_TYPE_TABLE(ucc_tl_lib_config_table)},
Expand Down
7 changes: 7 additions & 0 deletions src/components/tl/ucp/tl_ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,13 @@ typedef struct ucc_tl_ucp_remote_info {
size_t packed_key_len;
} ucc_tl_ucp_remote_info_t;

typedef struct ucc_tl_ucp_memh_data {
ucc_tl_ucp_remote_info_t rinfo;
void *packed_memh;
size_t packed_memh_len;
ucp_rkey_h rkey;
} ucc_tl_ucp_memh_data_t;

typedef struct ucc_tl_ucp_worker {
ucp_context_h ucp_context;
ucp_worker_h ucp_worker;
Expand Down
172 changes: 172 additions & 0 deletions src/components/tl/ucp/tl_ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -555,6 +555,178 @@ static void ucc_tl_ucp_ctx_remote_pack_data(ucc_tl_ucp_context_t *ctx,
}
}

ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context,
int type,
void *address,
size_t len,
void *memh,
void *tl_h)
{
ucc_tl_ucp_context_t *ctx = ucc_derived_of(context, ucc_tl_ucp_context_t);
ucc_status_t ucc_status = UCC_OK;
ucc_mem_map_tl_t * p_memh = (ucc_mem_map_tl_t *)tl_h;
ucc_tl_ucp_memh_data_t *m_data = (ucc_tl_ucp_memh_data_t *)p_memh->tl_data;
ucp_mem_map_params_t mmap_params;
ucp_mem_h mh = NULL;
ucs_status_t status;
ucp_memh_pack_params_t pack_params;

// basically an import here
if (type == UCC_MEM_MAP_TYPE_GLOBAL) {
// m_data is lost in the exchange, make a new one
m_data = ucc_calloc(1, sizeof(ucc_tl_ucp_memh_data_t), "tl data");
if (!m_data) {
tl_error(ctx->super.super.lib, "failed to allocate tl data");
return UCC_ERR_NO_MEMORY;
}
p_memh->tl_data = m_data;
#if defined(__aarch64__)
ucc_mem_map_memh_t * l_memh = (ucc_mem_map_memh_t *)memh;
/* unpack here */
size_t *key_size = (size_t *)l_memh->pack_buffer;
void *packed_memh = PTR_OFFSET(l_memh->pack_buffer, sizeof(size_t) * 2 + *key_size);
mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER;
mmap_params.exported_memh_buffer = packed_memh;

status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh);
if (UCS_OK != status) {
if (status != UCS_ERR_UNREACHABLE) {
tl_error(ctx->super.super.lib,
"ucp_mem_map failed with error code: %s", ucs_status_string(status));
ucc_status = ucs_status_to_ucc_status(status);
return ucc_status;
} else {
tl_debug(ctx->super.super.lib,
"ucp_mem_map could not map exported memory handle");
ucc_status = UCC_OK; // this is still OK
}
} else {
m_data->rinfo.mem_h = mh;
// the rest of the data is garbage. fix it
m_data->rinfo.va_base = address;
m_data->rinfo.len = len;
}
#endif
} else {
if (!m_data) {
m_data = ucc_calloc(1, sizeof(ucc_tl_ucp_memh_data_t), "tl data");
if (!m_data) {
tl_error(ctx->super.super.lib, "failed to allocate TL/UCP specific data");
return UCC_ERR_NO_MEMORY;
}
p_memh->tl_data = m_data;
}
// local, we need to map it here
if (m_data->rinfo.mem_h == NULL) {
mmap_params.field_mask =
UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mmap_params.address = address;
mmap_params.length = len;

status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh);
if (UCS_OK != status) {
tl_error(ctx->super.super.lib,
"ucp_mem_map failed with error code: %d", status);
ucc_status = ucs_status_to_ucc_status(status);
}
}
m_data->rinfo.mem_h = mh;
m_data->rinfo.va_base = address;
m_data->rinfo.len = len;

// export memh
pack_params.field_mask = UCP_MEMH_PACK_PARAM_FIELD_FLAGS;
pack_params.flags = UCP_MEMH_PACK_FLAG_EXPORT;

status = ucp_memh_pack(mh, &pack_params, &m_data->packed_memh, &m_data->packed_memh_len);
if (status != UCS_OK) {
// we don't support memory pack, or it failed
tl_debug("ucp_memh_pack() returned error %s", ucs_status_string(status));
m_data->packed_memh = 0;
m_data->packed_memh_len = 0;
}
// pack rkey
status = ucp_rkey_pack(ctx->worker.ucp_context, mh, &m_data->rinfo.packed_key, &m_data->rinfo.packed_key_len);
if (status != UCS_OK) {
tl_error("unable to pack rkey with error %s", ucs_status_string(status));
}
p_memh->packed_size = m_data->packed_memh_len + m_data->rinfo.packed_key_len;
}

return ucc_status;
}

ucc_status_t ucc_tl_ucp_mem_unmap(const ucc_base_context_t *context, int type, void *memh)
{
ucc_tl_ucp_context_t *ctx = ucc_derived_of(context, ucc_tl_ucp_context_t);
ucc_mem_map_tl_t * p_memh = (ucc_mem_map_tl_t *) memh;
ucc_tl_ucp_memh_data_t *data;
ucs_status_t status;

if (!p_memh) {
return UCC_OK;
}

data = (ucc_tl_ucp_memh_data_t *)p_memh->tl_data;

if (type == UCC_MEM_MAP_TYPE_LOCAL) {
status = ucp_mem_unmap(ctx->worker.ucp_context, data->rinfo.mem_h);
if (status != UCS_OK) {
tl_error(ctx->super.super.lib, "ucp_mem_unmap failed with error code %d", status);
return ucs_status_to_ucc_status(status);
}
} else if (type == UCC_MEM_MAP_TYPE_GLOBAL) {
// need to free rkeys (data->rkey) , packed memh (data->packed_memh)
if (data->packed_memh) {
ucc_free(data->packed_memh);
}
if (data->rinfo.packed_key) {
ucp_rkey_buffer_release(data->rinfo.packed_key);
}
if (data->rkey) {
ucp_rkey_destroy(data->rkey);
}
} else {
ucc_error("Unknown type entered");
return UCC_ERR_INVALID_PARAM;
}
return UCC_OK;
}

ucc_status_t ucc_tl_ucp_memh_pack(const ucc_base_context_t *context, void *memh, void **pack_buffer)
{
ucc_mem_map_tl_t * p_memh = (ucc_mem_map_tl_t *) memh;
ucc_tl_ucp_memh_data_t *data = p_memh->tl_data;
void *packed_buffer;
size_t *key_size;
size_t *memh_size;

if (!data) {
return UCC_OK;
}
/*
* data order
*
* packed_key_size | packed_memh_size | packed_key | packed_memh
*/
packed_buffer = ucc_malloc(sizeof(size_t) * 2 + data->packed_memh_len + data->rinfo.packed_key_len,
"packed buffer");
if (!packed_buffer) {
ucc_error("failed to allocate a packed buffer of size %lu", data->packed_memh_len + data->rinfo.packed_key_len);
return UCC_ERR_NO_MEMORY;
}
key_size = packed_buffer;
*key_size = data->rinfo.packed_key_len;
memh_size = PTR_OFFSET(packed_buffer, sizeof(size_t));
*memh_size = data->packed_memh_len;
memcpy(PTR_OFFSET(packed_buffer, sizeof(size_t) * 2), data->rinfo.packed_key, *key_size);
memcpy(PTR_OFFSET(packed_buffer, sizeof(size_t) * 2 + data->rinfo.packed_key_len), data->packed_memh, *memh_size);

p_memh->packed_size = sizeof(size_t) * 2 + data->packed_memh_len + data->rinfo.packed_key_len;
*pack_buffer = packed_buffer;
return UCC_OK;
}

ucc_status_t ucc_tl_ucp_get_context_attr(const ucc_base_context_t *context,
ucc_base_ctx_attr_t *attr)
{
Expand Down
Loading

0 comments on commit 93547ea

Please sign in to comment.