Skip to content
This repository was archived by the owner on Nov 18, 2020. It is now read-only.

Implement OCKL wavefront intrinsics #72

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Manifest.toml
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ uuid = "9fa8497b-333b-5362-9e8d-4d0656e87820"

[[GPUCompiler]]
deps = ["DataStructures", "InteractiveUtils", "LLVM", "Libdl", "TimerOutputs", "UUIDs"]
git-tree-sha1 = "65f7395a1245635f0c2279649fdbef09a1b0aa7b"
repo-rev = "master"
git-tree-sha1 = "95aa07bfda5c80ccd57b038570ef79f663ce531f"
repo-rev = "jps/gcn-workaround-allocas"
repo-url = "https://github.com/JuliaGPU/GPUCompiler.jl.git"
uuid = "61eb1bfa-7361-4325-ad38-22787b887f55"
version = "0.4.0"
Expand Down
4 changes: 3 additions & 1 deletion src/device/gcn.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
include(joinpath("gcn", "intrinsics.jl"))
if Base.libllvm_version >= v"7.0"
include(joinpath("gcn", "math.jl"))
include(joinpath("gcn", "ocml.jl"))
include(joinpath("gcn", "ockl.jl"))
end
include(joinpath("gcn", "indexing.jl"))
include(joinpath("gcn", "assertion.jl"))
Expand Down
50 changes: 50 additions & 0 deletions src/device/gcn/intrinsics.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
@generated function _intr(::Val{fname}, out_arg, inp_args...) where {fname,}
inp_exprs = [:( inp_args[$i] ) for i in 1:length(inp_args)]
inp_types = [inp_args...]
out_type = convert(LLVMType, out_arg.parameters[1])

# create function
param_types = LLVMType[convert.(LLVMType, inp_types)...]
llvm_f, _ = create_function(out_type, param_types)
mod = LLVM.parent(llvm_f)

# generate IR
Builder(JuliaContext()) do builder
entry = BasicBlock(llvm_f, "entry", JuliaContext())
position!(builder, entry)

# call the intrinsic
intr_typ = LLVM.FunctionType(out_type, param_types)
intr = LLVM.Function(mod, string(fname), intr_typ)
value = call!(builder, intr, [parameters(llvm_f)...])
ret!(builder, value)
end

call_function(llvm_f, out_arg.parameters[1], Tuple{inp_args...}, Expr(:tuple, inp_exprs...))
end

struct GCNIntrinsic
jlname::Symbol
rocname::Symbol
isbroken::Bool # please don't laugh...
isinverted::Bool
# FIXME: Input/output types should have addrspaces
inp_args::Tuple
out_arg::Type
roclib::Symbol
suffix::Symbol
end

GCNIntrinsic(jlname, rocname=jlname; isbroken=false, isinverted=false,
inp_args=(), out_arg=(), roclib=:ocml, suffix=fntypes[first(inp_args)]) =
GCNIntrinsic(jlname, rocname, isbroken, isinverted, inp_args, out_arg, roclib, suffix)

function generate_intrinsic(intr)
inp_vars = [gensym() for _ in 1:length(intr.inp_args)]
inp_expr = [:($(inp_vars[idx])::$arg) for (idx,arg) in enumerate(intr.inp_args)]
libname = Symbol("__$(intr.roclib)_$(intr.rocname)_$(intr.suffix)")
@eval @inline function $(intr.jlname)($(inp_expr...))
y = _intr($(Val(libname)), $(intr.out_arg), $(inp_expr...))
return $(intr.isinverted ? :(1-y) : :y)
end
end
102 changes: 0 additions & 102 deletions src/device/gcn/math.jl

This file was deleted.

19 changes: 19 additions & 0 deletions src/device/gcn/ockl.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
const OCKL_INTRINSICS = GCNIntrinsic[]

#= TODO: Float16 Broken due to being i16 in Julia=#
for kind in (:wfred, :wfscan)
for op in (:add, :max, :min)
for jltype in (Float32, Float64, Int32, Int64, UInt32, UInt64)
inp_args = kind == :wfscan ? (jltype,Bool) : (jltype,)
push!(OCKL_INTRINSICS, GCNIntrinsic(Symbol(string(kind)*"_"*string(op)); roclib=:ockl, inp_args=inp_args, out_arg=jltype))
end
end
for op in (:and, :or, :xor)
for jltype in (Int32, Int64, UInt32, UInt64)
inp_args = kind == :wfscan ? (jltype,Bool) : (jltype,)
push!(OCKL_INTRINSICS, GCNIntrinsic(Symbol(string(kind)*"_"*string(op)); roclib=:ockl, inp_args=inp_args, out_arg=jltype))
end
end
end

