Skip to content

[SYCL][CUDA][libclc] Add approx. tanhf built-in #5265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 11 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions clang/include/clang/Basic/BuiltinsNVPTX.def
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,12 @@ BUILTIN(__nvvm_sin_approx_f, "ff", "")
BUILTIN(__nvvm_cos_approx_ftz_f, "ff", "")
BUILTIN(__nvvm_cos_approx_f, "ff", "")

// Tanh

TARGET_BUILTIN(__nvvm_tanh_approx_f, "ff", "", AND(SM_75,PTX70))
TARGET_BUILTIN(__nvvm_tanh_approx_f16, "hh", "", AND(SM_75, PTX70))
TARGET_BUILTIN(__nvvm_tanh_approx_f16x2, "V2hV2h", "", AND(SM_75, PTX70))

// Fma

BUILTIN(__nvvm_fma_rn_ftz_f, "ffff", "")
Expand Down
3 changes: 3 additions & 0 deletions clang/include/clang/Basic/TargetOptions.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,9 @@ class TargetOptions {
/// \brief If enabled, use precise square root
bool NVVMCudaPrecSqrt = false;

/// \brief If enabled, use approximate tanh
bool NVVMCudaApproxTanhf = false;

/// \brief If enabled, allow AMDGPU unsafe floating point atomics.
bool AllowAMDGPUUnsafeFPAtomics = false;

Expand Down
5 changes: 4 additions & 1 deletion clang/include/clang/Driver/Options.td
Original file line number Diff line number Diff line change
Expand Up @@ -4726,7 +4726,10 @@ def fno_sycl_device_lib_EQ : CommaJoined<["-"], "fno-sycl-device-lib=">, Group<s
Values<"libc, libm-fp32, libm-fp64, all">, HelpText<"Control exclusion of "
"device libraries from device binary linkage. Valid arguments "
"are libc, libm-fp32, libm-fp64, all">;

defm nvvm_cuda_approx_tanh : BoolFOption<"sycl-cuda-approx-tanh",
TargetOpts<"NVVMCudaApproxTanhf">, DefaultFalse,
PosFlag<SetTrue, [CC1Option], "Use the built-in fast approximation of tanh function for devices having c.c.>=7.5">,
NegFlag<SetFalse>>;
//===----------------------------------------------------------------------===//
// FLangOption + CoreOption + NoXarchOption
//===----------------------------------------------------------------------===//
Expand Down
6 changes: 6 additions & 0 deletions clang/lib/CodeGen/CodeGenModule.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -782,6 +782,12 @@ void CodeGenModule::Release() {
getTarget().getTargetOpts().NVVMCudaPrecSqrt);
}

if (LangOpts.isSYCL() && getTriple().isNVPTX()) {
getModule().addModuleFlag(llvm::Module::Override,
"nvvm-reflect-approx-tanhf",
getTarget().getTargetOpts().NVVMCudaApproxTanhf);
}

if (LangOpts.EHAsynch)
getModule().addModuleFlag(llvm::Module::Warning, "eh-asynch", 1);

Expand Down
11 changes: 11 additions & 0 deletions clang/test/CodeGenCUDA/nvvm-reflect-approx-tanh.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
// RUN: %clang_cc1 -fsycl-is-device -triple nvptx64-nvidia-cuda -emit-llvm -fsycl-cuda-approx-tanh %s -o -| FileCheck --check-prefix=CHECK-ON %s
// RUN: %clang_cc1 -fsycl-is-device -triple nvptx64-nvidia-cuda -emit-llvm %s -o -| FileCheck --check-prefix=CHECK-OFF %s

#include "Inputs/cuda.h"

// Check that the -fsycl-cuda-approx-tanh flag correctly sets the nvvm-reflect module flags.

extern "C" __device__ void foo() {}

// CHECK-ON: !{i32 4, !"nvvm-reflect-approx-tanhf", i32 1}
// CHECK-OFF: !{i32 4, !"nvvm-reflect-approx-tanhf", i32 0}
35 changes: 25 additions & 10 deletions libclc/generic/include/clcmacro.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,7 @@
#ifndef __CLC_MACRO_H
#define __CLC_MACRO_H

