Skip to content

Commit b60f69e

Browse files
authored
Add supports for AMD GPU (#6161)
Fixes # . ### Description Add supports for AMD GPU: ~1. add docker file "Dockerfile.amd " for creating docker image to support AMD GPU. The file is based on "Dockerfile" for NVIDIA GPU~ 2. In monai/_extensions/gmm/gmm_cuda.cu , replaces __shfl_down_sync and __shfl_xor_sync with __shfl_down and __shfl_xor when build for AMD GPU. ### Types of changes <!--- Put an `x` in all the boxes that apply, and remove the not applicable items --> - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] Integration tests passed locally by running `./runtests.sh -f -u --net --coverage`. - [ ] Quick tests passed locally by running `./runtests.sh --quick --unittests --disttests`. - [ ] In-line docstrings updated. - [ ] Documentation updated, tested `make html` command in the `docs/` folder. --------- Signed-off-by: Yaoming Mu <[email protected]> Signed-off-by: monai-bot <[email protected]>
1 parent ef2bd45 commit b60f69e

File tree

1 file changed

+46
-48
lines changed

1 file changed

+46
-48
lines changed

monai/_extensions/gmm/gmm_cuda.cu

Lines changed: 46 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,13 @@ limitations under the License.
2121
#define EPSILON 1e-5
2222
#define BLOCK_SIZE 32
2323
#define TILE(SIZE, STRIDE) ((((SIZE)-1) / (STRIDE)) + 1)
24+
#ifdef __HIP_PLATFORM_AMD__
25+
#define __SHFL_DOWN(a, b) __shfl_down(a, b)
26+
#define __SHFL_XOR(a, b) __shfl_xor(a, b)
27+
#else
28+
#define __SHFL_DOWN(a, b) __shfl_down_sync(0xffffffff, a, b)
29+
#define __SHFL_XOR(a, b) __shfl_xor_sync(0xffffffff, a, b)
30+
#endif
2431

2532
template <int warp_count, int load_count>
2633
__global__ void CovarianceReductionKernel(
@@ -82,13 +89,11 @@ __global__ void CovarianceReductionKernel(
8289

8390
for (int i = 0; i < MATRIX_COMPONENT_COUNT; i++) {
8491
float matrix_component = matrix[i];
85-
86-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
87-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
88-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
89-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
90-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
91-
92+
matrix_component += __SHFL_DOWN(matrix_component, 16);
93+
matrix_component += __SHFL_DOWN(matrix_component, 8);
94+
matrix_component += __SHFL_DOWN(matrix_component, 4);
95+
matrix_component += __SHFL_DOWN(matrix_component, 2);
96+
matrix_component += __SHFL_DOWN(matrix_component, 1);
9297
if (lane_index == 0) {
9398
s_matrix_component[warp_index] = matrix_component;
9499
}
@@ -97,23 +102,21 @@ __global__ void CovarianceReductionKernel(
97102

98103
if (warp_index == 0) {
99104
matrix_component = s_matrix_component[lane_index];
100-
101105
if (warp_count >= 32) {
102-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
106+
matrix_component += __SHFL_DOWN(matrix_component, 16);
103107
}
104108
if (warp_count >= 16) {
105-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
109+
matrix_component += __SHFL_DOWN(matrix_component, 8);
106110
}
107111
if (warp_count >= 8) {
108-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
112+
matrix_component += __SHFL_DOWN(matrix_component, 4);
109113
}
110114
if (warp_count >= 4) {
111-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
115+
matrix_component += __SHFL_DOWN(matrix_component, 2);
112116
}
113117
if (warp_count >= 2) {
114-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
118+
matrix_component += __SHFL_DOWN(matrix_component, 1);
115119
}
116-
117120
if (lane_index == 0) {
118121
g_batch_matrices[matrix_offset + i] = matrix_component;
119122
}
@@ -156,13 +159,11 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g
156159
matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index];
157160
}
158161
}
159-
160-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
161-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
162-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
163-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
164-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
165-
162+
matrix_component += __SHFL_DOWN(matrix_component, 16);
163+
matrix_component += __SHFL_DOWN(matrix_component, 8);
164+
matrix_component += __SHFL_DOWN(matrix_component, 4);
165+
matrix_component += __SHFL_DOWN(matrix_component, 2);
166+
matrix_component += __SHFL_DOWN(matrix_component, 1);
166167
if (lane_index == 0) {
167168
s_matrix_component[warp_index] = matrix_component;
168169
}
@@ -171,23 +172,21 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g
171172

172173
if (warp_index == 0) {
173174
matrix_component = s_matrix_component[lane_index];
174-
175175
if (warp_count >= 32) {
176-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 16);
176+
matrix_component += __SHFL_DOWN(matrix_component, 16);
177177
}
178178
if (warp_count >= 16) {
179-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 8);
179+
matrix_component += __SHFL_DOWN(matrix_component, 8);
180180
}
181181
if (warp_count >= 8) {
182-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 4);
182+
matrix_component += __SHFL_DOWN(matrix_component, 4);
183183
}
184184
if (warp_count >= 4) {
185-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 2);
185+
matrix_component += __SHFL_DOWN(matrix_component, 2);
186186
}
187187
if (warp_count >= 2) {
188-
matrix_component += __shfl_down_sync(0xffffffff, matrix_component, 1);
188+
matrix_component += __SHFL_DOWN(matrix_component, 1);
189189
}
190-
191190
if (lane_index == 0) {
192191
float constant = i == 0 ? 0.0f : s_gmm[i] * s_gmm[j];
193192

@@ -261,13 +260,11 @@ __global__ void GMMFindSplit(GMMSplit_t* gmmSplit, int gmmK, float* gmm) {
261260
}
262261

263262
float max_value = eigenvalue;
264-
265-
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 16));
266-
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 8));
267-
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 4));
268-
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 2));
269-
max_value = max(max_value, __shfl_xor_sync(0xffffffff, max_value, 1));
270-
263+
max_value = max(max_value, __SHFL_XOR(max_value, 16));
264+
max_value = max(max_value, __SHFL_XOR(max_value, 8));
265+
max_value = max(max_value, __SHFL_XOR(max_value, 4));
266+
max_value = max(max_value, __SHFL_XOR(max_value, 2));
267+
max_value = max(max_value, __SHFL_XOR(max_value, 1));
271268
if (max_value == eigenvalue) {
272269
GMMSplit_t split;
273270

@@ -347,12 +344,11 @@ __global__ void GMMcommonTerm(float* g_gmm) {
347344
float gmm_n = threadIdx.x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0.0f;
348345

349346
float sum = gmm_n;
350-
351-
sum += __shfl_xor_sync(0xffffffff, sum, 1);
352-
sum += __shfl_xor_sync(0xffffffff, sum, 2);
353-
sum += __shfl_xor_sync(0xffffffff, sum, 4);
354-
sum += __shfl_xor_sync(0xffffffff, sum, 8);
355-
sum += __shfl_xor_sync(0xffffffff, sum, 16);
347+
sum += __SHFL_XOR(sum, 1);
348+
sum += __SHFL_XOR(sum, 2);
349+
sum += __SHFL_XOR(sum, 4);
350+
sum += __SHFL_XOR(sum, 8);
351+
sum += __SHFL_XOR(sum, 16);
356352

357353
if (threadIdx.x < MIXTURE_SIZE) {
358354
float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON;
@@ -446,13 +442,14 @@ void GMMInitialize(
446442
for (unsigned int k = MIXTURE_COUNT; k < gmm_N; k += MIXTURE_COUNT) {
447443
for (unsigned int i = 0; i < k; ++i) {
448444
CovarianceReductionKernel<WARPS, LOAD>
449-
<<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
445+
<<<dim3(block_count, 1, batch_count), BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
450446
}
451447

452-
CovarianceFinalizationKernel<WARPS, false><<<{k, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count);
448+
CovarianceFinalizationKernel<WARPS, false><<<dim3(k, 1, batch_count), BLOCK>>>(block_gmm_scratch, gmm, block_count);
453449

454-
GMMFindSplit<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm_split_scratch, k / MIXTURE_COUNT, gmm);
455-
GMMDoSplit<<<{TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count}, BLOCK_SIZE>>>(
450+
GMMFindSplit<<<dim3(1, 1, batch_count), dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(
451+
gmm_split_scratch, k / MIXTURE_COUNT, gmm);
452+
GMMDoSplit<<<dim3(TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1, batch_count), BLOCK_SIZE>>>(
456453
gmm_split_scratch, (k / MIXTURE_COUNT) << 4, image, alpha, element_count);
457454
}
458455
}
@@ -472,12 +469,13 @@ void GMMUpdate(
472469

473470
for (unsigned int i = 0; i < gmm_N; ++i) {
474471
CovarianceReductionKernel<WARPS, LOAD>
475-
<<<{block_count, 1, batch_count}, BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
472+
<<<dim3(block_count, 1, batch_count), BLOCK>>>(i, image, alpha, block_gmm_scratch, element_count);
476473
}
477474

478-
CovarianceFinalizationKernel<WARPS, true><<<{gmm_N, 1, batch_count}, BLOCK>>>(block_gmm_scratch, gmm, block_count);
475+
CovarianceFinalizationKernel<WARPS, true>
476+
<<<dim3(gmm_N, 1, batch_count), BLOCK>>>(block_gmm_scratch, gmm, block_count);
479477

480-
GMMcommonTerm<<<{1, 1, batch_count}, dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm);
478+
GMMcommonTerm<<<dim3(1, 1, batch_count), dim3(BLOCK_SIZE, MIXTURE_COUNT)>>>(gmm);
481479
}
482480

483481
void GMMDataTerm(

0 commit comments

Comments
 (0)