Skip to content

Commit c7b7b02

Browse files
authored
[gloo] Enable using c10::Half for gloo cuda
Differential Revision: D75909352 Pull Request resolved: #449
1 parent cc44198 commit c7b7b02

8 files changed

+17
-4
lines changed

gloo/cuda.cu

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,7 @@ static inline int cudaGetBlocks(const int N) {
283283
#define DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(T, Funcname, op) \
284284
__global__ void _Kernel_##T##_##Funcname( \
285285
T* dst, const T* src, const int n) { \
286-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
286+
for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
287287
i += blockDim.x * gridDim.x) { \
288288
dst[i] = dst[i] op src[i]; \
289289
} \
@@ -301,7 +301,7 @@ static inline int cudaGetBlocks(const int N) {
301301
#define DELEGATE_HALF_PRECISION_CUDA_BINARY_OPERATOR(Funcname, op) \
302302
__global__ void _Kernel_half_##Funcname( \
303303
half* dst, const half* src, const int n) { \
304-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
304+
for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
305305
i += blockDim.x * gridDim.x) { \
306306
float r = __half2float(dst[i]) op __half2float(src[i]); \
307307
dst[i] = __float2half(r); \
@@ -337,7 +337,7 @@ DELEGATE_HALF_PRECISION_CUDA_BINARY_OPERATOR(cudaProduct, *);
337337
#define DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(T, Funcname, op) \
338338
__global__ void _Kernel_##T##_##Funcname( \
339339
T* dst, const T* src, const int n) { \
340-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
340+
for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
341341
i += blockDim.x * gridDim.x) { \
342342
if (src[i] op dst[i]) { \
343343
dst[i] = src[i]; \
@@ -357,7 +357,7 @@ DELEGATE_HALF_PRECISION_CUDA_BINARY_OPERATOR(cudaProduct, *);
357357
#define DELEGATE_HALF_PRECISION_CUDA_BINARY_COMPARE(Funcname, op) \
358358
__global__ void _Kernel_half_##Funcname( \
359359
half* dst, const half* src, const int n) { \
360-
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
360+
for (auto i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
361361
i += blockDim.x * gridDim.x) { \
362362
if (__half2float(src[i]) op __half2float(dst[i])) { \
363363
dst[i] = src[i]; \
@@ -398,6 +398,12 @@ DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(BFloat16, cudaSum, +);
398398
DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(BFloat16, cudaProduct, *);
399399
DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(BFloat16, cudaMin, <);
400400
DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(BFloat16, cudaMax, >);
401+
using Half = c10::Half;
402+
INSTANTIATE_COPY_ASYNC(Half);
403+
DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(Half, cudaSum, +);
404+
DELEGATE_SIMPLE_CUDA_BINARY_OPERATOR(Half, cudaProduct, *);
405+
DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(Half, cudaMin, <);
406+
DELEGATE_SIMPLE_CUDA_BINARY_COMPARE(Half, cudaMax, >);
401407
#endif
402408
403409
} // namespace gloo

gloo/cuda_allreduce_bcube.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,7 @@ INSTANTIATE_TEMPLATE(float16);
516516

517517
#if GLOO_USE_TORCH_DTYPES
518518
INSTANTIATE_TEMPLATE(c10::BFloat16);
519+
INSTANTIATE_TEMPLATE(c10::Half);
519520
#endif
520521

521522
} // namespace gloo

gloo/cuda_allreduce_halving_doubling.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,7 @@ INSTANTIATE_TEMPLATE(float16);
659659

660660
#if GLOO_USE_TORCH_DTYPES
661661
INSTANTIATE_TEMPLATE(c10::BFloat16);
662+
INSTANTIATE_TEMPLATE(c10::Half);
662663
#endif
663664

664665
} // namespace gloo

gloo/cuda_allreduce_local.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ INSTANTIATE_TEMPLATE(float16);
7878

7979
#if GLOO_USE_TORCH_DTYPES
8080
INSTANTIATE_TEMPLATE(c10::BFloat16);
81+
INSTANTIATE_TEMPLATE(c10::Half);
8182
#endif
8283

8384
} // namespace gloo

gloo/cuda_allreduce_ring.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,7 @@ INSTANTIATE_TEMPLATE(float16);
190190

191191
#if GLOO_USE_TORCH_DTYPES
192192
INSTANTIATE_TEMPLATE(c10::BFloat16);
193+
INSTANTIATE_TEMPLATE(c10::Half);
193194
#endif
194195

195196
} // namespace gloo

gloo/cuda_allreduce_ring_chunked.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -367,6 +367,7 @@ INSTANTIATE_TEMPLATE(float16);
367367

368368
#if GLOO_USE_TORCH_DTYPES
369369
INSTANTIATE_TEMPLATE(c10::BFloat16);
370+
INSTANTIATE_TEMPLATE(c10::Half);
370371
#endif
371372

372373
} // namespace gloo

gloo/cuda_broadcast_one_to_all.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ INSTANTIATE_TEMPLATE(float16);
199199

200200
#if GLOO_USE_TORCH_DTYPES
201201
INSTANTIATE_TEMPLATE(c10::BFloat16);
202+
INSTANTIATE_TEMPLATE(c10::Half);
202203
#endif
203204

204205
} // namespace gloo

gloo/cuda_private.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222

2323
#if GLOO_USE_TORCH_DTYPES
2424
#include <c10/util/BFloat16.h>
25+
#include <c10/util/Half.h>
2526
#endif
2627

2728
namespace gloo {

0 commit comments

Comments
 (0)