#define _CLC_UNARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE) \
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x) { \
return (RET_TYPE##2)(FUNCTION(x.x), FUNCTION(x.y)); \
} \
\
#define _CLC_UNARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE) \
DECLSPEC RET_TYPE##3 FUNCTION(ARG1_TYPE##3 x) { \
return (RET_TYPE##3)(FUNCTION(x.x), FUNCTION(x.y), FUNCTION(x.z)); \
} \
Expand All @@ -30,12 +26,14 @@
return (RET_TYPE##16)(FUNCTION(x.lo), FUNCTION(x.hi)); \
}

#define _CLC_BINARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
ARG2_TYPE) \
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x, ARG2_TYPE##2 y) { \
return (RET_TYPE##2)(FUNCTION(x.x, y.x), FUNCTION(x.y, y.y)); \
#define _CLC_UNARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE) \
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x) { \
return (RET_TYPE##2)(FUNCTION(x.x), FUNCTION(x.y)); \
} \
\
_CLC_UNARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE)

#define _CLC_BINARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
ARG2_TYPE) \
DECLSPEC RET_TYPE##3 FUNCTION(ARG1_TYPE##3 x, ARG2_TYPE##3 y) { \
return (RET_TYPE##3)(FUNCTION(x.x, y.x), FUNCTION(x.y, y.y), \
FUNCTION(x.z, y.z)); \
Expand All @@ -53,6 +51,14 @@
return (RET_TYPE##16)(FUNCTION(x.lo, y.lo), FUNCTION(x.hi, y.hi)); \
}

#define _CLC_BINARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
ARG2_TYPE) \
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x, ARG2_TYPE##2 y) { \
return (RET_TYPE##2)(FUNCTION(x.x, y.x), FUNCTION(x.y, y.y)); \
} \
_CLC_BINARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
ARG2_TYPE)

#define _CLC_V_S_V_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
ARG2_TYPE) \
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE x, ARG2_TYPE##2 y) { \
Expand Down Expand Up @@ -107,6 +113,15 @@
FUNCTION(x.hi, y.hi, z.hi)); \
}

#define _CLC_TERNARY_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
ARG2_TYPE, ARG3_TYPE) \
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE##2 x, ARG2_TYPE##2 y, \
ARG3_TYPE##2 z) { \
return (RET_TYPE##2)(FUNCTION(x.x, y.x, z.x), FUNCTION(x.y, y.y, z.y)); \
} \
_CLC_TERNARY_VECTORIZE_HAVE2(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
ARG2_TYPE, ARG3_TYPE)

#define _CLC_V_S_S_V_VECTORIZE(DECLSPEC, RET_TYPE, FUNCTION, ARG1_TYPE, \
ARG2_TYPE, ARG3_TYPE) \
DECLSPEC RET_TYPE##2 FUNCTION(ARG1_TYPE x, ARG2_TYPE y, ARG3_TYPE##2 z) { \
Expand Down
46 changes: 42 additions & 4 deletions libclc/ptx-nvidiacl/libspirv/math/tanh.cl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,45 @@
#include "../../include/libdevice.h"
#include <clcmacro.h>

#define __CLC_FUNCTION __spirv_ocl_tanh
#define __CLC_BUILTIN __nv_tanh
#define __CLC_BUILTIN_F __CLC_XCONCAT(__CLC_BUILTIN, f)
#include <math/unary_builtin.inc>
extern int __clc_nvvm_reflect_arch();
extern int __clc_nvvm_reflect_approx_tanh();

#define __USE_TANH_APPROX \
(__clc_nvvm_reflect_approx_tanh() && (__clc_nvvm_reflect_arch() >= 750))

#ifdef cl_khr_fp64

#pragma OPENCL EXTENSION cl_khr_fp64 : enable

_CLC_DEF _CLC_OVERLOAD double __spirv_ocl_tanh(double x) {
return __nv_tanh(x);
}

_CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, double, __spirv_ocl_tanh, double)

#endif

_CLC_DEF _CLC_OVERLOAD float __spirv_ocl_tanh(float x) {
return (__USE_TANH_APPROX) ? __nvvm_tanh_approx_f(x) : __nv_tanhf(x);
}

_CLC_UNARY_VECTORIZE(_CLC_OVERLOAD _CLC_DEF, float, __spirv_ocl_tanh, float)

#ifdef cl_khr_fp16

#pragma OPENCL EXTENSION cl_khr_fp16 : enable

_CLC_DEF _CLC_OVERLOAD half __spirv_ocl_tanh(half x) {
return (__USE_TANH_APPROX) ? __nvvm_tanh_approx_f16(x) : __nv_tanhf(x);
}

