Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ferrol aderholdt committed Feb 15, 2025
1 parent 033942f commit 4f039a2
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 28 deletions.
5 changes: 2 additions & 3 deletions src/components/tl/ucp/alltoallv/alltoallv_onesided.c
Original file line number Diff line number Diff line change
Expand Up @@ -48,11 +48,10 @@ ucc_status_t ucc_tl_ucp_alltoallv_onesided_start(ucc_coll_task_t *ctask)

UCPCHECK_GOTO(ucc_tl_ucp_put_nb(PTR_OFFSET(src, sd_disp),
PTR_OFFSET(dest, dd_disp),
data_size, peer, src_memh[peer],
data_size, peer, *src_memh,
dst_memh[peer], team, task),
task, out);
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, src_memh[peer],
dst_memh[peer], team), task, out);
UCPCHECK_GOTO(ucc_tl_ucp_atomic_inc(pSync, peer, *src_memh, dst_memh[peer], team), task, out);
}
return ucc_progress_queue_enqueue(UCC_TL_CORE_CTX(team)->pq, &task->super);
out:
Expand Down
18 changes: 18 additions & 0 deletions src/components/tl/ucp/tl_ucp.h
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,24 @@ extern ucc_config_field_t ucc_tl_ucp_lib_config_table[];
#define UCC_TL_UCP_REMOTE_RKEY(_ctx, _rank, _seg) \
((_ctx)->rkeys[_rank * _ctx->n_rinfo_segs + _seg])

#define UCC_TL_UCP_MEMH_TL_HEADERS 2

#define UCC_TL_UCP_MEMH_TL_PACKED_HEADERS 2

#define UCC_TL_UCP_MEMH_TL_HEADER_SIZE \
(sizeof(size_t) * UCC_TL_UCP_MEMH_TL_HEADERS)

#define UCC_TL_UCP_MEMH_TL_PACKED_HEADER_SIZE \
(sizeof(size_t) * \
(UCC_TL_UCP_MEMH_TL_HEADERS + UCC_TL_UCP_MEMH_TL_PACKED_HEADERS))

#define UCC_TL_UCP_MEMH_TL_KEY_SIZE(buffer) \
(*(size_t *)(PTR_OFFSET(buffer, UCC_TL_UCP_MEMH_TL_HEADER_SIZE)))

#define UCC_TL_UCP_MEMH_TL_PACKED_MEMH(buffer) \
(PTR_OFFSET(buffer, UCC_TL_UCP_MEMH_TL_PACKED_HEADER_SIZE + \
UCC_TL_UCP_MEMH_TL_KEY_SIZE(buffer)))

extern ucs_memory_type_t ucc_memtype_to_ucs[UCC_MEMORY_TYPE_LAST+1];

