-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcoll_vhx_bcast.c
93 lines (63 loc) · 3.43 KB
/
coll_vhx_bcast.c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#include "ompi_config.h"
#include "mpi.h"
#include "ompi/constants.h"
#include "ompi/datatype/ompi_datatype.h"
#include "ompi/communicator/communicator.h"
#include "opal/util/show_help.h"
#include "opal/util/minmax.h"
#include "coll_vhx.h"
//the root becomes the leader in all hierarchy groups it participates
int set_root_leader(int root, mca_coll_base_module_t * module, ompi_communicator_t * ompi_comm, int do_cico) {
int rank = ompi_comm_rank(ompi_comm);
int comm_size = ompi_comm_size(ompi_comm);
vhx_module_t * vhx_module = (vhx_module_t * ) module;
int hier_size = vhx_module -> hierarchy_size;
for (int i = 0; i < hier_size; i++) {
vhx_hier_group_t * hier_group = & (vhx_module -> hier_groups[i]);
if (hier_group -> members_bitmap[root] == 1 && hier_group -> real_members_bitmap[root] == 0) {
hier_group -> real_members_bitmap[root] = 1;
hier_group -> leader = root;
}
}
}
int mca_coll_vhx_bcast(void * buf, int count, ompi_datatype_t * datatype, int root,
ompi_communicator_t * ompi_comm, mca_coll_base_module_t * module) {
int rank = ompi_comm_rank(ompi_comm);
int comm_size = ompi_comm_size(ompi_comm);
vhx_module_t * vhx_module = (vhx_module_t * ) module;
if (((mca_coll_vhx_module_t * ) module) -> initialized == false) {
int ret = vhx_init(module, ompi_comm);
if (ret != 0) return OMPI_ERROR;
}
int hier_size = vhx_module -> hierarchy_size;
int pvt_seq = (vhx_module -> hier_groups[0].shared_ctrl_vars[rank].coll_seq) + 1;
size_t dtype_size, bytes_total;
ompi_datatype_type_size(datatype, & dtype_size);
bytes_total = count * dtype_size;
bool do_cico = (bytes_total <= OMPI_vhx_CICO_MAX);
do_cico = 1; //TODO, XPMEM support temporarily broken due to recent "refactoring" after discovering erroneous execution, will get fixed soon
set_root_leader(root, vhx_module, ompi_comm, do_cico);
if (do_cico && root == rank)
memcpy((char * )(vhx_module -> cico_buffer), buf, bytes_total);
for (int i = hier_size - 1; i >= 0; i--) {
vhx_hier_group_t * hier_group = & (vhx_module -> hier_groups[i]);
(hier_group -> shared_ctrl_vars[rank].coll_seq) = pvt_seq;
if (rank == hier_group -> leader) {
opal_atomic_wmb();
vhx_module -> hier_groups[i].shared_ctrl_vars[rank].coll_ack = pvt_seq; // we need the root's ack counter to be equal to pvt for consistency
if (i == 0 && rank != root)
memcpy(buf, (do_cico)?(char * )(vhx_module -> cico_buffer):hier_group->shared_ctrl_vars[hier_group->leader].sbuf_vaddr, bytes_total);
for (int i = 0; i < hier_group -> real_size; i++) //Waiting for the acks of the hierarchy group
WAIT_FLAG( & (hier_group -> shared_ctrl_vars[hier_group -> real_members[i]].coll_ack), pvt_seq);
} else if (hier_group -> real_members_bitmap[rank] == 1) {
WAIT_FLAG( & (vhx_module -> hier_groups[i].shared_ctrl_vars[hier_group -> leader].coll_seq), pvt_seq);
if (do_cico)
memcpy((i == 0) ? buf : (char * )(vhx_module -> cico_buffer), (char * )(vhx_module -> neighbour_cico_buffers[hier_group -> leader]), bytes_total); // on the bottom we need to write to buf even in cico scenarios
else
memcpy( buf, (char * )(hier_group->shared_ctrl_vars[hier_group->leader].sbuf_vaddr), bytes_total);
opal_atomic_wmb();
hier_group -> shared_ctrl_vars[rank].coll_ack = pvt_seq;
}
}
return OMPI_SUCCESS;
}