_CLC_DEF _CLC_OVERLOAD half2 __spirv_ocl_tanh(half2 x) {
return (__USE_TANH_APPROX) ? __nvvm_tanh_approx_f16x2(x)
: (half2)(__nv_tanhf(x.x), __nv_tanhf(x.y));
}

_CLC_UNARY_VECTORIZE_HAVE2(_CLC_OVERLOAD _CLC_DEF, half, __spirv_ocl_tanh, half)

#endif

#undef __USE_TANH_APPROX
7 changes: 7 additions & 0 deletions libclc/ptx-nvidiacl/libspirv/reflect.ll
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,10 @@ define i32 @__clc_nvvm_reflect_arch() alwaysinline {
%reflect = call i32 @__nvvm_reflect(i8* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([12 x i8], [12 x i8] addrspace(1)* @str, i32 0, i32 0) to i8*))
ret i32 %reflect
}

@str_approx_tanh = private addrspace(1) constant [20 x i8] c"__CUDA_APPROX_TANHF\00"

define i32 @__clc_nvvm_reflect_approx_tanh() alwaysinline {
%reflect = call i32 @__nvvm_reflect(i8* addrspacecast (i8 addrspace(1)* getelementptr inbounds ([20 x i8], [20 x i8] addrspace(1)* @str_approx_tanh, i32 0, i32 0) to i8*))
ret i32 %reflect
}
11 changes: 11 additions & 0 deletions llvm/include/llvm/IR/IntrinsicsNVVM.td
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,17 @@ let TargetPrefix = "nvvm" in {
def int_nvvm_cos_approx_f : GCCBuiltin<"__nvvm_cos_approx_f">,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;

//
// Tanh
//

def int_nvvm_tanh_approx_f : GCCBuiltin<"__nvvm_tanh_approx_f">,
DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty], [IntrNoMem]>;
def int_nvvm_tanh_approx_f16 : GCCBuiltin<"__nvvm_tanh_approx_f16">,
DefaultAttrsIntrinsic<[llvm_half_ty], [llvm_half_ty], [IntrNoMem]>;
def int_nvvm_tanh_approx_f16x2 : GCCBuiltin<"__nvvm_tanh_approx_f16x2">,
DefaultAttrsIntrinsic<[llvm_v2f16_ty], [llvm_v2f16_ty], [IntrNoMem]>;

//
// Fma
//
Expand Down
11 changes: 11 additions & 0 deletions llvm/lib/Target/NVPTX/NVPTXIntrinsics.td
Original file line number Diff line number Diff line change
Expand Up @@ -784,6 +784,17 @@ def INT_NVVM_COS_APPROX_FTZ_F : F_MATH_1<"cos.approx.ftz.f32 \t$dst, $src0;",
def INT_NVVM_COS_APPROX_F : F_MATH_1<"cos.approx.f32 \t$dst, $src0;",
Float32Regs, Float32Regs, int_nvvm_cos_approx_f>;

//
// Tanh
//

def INT_NVVM_TANH_APPROX_F : F_MATH_1<"tanh.approx.f32 \t$dst, $src0;",
Float32Regs, Float32Regs, int_nvvm_tanh_approx_f>;
def INT_NVVM_TANH_APPROX_F16 : F_MATH_1<"tanh.approx.f16 \t$dst, $src0;",
Float16Regs, Float16Regs, int_nvvm_tanh_approx_f16>;
def INT_NVVM_TANH_APPROX_F16X2 : F_MATH_1<"tanh.approx.f16x2 \t$dst, $src0;",
Float16x2Regs, Float16x2Regs, int_nvvm_tanh_approx_f16x2>;

//
// Fma
//
Expand Down
6 changes: 6 additions & 0 deletions llvm/lib/Target/NVPTX/NVVMReflect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ static bool runNVVMReflect(Function &F, unsigned SmVersion) {
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
F.getParent()->getModuleFlag("nvvm-reflect-prec-sqrt")))
ReflectVal = Flag->getSExtValue();
} else if (ReflectArg == "__CUDA_APPROX_TANHF") {
// Try to pull __CUDA_APPROX_TANHF from the nvvm-reflect-approx-tanhf
// module flag.
if (auto *Flag = mdconst::extract_or_null<ConstantInt>(
F.getParent()->getModuleFlag("nvvm-reflect-approx-tanhf")))
ReflectVal = Flag->getSExtValue();
}
Call->replaceAllUsesWith(ConstantInt::get(Call->getType(), ReflectVal));
ToRemove.push_back(Call);
Expand Down