@@ -21,6 +21,13 @@ limitations under the License.
21
21
#define EPSILON 1e-5
22
22
#define BLOCK_SIZE 32
23
23
#define TILE (SIZE, STRIDE ) ((((SIZE)-1 ) / (STRIDE)) + 1 )
24
+ #ifdef __HIP_PLATFORM_AMD__
25
+ #define __SHFL_DOWN (a, b ) __shfl_down (a, b)
26
+ #define __SHFL_XOR (a, b ) __shfl_xor (a, b)
27
+ #else
28
+ #define __SHFL_DOWN (a, b ) __shfl_down_sync(0xffffffff , a, b)
29
+ #define __SHFL_XOR (a, b ) __shfl_xor_sync(0xffffffff , a, b)
30
+ #endif
24
31
25
32
template <int warp_count, int load_count>
26
33
__global__ void CovarianceReductionKernel (
@@ -82,13 +89,11 @@ __global__ void CovarianceReductionKernel(
82
89
83
90
for (int i = 0 ; i < MATRIX_COMPONENT_COUNT; i++) {
84
91
float matrix_component = matrix[i];
85
-
86
- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 16 );
87
- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 8 );
88
- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 4 );
89
- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 2 );
90
- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 1 );
91
-
92
+ matrix_component += __SHFL_DOWN (matrix_component, 16 );
93
+ matrix_component += __SHFL_DOWN (matrix_component, 8 );
94
+ matrix_component += __SHFL_DOWN (matrix_component, 4 );
95
+ matrix_component += __SHFL_DOWN (matrix_component, 2 );
96
+ matrix_component += __SHFL_DOWN (matrix_component, 1 );
92
97
if (lane_index == 0 ) {
93
98
s_matrix_component[warp_index] = matrix_component;
94
99
}
@@ -97,23 +102,21 @@ __global__ void CovarianceReductionKernel(
97
102
98
103
if (warp_index == 0 ) {
99
104
matrix_component = s_matrix_component[lane_index];
100
-
101
105
if (warp_count >= 32 ) {
102
- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 16 );
106
+ matrix_component += __SHFL_DOWN ( matrix_component, 16 );
103
107
}
104
108
if (warp_count >= 16 ) {
105
- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 8 );
109
+ matrix_component += __SHFL_DOWN ( matrix_component, 8 );
106
110
}
107
111
if (warp_count >= 8 ) {
108
- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 4 );
112
+ matrix_component += __SHFL_DOWN ( matrix_component, 4 );
109
113
}
110
114
if (warp_count >= 4 ) {
111
- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 2 );
115
+ matrix_component += __SHFL_DOWN ( matrix_component, 2 );
112
116
}
113
117
if (warp_count >= 2 ) {
114
- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 1 );
118
+ matrix_component += __SHFL_DOWN ( matrix_component, 1 );
115
119
}
116
-
117
120
if (lane_index == 0 ) {
118
121
g_batch_matrices[matrix_offset + i] = matrix_component;
119
122
}
@@ -156,13 +159,11 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g
156
159
matrix_component += g_batch_matrices[(matrix_offset + matrix_index) * GMM_COMPONENT_COUNT + index ];
157
160
}
158
161
}
159
-
160
- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 16 );
161
- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 8 );
162
- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 4 );
163
- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 2 );
164
- matrix_component += __shfl_down_sync (0xffffffff , matrix_component, 1 );
165
-
162
+ matrix_component += __SHFL_DOWN (matrix_component, 16 );
163
+ matrix_component += __SHFL_DOWN (matrix_component, 8 );
164
+ matrix_component += __SHFL_DOWN (matrix_component, 4 );
165
+ matrix_component += __SHFL_DOWN (matrix_component, 2 );
166
+ matrix_component += __SHFL_DOWN (matrix_component, 1 );
166
167
if (lane_index == 0 ) {
167
168
s_matrix_component[warp_index] = matrix_component;
168
169
}
@@ -171,23 +172,21 @@ __global__ void CovarianceFinalizationKernel(const float* g_matrices, float* g_g
171
172
172
173
if (warp_index == 0 ) {
173
174
matrix_component = s_matrix_component[lane_index];
174
-
175
175
if (warp_count >= 32 ) {
176
- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 16 );
176
+ matrix_component += __SHFL_DOWN ( matrix_component, 16 );
177
177
}
178
178
if (warp_count >= 16 ) {
179
- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 8 );
179
+ matrix_component += __SHFL_DOWN ( matrix_component, 8 );
180
180
}
181
181
if (warp_count >= 8 ) {
182
- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 4 );
182
+ matrix_component += __SHFL_DOWN ( matrix_component, 4 );
183
183
}
184
184
if (warp_count >= 4 ) {
185
- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 2 );
185
+ matrix_component += __SHFL_DOWN ( matrix_component, 2 );
186
186
}
187
187
if (warp_count >= 2 ) {
188
- matrix_component += __shfl_down_sync ( 0xffffffff , matrix_component, 1 );
188
+ matrix_component += __SHFL_DOWN ( matrix_component, 1 );
189
189
}
190
-
191
190
if (lane_index == 0 ) {
192
191
float constant = i == 0 ? 0 .0f : s_gmm[i] * s_gmm[j];
193
192
@@ -261,13 +260,11 @@ __global__ void GMMFindSplit(GMMSplit_t* gmmSplit, int gmmK, float* gmm) {
261
260
}
262
261
263
262
float max_value = eigenvalue;
264
-
265
- max_value = max (max_value, __shfl_xor_sync (0xffffffff , max_value, 16 ));
266
- max_value = max (max_value, __shfl_xor_sync (0xffffffff , max_value, 8 ));
267
- max_value = max (max_value, __shfl_xor_sync (0xffffffff , max_value, 4 ));
268
- max_value = max (max_value, __shfl_xor_sync (0xffffffff , max_value, 2 ));
269
- max_value = max (max_value, __shfl_xor_sync (0xffffffff , max_value, 1 ));
270
-
263
+ max_value = max (max_value, __SHFL_XOR (max_value, 16 ));
264
+ max_value = max (max_value, __SHFL_XOR (max_value, 8 ));
265
+ max_value = max (max_value, __SHFL_XOR (max_value, 4 ));
266
+ max_value = max (max_value, __SHFL_XOR (max_value, 2 ));
267
+ max_value = max (max_value, __SHFL_XOR (max_value, 1 ));
271
268
if (max_value == eigenvalue) {
272
269
GMMSplit_t split;
273
270
@@ -347,12 +344,11 @@ __global__ void GMMcommonTerm(float* g_gmm) {
347
344
float gmm_n = threadIdx .x < MIXTURE_SIZE ? g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT] : 0 .0f ;
348
345
349
346
float sum = gmm_n;
350
-
351
- sum += __shfl_xor_sync (0xffffffff , sum, 1 );
352
- sum += __shfl_xor_sync (0xffffffff , sum, 2 );
353
- sum += __shfl_xor_sync (0xffffffff , sum, 4 );
354
- sum += __shfl_xor_sync (0xffffffff , sum, 8 );
355
- sum += __shfl_xor_sync (0xffffffff , sum, 16 );
347
+ sum += __SHFL_XOR (sum, 1 );
348
+ sum += __SHFL_XOR (sum, 2 );
349
+ sum += __SHFL_XOR (sum, 4 );
350
+ sum += __SHFL_XOR (sum, 8 );
351
+ sum += __SHFL_XOR (sum, 16 );
356
352
357
353
if (threadIdx .x < MIXTURE_SIZE) {
358
354
float det = g_batch_gmm[gmm_index * GMM_COMPONENT_COUNT + MATRIX_COMPONENT_COUNT] + EPSILON;
@@ -446,13 +442,14 @@ void GMMInitialize(
446
442
for (unsigned int k = MIXTURE_COUNT; k < gmm_N; k += MIXTURE_COUNT) {
447
443
for (unsigned int i = 0 ; i < k; ++i) {
448
444
CovarianceReductionKernel<WARPS, LOAD>
449
- <<<{ block_count, 1 , batch_count} , BLOCK>>> (i, image, alpha, block_gmm_scratch, element_count);
445
+ <<<dim3 ( block_count, 1 , batch_count) , BLOCK>>> (i, image, alpha, block_gmm_scratch, element_count);
450
446
}
451
447
452
- CovarianceFinalizationKernel<WARPS, false ><<<{ k, 1 , batch_count} , BLOCK>>> (block_gmm_scratch, gmm, block_count);
448
+ CovarianceFinalizationKernel<WARPS, false ><<<dim3 ( k, 1 , batch_count) , BLOCK>>> (block_gmm_scratch, gmm, block_count);
453
449
454
- GMMFindSplit<<<{1 , 1 , batch_count}, dim3 (BLOCK_SIZE, MIXTURE_COUNT)>>> (gmm_split_scratch, k / MIXTURE_COUNT, gmm);
455
- GMMDoSplit<<<{TILE (element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1 , batch_count}, BLOCK_SIZE>>> (
450
+ GMMFindSplit<<<dim3 (1 , 1 , batch_count), dim3 (BLOCK_SIZE, MIXTURE_COUNT)>>> (
451
+ gmm_split_scratch, k / MIXTURE_COUNT, gmm);
452
+ GMMDoSplit<<<dim3 (TILE(element_count, BLOCK_SIZE * DO_SPLIT_DEGENERACY), 1 , batch_count), BLOCK_SIZE>>> (
456
453
gmm_split_scratch, (k / MIXTURE_COUNT) << 4 , image, alpha, element_count);
457
454
}
458
455
}
@@ -472,12 +469,13 @@ void GMMUpdate(
472
469
473
470
for (unsigned int i = 0 ; i < gmm_N; ++i) {
474
471
CovarianceReductionKernel<WARPS, LOAD>
475
- <<<{ block_count, 1 , batch_count} , BLOCK>>> (i, image, alpha, block_gmm_scratch, element_count);
472
+ <<<dim3 ( block_count, 1 , batch_count) , BLOCK>>> (i, image, alpha, block_gmm_scratch, element_count);
476
473
}
477
474
478
- CovarianceFinalizationKernel<WARPS, true ><<<{gmm_N, 1 , batch_count}, BLOCK>>> (block_gmm_scratch, gmm, block_count);
475
+ CovarianceFinalizationKernel<WARPS, true >
476
+ <<<dim3 (gmm_N, 1 , batch_count), BLOCK>>> (block_gmm_scratch, gmm, block_count);
479
477
480
- GMMcommonTerm<<<{ 1 , 1 , batch_count} , dim3 (BLOCK_SIZE, MIXTURE_COUNT)>>> (gmm);
478
+ GMMcommonTerm<<<dim3 ( 1 , 1 , batch_count) , dim3 (BLOCK_SIZE, MIXTURE_COUNT)>>> (gmm);
481
479
}
482
480
483
481
void GMMDataTerm (
0 commit comments