Skip to content

Directed rounding #2576

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,9 @@ function main()
"development/troubleshooting.md",
"development/debugging.md",
],
"Hacking" => Any[
"hacking/exposing_new_intrinsics.md",
],
"API reference" => Any[
"api/essentials.md",
"api/array.md",
Expand Down
49 changes: 49 additions & 0 deletions docs/src/hacking/exposing_new_intrinsics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# # Introduction

# * Adding new GPU intrinsics *

# In this tutorial we will expose some GPU intrinsics to allow directed rounding in fused-multiply-add (fma)
# floating point operation
# We start by identifying the intrinsic we want to expose; to do so, we read the PTX (Parallel Thread Execution)
# documentation at [PTX - Floating Point Instructions](https://docs.nvidia.com/cuda/parallel-thread-execution/#floating-point-instructions).
# In table 32, it is presented a summary of floating point operations: we can construct the intrinsic string from that.
# The FMA instruction for Float32 is presented as `{mad,fma}.rnd.f32`, where `rnd` can assume the values `.rnd = { .rn, .rz, .rm, .rp }`,
# where `rn` is round to nearest, `rz` round to zero, `rm` round to minus infinity, `rp` round to plus infinity.
# When building the intrinsic for the call, we need to change the type `.f64` with `.d` and `.f32` with `.f`
# Therefore, to call the rounded towards infinity `fma` for `.f64` we need to call the intrinsic `llvm.nvvm.fma.rp.d`
# Please remark that this is only possible if LLVM support the intrinsic; a source for those exposed by LLVM
# may be found by searching the [LLVM repository](https://github.com/llvm/llvm-project). In in other cases you'd need @asmcall and inline PTX assembly.

fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z)

# We inspect the PTX code
CUDA.code_ptx(fma_rp, Tuple{Float64,Float64,Float64})

# It is possible to see that the PTX code contains a call to the intrinsic `fma.rp.f64`; we add this function now
# to src/device/intrins/math.jl

using CUDA
function test_fma!(out, x, y)
I = threadIdx().x
z = (2.0) ^ (-(I+53))

out[I] = fma(x, y, z, RoundNearest)
out[I+4] = fma(x, y, z, RoundToZero)
out[I+8] = fma(x, y, z, RoundUp)
out[I+12] = fma(x, y, z, RoundDown)

return
end

# The first four entries of the output are Rounded to Nearest, the entries 5 to 8 are rounded towards zero,
# etc...

out_d = CuArray(zeros(16))
@cuda threads = 4 test_fma!(out_d, 1.0, 1.0)
out_h = Array(out_d)

out_d = CuArray(zeros(4))
@cuda threads = 4 test_fma!(out_d, -1.0, 1.0)
out_h = Array(out_d)

63 changes: 61 additions & 2 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -390,18 +390,77 @@ end
@device_function normcdfinv(x::Float64) = ccall("extern __nv_normcdfinv", llvmcall, Cdouble, (Cdouble,), x)
@device_function normcdfinv(x::Float32) = ccall("extern __nv_normcdfinvf", llvmcall, Cfloat, (Cfloat,), x)



Comment on lines -393 to -394
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unrelated change.

#
# Unsorted
#

@device_override Base.hypot(x::Float64, y::Float64) = ccall("extern __nv_hypot", llvmcall, Cdouble, (Cdouble, Cdouble), x, y)
@device_override Base.hypot(x::Float32, y::Float32) = ccall("extern __nv_hypotf", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)


for type in [:f, :d]
for round in [:rn, :rz, :rm, :rp]
for op in [:add, :mul, :div]

inp_type = Symbol("Float64")
c_type = Symbol("Cdouble")
if type == :f
inp_type = Symbol("Float32")
c_type = Symbol("Cfloat")
end

func_name = Symbol("$(op)_$(round)")
intrinsic_name = "llvm.nvvm.$(op).$(round).$(type)"
#@info func_name, intrinsic_name

