Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

TL/MLX5: knomial fan-in in a2a #1064

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/components/tl/mlx5/alltoall/alltoall.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ typedef struct ucc_tl_mlx5_alltoall_node {
struct mlx5dv_mkey *team_recv_mkey;
void *umr_entries_buf;
struct ibv_mr *umr_entries_mr;
int fanin_index;
int fanin_dist;
int fanin_max_dist;
} ucc_tl_mlx5_alltoall_node_t;

typedef struct alltoall_net_ctrl {
Expand Down
90 changes: 63 additions & 27 deletions src/components/tl/mlx5/alltoall/alltoall_coll.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,40 +91,68 @@ static ucc_status_t ucc_tl_mlx5_poll_cq(struct ibv_cq *cq, ucc_base_lib_t *lib)
static ucc_status_t ucc_tl_mlx5_node_fanin(ucc_tl_mlx5_team_t *team,
ucc_tl_mlx5_schedule_t *task)
{
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
int seq_index = task->alltoall.seq_index;
int i;
ucc_tl_mlx5_alltoall_t *a2a = team->a2a;
int seq_index = task->alltoall.seq_index;
int npolls = UCC_TL_MLX5_TEAM_CTX(team)->cfg.npolls;
int radix = UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix;
int vrank = a2a->node.sbgp->group_rank - a2a->node.asr_rank;
int *dist = &a2a->node.fanin_dist;
int size = a2a->node.sbgp->group_size;
int seq_num = task->alltoall.seq_num;
int c_flag = 0;
int polls, peer, vpeer, pos, i;
ucc_tl_mlx5_alltoall_ctrl_t *ctrl_v;

if (a2a->node.sbgp->group_rank != a2a->node.asr_rank) {
ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->seq_num =
task->alltoall.seq_num;
} else {
for (i = 0; i < a2a->node.sbgp->group_size; i++) {
if (i == a2a->node.sbgp->group_rank) {
continue;
}
ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i);
if (ctrl_v->seq_num != task->alltoall.seq_num) {
return UCC_INPROGRESS;
}
}
for (i = 0; i < a2a->node.sbgp->group_size; i++) {
if (i == a2a->node.sbgp->group_rank) {
continue;
while (*dist <= a2a->node.fanin_max_dist) {
if (vrank % *dist == 0) {
pos = (vrank / *dist) % radix;
if (pos == 0) {
while (a2a->node.fanin_index < radix) {
vpeer = vrank + a2a->node.fanin_index * *dist;
if (vpeer >= size) {
a2a->node.fanin_index = radix;
break;
}
peer = (vpeer + a2a->node.asr_rank) % size;
ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, peer);
for (polls = 0; polls < npolls; polls++) {
if (ctrl_v->seq_num == seq_num) {
a2a->node.fanin_index++;
break;
}
}
if (polls == npolls) {
return UCC_INPROGRESS;
}
}
} else {
ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->seq_num = seq_num;
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(
task, "mlx5_alltoall_fanin_done", 0);
return UCC_OK;
}
ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i);
ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag |=
ctrl_v->mkey_cache_flag;
}
*dist *= radix;
a2a->node.fanin_index = 1;
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_step_done",
0);
}
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0);
for (i = 0; i < size; i++) {
ctrl_v = ucc_tl_mlx5_get_ctrl(a2a, seq_index, i);
ucc_assert(i == a2a->node.sbgp->group_rank ||
ctrl_v->seq_num == seq_num);
c_flag |= ctrl_v->mkey_cache_flag;
}
ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag = c_flag;
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(
task, "mlx5_alltoall_retrieve_cache_flags_done", 0);
return UCC_OK;
}

/* Each rank registers sbuf and rbuf and place the registration data
in the shared memory location. Next, all rank in node nitify the
ASR the registration data is ready using SHM Fanin */
in the shared memory location. Next, all rank in node notify the
ASR that the registration data is ready using SHM Fanin */
static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task)
{
ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task);
Expand All @@ -137,7 +165,7 @@ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task)
ucc_tl_mlx5_rcache_region_t *send_ptr;
ucc_tl_mlx5_rcache_region_t *recv_ptr;

UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_start", 0);
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_reg_start", 0);
tl_debug(UCC_TASK_LIB(task), "register memory buffers");
coll_task->status = UCC_INPROGRESS;
coll_task->super.status = UCC_INPROGRESS;
Expand Down Expand Up @@ -187,11 +215,20 @@ static ucc_status_t ucc_tl_mlx5_reg_fanin_start(ucc_coll_task_t *coll_task)
/* Start fanin */
ucc_tl_mlx5_get_my_ctrl(a2a, seq_index)->mkey_cache_flag = flag;
ucc_tl_mlx5_update_mkeys_entries(a2a, task, flag);
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_reg_done", 0);
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_start", 0);

a2a->node.fanin_index = 1;
a2a->node.fanin_dist = 1;
for (a2a->node.fanin_max_dist = 1;
a2a->node.fanin_max_dist < a2a->node.sbgp->group_size;
a2a->node.fanin_max_dist *=
UCC_TL_MLX5_TEAM_LIB(team)->cfg.fanin_kn_radix) {
}

if (UCC_OK == ucc_tl_mlx5_node_fanin(team, task)) {
tl_debug(UCC_TL_MLX5_TEAM_LIB(team), "fanin complete");
coll_task->status = UCC_OK;
UCC_TL_MLX5_PROFILE_REQUEST_EVENT(task, "mlx5_alltoall_fanin_done", 0);
return ucc_task_complete(coll_task);
}

Expand All @@ -204,7 +241,6 @@ void ucc_tl_mlx5_reg_fanin_progress(ucc_coll_task_t *coll_task)
ucc_tl_mlx5_schedule_t *task = TASK_SCHEDULE(coll_task);
ucc_tl_mlx5_team_t *team = SCHEDULE_TEAM(task);

ucc_assert(team->a2a->node.sbgp->group_rank == team->a2a->node.asr_rank);
if (UCC_OK == ucc_tl_mlx5_node_fanin(team, task)) {
tl_debug(UCC_TL_MLX5_TEAM_LIB(team), "fanin complete");
coll_task->status = UCC_OK;
Expand Down
9 changes: 9 additions & 0 deletions src/components/tl/mlx5/tl_mlx5.c
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,10 @@ static ucc_config_field_t ucc_tl_mlx5_lib_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_lib_config_t, mcast_conf.one_sided_reliability_enable),
UCC_CONFIG_TYPE_BOOL},

{"FANIN_KN_RADIX", "4", "Radix of the knomial tree fanin algorithm",
ucc_offsetof(ucc_tl_mlx5_lib_config_t, fanin_kn_radix),
UCC_CONFIG_TYPE_UINT},

{NULL}};

static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
Expand All @@ -126,6 +130,11 @@ static ucc_config_field_t ucc_tl_mlx5_context_config_table[] = {
ucc_offsetof(ucc_tl_mlx5_context_config_t, mcast_ctx_conf.ib_dev_name),
UCC_CONFIG_TYPE_STRING},

{"FANIN_NPOLLS", "1000",
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move this to knomial fan-in

"Number of shared memory polling before returning UCC_INPROGRESS during "
"internode FANIN",
ucc_offsetof(ucc_tl_mlx5_context_config_t, npolls), UCC_CONFIG_TYPE_UINT},

{NULL}};

UCC_CLASS_DEFINE_NEW_FUNC(ucc_tl_mlx5_lib_t, ucc_base_lib_t,
Expand Down
2 changes: 2 additions & 0 deletions src/components/tl/mlx5/tl_mlx5.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ typedef struct ucc_tl_mlx5_lib_config {
int dm_host;
ucc_tl_mlx5_ib_qp_conf_t qp_conf;
ucc_tl_mlx5_mcast_coll_comm_init_spec_t mcast_conf;
int fanin_kn_radix;
} ucc_tl_mlx5_lib_config_t;

typedef struct ucc_tl_mlx5_context_config {
ucc_tl_context_config_t super;
ucs_config_names_array_t devices;
ucc_tl_mlx5_mcast_ctx_params_t mcast_ctx_conf;
int npolls;
} ucc_tl_mlx5_context_config_t;

typedef struct ucc_tl_mlx5_lib {
Expand Down
Loading