void ucc_tl_ucp_pre_register_mem(ucc_tl_ucp_team_t *team, void *addr,
Expand Down
129 changes: 117 additions & 12 deletions src/components/tl/ucp/tl_ucp_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -558,11 +558,13 @@ ucc_status_t ucc_tl_ucp_mem_map_memhbuf(ucc_tl_ucp_context_t *ctx,
{
ucp_mem_map_params_t mmap_params;
ucs_status_t status;
void *packed_memh;

*mh = NULL;
/* unpack here */
size_t *key_size = (size_t *)PTR_OFFSET(pack_buffer, sizeof(size_t) * 2);
void *packed_memh = PTR_OFFSET(pack_buffer, sizeof(size_t) * 4 + *key_size);
packed_memh = UCC_TL_UCP_MEMH_TL_PACKED_MEMH(pack_buffer);
// size_t *key_size = (size_t *)PTR_OFFSET(pack_buffer, sizeof(size_t) * 2);
// void *packed_memh = PTR_OFFSET(pack_buffer, sizeof(size_t) * 4 + *key_size);

mmap_params.field_mask = UCP_MEM_MAP_PARAM_FIELD_EXPORTED_MEMH_BUFFER;
mmap_params.exported_memh_buffer = packed_memh;
Expand All @@ -585,23 +587,113 @@ ucc_status_t ucc_tl_ucp_mem_map_memhbuf(ucc_tl_ucp_context_t *ctx,
return UCC_OK;
}

ucc_status_t ucc_tl_ucp_mem_map_export(ucc_tl_ucp_context_t *ctx, void *address,
size_t len,
int type,
ucc_tl_ucp_memh_data_t *m_data)
{
ucc_status_t ucc_status = UCC_OK;
ucp_mem_h mh;
ucp_mem_map_params_t mmap_params;
ucp_memh_pack_params_t pack_params;
ucs_status_t status;

if (type == UCC_MEM_MAP_TYPE_LOCAL) {
mmap_params.field_mask =
UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
mmap_params.address = address;
mmap_params.length = len;

status = ucp_mem_map(ctx->worker.ucp_context, &mmap_params, &mh);
if (UCS_OK != status) {
tl_error(ctx->super.super.lib, "ucp_mem_map failed with error code: %d",
status);
ucc_status = ucs_status_to_ucc_status(status);
return ucc_status;
}
m_data->rinfo.mem_h = mh;
m_data->rinfo.va_base = address;
m_data->rinfo.len = len;

status = ucp_memh_pack(mh, &pack_params, &m_data->packed_memh,
&m_data->packed_memh_len);
if (status != UCS_OK) {
// we don't support memory pack, or it failed
tl_debug(ctx->super.super.lib, "ucp_memh_pack() returned error %s",
ucs_status_string(status));
m_data->packed_memh = 0;
m_data->packed_memh_len = 0;
}
}

// pack rkey
status =
ucp_rkey_pack(ctx->worker.ucp_context, mh, &m_data->rinfo.packed_key,
&m_data->rinfo.packed_key_len);
if (status != UCS_OK) {
tl_error(ctx->super.super.lib, "unable to pack rkey with error %s",
ucs_status_string(status));
ucp_mem_unmap(ctx->worker.ucp_context, mh);
ucc_status = ucs_status_to_ucc_status(status);
return ucc_status;
}

return ucc_status;
}

ucc_status_t ucc_tl_ucp_mem_map_dpu_import(ucc_tl_ucp_context_t *ctx,
void *address, size_t len,
ucc_tl_ucp_memh_data_t *m_data,
ucc_mem_map_memh_t *l_memh,
void *tl_h)
{
ucp_mem_h mh;
ucc_status_t ucc_status;
size_t offset = 0;

for (int i = 0; i < l_memh->num_tls; i++) {
size_t *p = (size_t *)PTR_OFFSET(l_memh->pack_buffer, offset);

if (tl_h == (void *)&l_memh->tl_h[i]) {
break;
}
/* this is not the index, skip this section of buffer if exists */
if (p[0] == i) {
offset += p[1];
}
}

ucc_status = ucc_tl_ucp_mem_map_memhbuf(
ctx, PTR_OFFSET(l_memh->pack_buffer, offset), &mh);
if (ucc_status != UCC_OK) {
tl_error(ctx->super.super.lib, "ucp_mem_map failed to map memh");
return ucc_status;
}
m_data->rinfo.mem_h = mh; /* used for xgvmi */
m_data->rinfo.va_base = address;
m_data->rinfo.len = len;

return UCC_OK;
}

ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context, int type,
void *address, size_t len, void *memh,
void *tl_h)
{
ucc_tl_ucp_context_t *ctx = ucc_derived_of(context, ucc_tl_ucp_context_t);
ucc_status_t ucc_status = UCC_OK;
ucc_mem_map_tl_t *p_memh = (ucc_mem_map_tl_t *)tl_h;
ucc_mem_map_tl_t *p_memh = (ucc_mem_map_tl_t *)tl_h;
ucc_tl_ucp_memh_data_t *m_data = (ucc_tl_ucp_memh_data_t *)p_memh->tl_data;
ucp_mem_h mh = NULL;
ucc_mem_map_memh_t *l_memh = (ucc_mem_map_memh_t *)memh;
size_t offset = 0;
ucp_mem_map_params_t mmap_params;
ucs_status_t status;
ucp_memh_pack_params_t pack_params;
// size_t offset = 0;
// ucp_mem_h mh = NULL;
ucc_status_t ucc_status = UCC_OK;
// ucp_mem_map_params_t mmap_params;
// ucs_status_t status;
// ucp_memh_pack_params_t pack_params;

/* technically, only need to do this on import */
if (type == UCC_MEM_MAP_TYPE_GLOBAL || type == UCC_MEM_MAP_TYPE_DPU_IMPORT || !m_data) {
if (type == UCC_MEM_MAP_TYPE_GLOBAL ||
type == UCC_MEM_MAP_TYPE_DPU_IMPORT || !m_data) {
/* either we are importing or m_data is null */
m_data = ucc_calloc(1, sizeof(ucc_tl_ucp_memh_data_t), "tl data");
if (!m_data) {
Expand All @@ -610,6 +702,19 @@ ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context, int type,
}
p_memh->tl_data = m_data;
}

if (type == UCC_MEM_MAP_TYPE_LOCAL || type == UCC_MEM_MAP_TYPE_DPU_EXPORT) {
ucc_status = ucc_tl_ucp_mem_map_export(ctx, address, len, type, m_data);
} else if (type == UCC_MEM_MAP_TYPE_DPU_IMPORT) {
ucc_status = ucc_tl_ucp_mem_map_dpu_import(ctx, address, len, m_data, l_memh, tl_h);
}

p_memh->packed_size =
m_data->packed_memh_len + m_data->rinfo.packed_key_len;

return ucc_status;
}
#if 0
if (type == UCC_MEM_MAP_TYPE_LOCAL) { /* this is an export */
mmap_params.field_mask =
UCP_MEM_MAP_PARAM_FIELD_ADDRESS | UCP_MEM_MAP_PARAM_FIELD_LENGTH;
Expand Down Expand Up @@ -668,7 +773,6 @@ ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context, int type,
tl_error(ctx->super.super.lib, "unable to pack rkey with error %s",
ucs_status_string(status));
}

}
if (type == UCC_MEM_MAP_TYPE_DPU_EXPORT) {
// pack rkey
Expand All @@ -683,10 +787,11 @@ ucc_status_t ucc_tl_ucp_mem_map(const ucc_base_context_t *context, int type,
l_memh->type = UCC_MEM_MAP_TYPE_DPU_EXPORT;
}
p_memh->packed_size =
m_data->packed_memh_len + m_data->rinfo.packed_key_len;
m_data->packed_memh_len + m_data->rinfo.packed_key_len;

return ucc_status;
}
#endif

ucc_status_t ucc_tl_ucp_mem_unmap(const ucc_base_context_t *context, int type,
void *memh)
Expand Down
1 change: 0 additions & 1 deletion src/components/tl/ucp/tl_ucp_sendrecv.h
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,6 @@ static inline ucc_status_t ucc_tl_ucp_atomic_inc(void * target,
ucs_status_ptr_t ucp_status;
ucc_status_t status;
ucp_ep_h ep;
// void *packed_memh;

status = ucc_tl_ucp_get_ep(team, dest_group_rank, &ep);
if (ucc_unlikely(UCC_OK != status)) {
Expand Down
23 changes: 11 additions & 12 deletions src/core/ucc_context.c
Original file line number Diff line number Diff line change
Expand Up @@ -1132,14 +1132,17 @@ ucc_status_t ucc_context_get_attr(ucc_context_t *context,
return status;
}

ucc_status_t ucc_mem_map_import(ucc_context_h context, ucc_mem_map_flags_t flags,
ucc_status_t ucc_mem_map_import(ucc_context_h context,
ucc_mem_map_flags_t flags,
ucc_mem_map_params_t *params, size_t *memh_size,
ucc_mem_map_mem_h *memh)
{
ucc_context_t *ctx = (ucc_context_t *)context;
ucc_status_t status = UCC_OK;
ucc_config_names_array_t *tls = &ctx->all_tls;
ucc_mem_map_type_t type = (flags == UCC_MEM_MAP_IMPORT) ? UCC_MEM_MAP_TYPE_GLOBAL : UCC_MEM_MAP_TYPE_DPU_IMPORT;
ucc_mem_map_type_t type = (flags == UCC_MEM_MAP_IMPORT)
? UCC_MEM_MAP_TYPE_GLOBAL
: UCC_MEM_MAP_TYPE_DPU_IMPORT;
int i;
ucc_mem_map_memh_t *local_memh;
ucc_tl_lib_t *tl_lib;
Expand All @@ -1153,15 +1156,12 @@ ucc_status_t ucc_mem_map_import(ucc_context_h context, ucc_mem_map_flags
return UCC_ERR_INVALID_PARAM;
}

/* FIXME: it's legal for the user to export and immediately import a handle,
* is this an issue? */
local_memh = *memh;
/* memh should have been used in exchanges or from a remote process, addresses, etc likely garbage. fix it */
local_memh->tl_h = (ucc_mem_map_tl_t *)ucc_calloc(
ctx->n_tl_ctx, sizeof(ucc_mem_map_tl_t), "tl memh");
for (i = 0; i < ctx->n_tl_ctx; i++) {
tl_lib = ucc_derived_of(ctx->tl_ctx[i]->super.lib, ucc_tl_lib_t);
/* FIXME: i don't think this will work properly for more than 1 TL */
status = tl_lib->iface->context.mem_map(
(const ucc_base_context_t *)ctx->tl_ctx[i], type,
params->segments[0].address, params->segments[0].len, local_memh,
Expand All @@ -1185,6 +1185,7 @@ ucc_status_t ucc_mem_map_export(ucc_context_h context,
ucc_mem_map_mem_h *memh)
{
ucc_context_t *ctx = (ucc_context_t *)context;
ucc_config_names_array_t *tls = &ctx->all_tls;
size_t total_pack_size = 0;
ucc_mem_map_memh_t *local_memh;
ucc_mem_map_memh_t *exported_memh;
Expand All @@ -1193,7 +1194,6 @@ ucc_status_t ucc_mem_map_export(ucc_context_h context,
ucc_tl_lib_t *tl_lib;
size_t offset;
int i;
ucc_config_names_array_t *tls = &ctx->all_tls;

local_memh = (ucc_mem_map_memh_t *)ucc_calloc(1, sizeof(ucc_mem_map_memh_t),
"local memh");
Expand Down Expand Up @@ -1286,8 +1286,7 @@ ucc_status_t ucc_mem_map_export(ucc_context_h context,
exported_memh->my_ctx_rank = ctx->rank;
exported_memh->num_tls = ctx->n_tl_ctx;
*memh = exported_memh;
*memh_size = sizeof(ucc_mem_map_memh_t) + total_pack_size +
2 * sizeof(size_t) * ctx->n_tl_ctx;
*memh_size = sizeof(ucc_mem_map_memh_t) + offset;// total_pack_size + 2 * sizeof(size_t) * ctx->n_tl_ctx;
return UCC_OK;
failed_pack:
failed_mem_map:
Expand Down Expand Up @@ -1390,13 +1389,13 @@ ucc_status_t ucc_mem_map_dpu_export(ucc_context_h context,
}
exported_memh->type = UCC_MEM_MAP_TYPE_DPU_EXPORT;
exported_memh->context = ctx;
exported_memh->address = local_memh->address; //params->segments[0].address;
exported_memh->len = local_memh->len; //params->segments[0].len;
exported_memh->address = local_memh->address;
exported_memh->len = local_memh->len;
exported_memh->my_ctx_rank = ctx->rank;
exported_memh->num_tls = ctx->n_tl_ctx;
*memh = exported_memh;
*memh_size = sizeof(ucc_mem_map_memh_t) + total_pack_size +
2 * sizeof(size_t) * ctx->n_tl_ctx;
*memh_size = sizeof(ucc_mem_map_memh_t) + offset;//total_pack_size + 2 * sizeof(size_t) * ctx->n_tl_ctx;
ucc_free(local_memh);
return UCC_OK;
failed_pack:
failed_mem_map:
Expand Down

0 comments on commit 4f039a2

Please sign in to comment.