@eval @device_function $func_name(x::$inp_type, y::$inp_type) = ccall($intrinsic_name, llvmcall, $c_type, ($c_type, $c_type), x, y)
end
end
end

@device_function sub_rn(x, y) = add_rn(x, -y)
@device_function sub_rz(x, y) = add_rz(x, -y)
@device_function sub_rm(x, y) = add_rm(x, -y)
@device_function sub_rp(x, y) = add_rp(x, -y)

@device_function add(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = add_rn(x, y)
@device_function add(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = add_rz(x, y)
@device_function add(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = add_rm(x, y)
@device_function add(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = add_rp(x, y)

@device_function sub(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = sub_rn(x, y)
@device_function sub(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = sub_rz(x, y)
@device_function sub(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = sub_rm(x, y)
@device_function sub(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = sub_rp(x, y)

@device_function mul(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = mul_rn(x, y)
@device_function mul(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = mul_rz(x, y)
@device_function mul(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = mul_rm(x, y)
@device_function mul(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = mul_rp(x, y)

@device_function div(x::T, y::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = div_rn(x, y)
@device_function div(x::T, y::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = div_rz(x, y)
@device_function div(x::T, y::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = div_rm(x, y)
@device_function div(x::T, y::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = div_rp(x, y)



@device_override Base.fma(x::Float64, y::Float64, z::Float64) = ccall("extern __nv_fma", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
@device_override Base.fma(x::Float32, y::Float32, z::Float32) = ccall("extern __nv_fmaf", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.fma(x::Float16, y::Float16, z::Float16) = ccall("llvm.fma.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z)
@device_function fma_rn(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rn.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
@device_function fma_rn(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rn.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_function fma_rz(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rz.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
@device_function fma_rz(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rz.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_function fma_rm(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rm.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
@device_function fma_rm(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rm.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_function fma_rp(x::Float64, y::Float64, z::Float64) = ccall("llvm.nvvm.fma.rp.d", llvmcall, Cdouble, (Cdouble, Cdouble, Cdouble), x, y, z)
@device_function fma_rp(x::Float32, y::Float32, z::Float32) = ccall("llvm.nvvm.fma.rp.f", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)

@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Nearest}) where {T <: Union{Float32, Float64}} = fma_rn(x, y, z)
@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:ToZero}) where {T <: Union{Float32, Float64}} = fma_rz(x, y, z)
@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Down}) where {T <: Union{Float32, Float64}} = fma_rm(x, y, z)
@device_override Base.fma(x::T, y::T, z::T, ::RoundingMode{:Up}) where {T <: Union{Float32, Float64}} = fma_rp(x, y, z)

@device_function sad(x::Int32, y::Int32, z::Int32) = ccall("extern __nv_sad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z)
@device_function sad(x::UInt32, y::UInt32, z::UInt32) = convert(UInt32, ccall("extern __nv_usad", llvmcall, Int32, (Int32, Int32, Int32), x, y, z))
Expand Down
8 changes: 6 additions & 2 deletions src/device/intrinsics/wmma.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,9 @@ const map_ptx_to_jl_frag = Dict(
"f32" => Float32
)

# Maps matrix & PTX types to fragment sizes
# Maps matrix & PTX types to fragment sizes, information retrieved from
# https://docs.nvidia.com/cuda/parallel-thread-execution/index.html?highlight=wmma#matrix-fragments-for-wmma

const map_frag_sizes = Dict(
# A
"a.u8.m16n16k16" => 2,
Expand Down Expand Up @@ -491,7 +493,9 @@ julia> config = WMMA.Config{16, 16, 16, Float32}
CUDA.WMMA.Config{16, 16, 16, Float32}
```
"""
struct Config{M, N, K, d_type} end
struct ConfigRounding{M, N, K, d_type, rounding} end

Config{M, N, K, d_type} = ConfigRounding{M, N, K, d_type, RoundNearest}

# ---------
# Constants
Expand Down
Loading