generate_intrinsic.(OCKL_INTRINSICS)
52 changes: 52 additions & 0 deletions src/device/gcn/ocml.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
const OCML_INTRINSICS = GCNIntrinsic[]

for jltype in (
#= TODO: Float16 Broken due to being i16 in Julia=#
Float32, Float64)
append!(OCML_INTRINSICS, GCNIntrinsic.((
:sin, :cos, :tan, :asin, :acos, :atan, :atan2,
:sinh, :cosh, :tanh, :asinh, :acosh, :atanh,
:sinpi, :cospi, :tanpi, :sincospi,
:asinpi, :acospi, :atanpi, :atan2pi,
:sqrt, :rsqrt, :cbrt, :rcbrt, :recip,
:log, :log2, :log10, :log1p, :logb, :ilogb,
:exp, :exp2, :exp10, :expm1,
:erf, :erfinv, :erfc, :erfcinv, :erfcx,
# TODO: :brev, :clz, :ffs, :byte_perm, :popc,
:isnormal, :nearbyint, :nextafter,
:pow, :pown, :powr,
:tgamma, :j0, :j1, :y0, :y1,
); inp_args=(jltype,), out_arg=jltype))

push!(OCML_INTRINSICS, GCNIntrinsic(:sin_fast, :native_sin; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:cos_fast, :native_cos; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:sqrt_fast, :native_sqrt; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:rsqrt_fast, :native_rsqrt; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:recip_fast, :native_recip; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:log_fast, :native_log; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:log2_fast, :native_log2; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:log10_fast, :native_log10; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:exp_fast, :native_exp; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:exp2_fast, :native_exp2; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:exp10_fast, :native_exp10; inp_args=(jltype,), out_arg=jltype))
push!(OCML_INTRINSICS, GCNIntrinsic(:abs, :fabs; inp_args=(jltype,), out_arg=jltype))
# TODO: abs(::Union{Int32,Int64})

# FIXME: Multi-argument functions
#=
push!(OCML_INTRINSICS, = map(intr->GCNIntrinsic(intr), (
:sincos, :frexp, :ldexp, :copysign,
)))
=#
#push!(OCML_INTRINSICS, GCNIntrinsic(:ldexp; inp_args=(jltype,), out_arg=(jltype, Int32), isinverted=true))
end

let jltype=Float32
# TODO: Float64 is broken for some reason, try to re-enable on a newer LLVM
push!(OCML_INTRINSICS, GCNIntrinsic(:isfinite; inp_args=(jltype,), out_arg=Int32))
push!(OCML_INTRINSICS, GCNIntrinsic(:isinf; inp_args=(jltype,), out_arg=Int32))
push!(OCML_INTRINSICS, GCNIntrinsic(:isnan; inp_args=(jltype,), out_arg=Int32))
push!(OCML_INTRINSICS, GCNIntrinsic(:signbit; inp_args=(jltype,), out_arg=Int32))
end

generate_intrinsic.(OCML_INTRINSICS)
4 changes: 4 additions & 0 deletions src/device/tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ const llvmtypes = Dict{Type,Symbol}(
Int16 => :i16,
Int32 => :i32,
Int64 => :i64,
UInt32 => :i32,
UInt64 => :i64,
Float32 => :float,
Float64 => :double,
)
Expand All @@ -29,6 +31,8 @@ const fntypes = Dict{Type,Symbol}(
Int16 => :i16,
Int32 => :i32,
Int64 => :i64,
UInt32 => :u32,
UInt64 => :u64,
Float16 => :f16,
Float32 => :f32,
Float64 => :f64
Expand Down
4 changes: 2 additions & 2 deletions test/device/math.jl → test/device/ocml.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
@testset "Math Intrinsics" begin
for intr in AMDGPUnative.MATH_INTRINSICS
@testset "OCML Intrinsics" begin
for intr in AMDGPUnative.OCML_INTRINSICS
jlintr = intr.jlname
if intr.isbroken || !(isdefined(Base, jlintr) || isdefined(SpecialFunctions, jlintr))
@test_skip "$jlintr()"
Expand Down
3 changes: 2 additions & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ if AMDGPUnative.configured
include("device/output.jl")
include("device/globals.jl")
if Base.libllvm_version >= v"7.0"
include("device/math.jl")
include("device/ocml.jl")
include("device/ockl.jl")
else
@warn "Testing with LLVM 6; some tests will be disabled!"
@test_skip "Math Intrinsics"
Expand Down