Skip to content
This repository was archived by the owner on Nov 18, 2020. It is now read-only.

Commit f3459d2

Browse files
authored
Merge pull request #26 from JuliaGPU/jps/runtime-abstraction
Abstract runtime support
2 parents 8070d83 + 7f3ed5e commit f3459d2

13 files changed

+162
-69
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ HSARuntime = "2c364e2c-59fb-59c3-96f3-194112e690e0"
1010
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
1111
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
1212
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
13+
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
1314
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
1415

1516
[compat]

src/AMDGPUnative.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,16 @@ using Adapt
77
using TimerOutputs
88
using DataStructures
99
using Libdl
10+
using Requires
11+
12+
@enum DeviceRuntime HSA OCL
13+
const RUNTIME = Ref{DeviceRuntime}(HSA)
14+
#=
15+
if get(ENV, "AMDGPUNATIVE_OPENCL", "") != ""
16+
RUNTIME[] = OCL
17+
end
18+
=#
19+
include("runtime.jl")
1020

1121
const configured = HSARuntime.configured
1222

@@ -29,6 +39,7 @@ include("reflection.jl")
2939

3040
function __init__()
3141
check_deps()
42+
@require OpenCL="08131aa3-fb12-5dee-8b74-c09406e224a2" include("opencl.jl")
3243
__init_compiler__()
3344
end
3445

src/compiler/common.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ struct CompilerJob
44
# core invocation
55
f::Base.Callable
66
tt::DataType
7-
agent::HSAAgent
7+
device::RuntimeDevice
88
kernel::Bool
99

1010
# optional properties
@@ -14,10 +14,10 @@ struct CompilerJob
1414
maxregs::Union{Nothing,Integer}
1515
name::Union{Nothing,String}
1616

17-
CompilerJob(f, tt, agent, kernel; name=nothing,
17+
CompilerJob(f, tt, device, kernel; name=nothing,
1818
minthreads=nothing, maxthreads=nothing,
1919
blocks_per_sm=nothing, maxregs=nothing) =
20-
new(f, tt, agent, kernel, minthreads, maxthreads, blocks_per_sm,
20+
new(f, tt, device, kernel, minthreads, maxthreads, blocks_per_sm,
2121
maxregs, name)
2222
end
2323

src/compiler/driver.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,11 @@
44
const compile_hook = Ref{Union{Nothing,Function}}(nothing)
55

66
"""
7-
compile(target::Symbol, agent::HSAAgent, f, tt, kernel=true;
7+
compile(target::Symbol, device::RuntimeDevice, f, tt, kernel=true;
88
libraries=true, optimize=true, strip=false, strict=true, ...)
9-
Compile a function `f` invoked with types `tt` for agent `agent` to one of the
10-
following formats as specified by the `target` argument: `:julia` for Julia
11-
IR, `:llvm` for LLVM IR, `:gcn` for GCN assembly, and `:roc` for linked
9+
Compile a function `f` invoked with types `tt` for device `device` to one of
10+
the following formats as specified by the `target` argument: `:julia` for
11+
Julia IR, `:llvm` for LLVM IR, `:gcn` for GCN assembly, and `:roc` for linked
1212
objects. If the `kernel` flag is set, specialized code generation and
1313
optimization for kernel functions is enabled.
1414
The following keyword arguments are supported:
@@ -18,11 +18,11 @@ The following keyword arguments are supported:
1818
- `strict`: perform code validation either as early or as late as possible
1919
Other keyword arguments can be found in the documentation of [`rocfunction`](@ref).
2020
"""
21-
compile(target::Symbol, agent::HSAAgent, @nospecialize(f::Core.Function),
21+
compile(target::Symbol, device::RuntimeDevice, @nospecialize(f::Core.Function),
2222
@nospecialize(tt), kernel::Bool=true; libraries::Bool=true,
2323
optimize::Bool=true, strip::Bool=false, strict::Bool=true, kwargs...) =
2424

25-
compile(target, CompilerJob(f, tt, agent, kernel; kwargs...);
25+
compile(target, CompilerJob(f, tt, device, kernel; kwargs...);
2626
libraries=libraries, optimize=optimize, strip=strip,
2727
strict=strict)
2828

@@ -88,7 +88,7 @@ function codegen(target::Symbol, job::CompilerJob; libraries::Bool=true,
8888
# always preload the runtime, and do so early; it cannot be part of any timing block
8989
# because it recurses into the compiler
9090
if libraries
91-
runtime = load_runtime(job.agent)
91+
runtime = load_runtime(job.device)
9292
runtime_fns = LLVM.name.(defs(runtime))
9393
end
9494

@@ -105,7 +105,7 @@ function codegen(target::Symbol, job::CompilerJob; libraries::Bool=true,
105105
end
106106
=#
107107
# FIXME: Load this only when needed
108-
device_libs = load_device_libs(job.agent)
108+
device_libs = load_device_libs(job.device)
109109
for lib in device_libs
110110
if need_library(ir, lib)
111111
@timeit to[] "device library" link_device_lib!(job, ir, lib)

src/compiler/mcgen.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
# machine code generation
22

3-
function machine(agent::HSAAgent, triple::String)
3+
function machine(device::RuntimeDevice, triple::String)
44
InitializeAMDGPUTarget()
55
InitializeAMDGPUTargetInfo()
66
t = Target(triple)
77

88
InitializeAMDGPUTargetMC()
9-
cpu = get_first_isa(agent) # TODO: Make this configurable
9+
cpu = default_isa(device) # TODO: Make this configurable
1010
feat = ""
1111
tm = TargetMachine(t, triple, cpu, feat)
1212
asm_verbosity!(tm, true)
@@ -80,7 +80,7 @@ end
8080

8181
function mcgen(job::CompilerJob, mod::LLVM.Module, f::LLVM.Function;
8282
output_format=LLVM.API.LLVMObjectFile)
83-
tm = machine(job.agent, triple(mod))
83+
tm = machine(job.device, triple(mod))
8484

8585
InitializeAMDGPUAsmPrinter()
8686
return String(emit(tm, mod, output_format))

src/compiler/optim.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
# LLVM IR optimization
22

33
function optimize!(job::CompilerJob, mod::LLVM.Module, entry::LLVM.Function)
4-
tm = AMDGPUnative.machine(job.agent, triple(mod))
4+
tm = AMDGPUnative.machine(job.device, triple(mod))
55

66
if job.kernel
77
entry = promote_kernel!(job, mod, entry)

src/compiler/rtlib.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,10 @@ const libcache = Dict{String, LLVM.Module}()
3030

3131
# ROCm device library
3232

33-
function load_device_libs(agent)
33+
function load_device_libs(device)
3434
device_libs_path === nothing && return
3535

36-
isa_short = replace(get_first_isa(agent), "gfx"=>"")
36+
isa_short = replace(default_isa(device), "gfx"=>"")
3737
device_libs = LLVM.Module[]
3838
bitcode_files = (
3939
"hc.amdgcn.bc",
@@ -132,20 +132,20 @@ end
132132

133133
## functionality to build the runtime library
134134

135-
function emit_function!(mod, agent, f, types, name)
135+
function emit_function!(mod, device, f, types, name)
136136
tt = Base.to_tuple_type(types)
137-
new_mod, entry = codegen(:llvm, CompilerJob(f, tt, agent, #=kernel=# false);
137+
new_mod, entry = codegen(:llvm, CompilerJob(f, tt, device, #=kernel=# false);
138138
libraries=false, strict=false)
139139
LLVM.name!(entry, name)
140140
link!(mod, new_mod)
141141
end
142142

143-
function build_runtime(agent)
143+
function build_runtime(device)
144144
mod = LLVM.Module("AMDGPUnative run-time library", JuliaContext())
145145

146146
for method in values(Runtime.methods)
147147
try
148-
emit_function!(mod, agent, method.def, method.types, method.llvm_name)
148+
emit_function!(mod, device, method.def, method.types, method.llvm_name)
149149
catch err
150150
@warn method
151151
end
@@ -154,8 +154,8 @@ function build_runtime(agent)
154154
mod
155155
end
156156

157-
function load_runtime(agent::HSAAgent)
158-
isa = get_first_isa(agent)
157+
function load_runtime(device::RuntimeDevice)
158+
isa = default_isa(device)
159159
name = "amdgpunative.$isa.bc"
160160
path = joinpath(@__DIR__, "..", "..", "deps", "runtime", name)
161161
mkpath(dirname(path))
@@ -167,7 +167,7 @@ function load_runtime(agent::HSAAgent)
167167
end
168168
else
169169
@info "Building the AMDGPUnative run-time library for your $isa device, this might take a while..."
170-
lib = build_runtime(agent)
170+
lib = build_runtime(device)
171171
open(path, "w") do io
172172
write(io, lib)
173173
end

src/execution.jl

Lines changed: 31 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,17 @@
33
export @roc, rocconvert, rocfunction
44

55
struct Kernel{F,TT}
6-
agent::HSAAgent
6+
device::RuntimeDevice
77
mod::ROCModule
88
fun::ROCFunction
99
end
1010

1111
# `split_kwargs()` segregates keyword arguments passed to `@roc` into those
1212
# affecting the compiler, kernel execution, or both.
1313
function split_kwargs(kwargs)
14-
compiler_kws = [:agent, :queue, :name]
15-
call_kws = [:groupsize, :gridsize, :agent, :queue]
14+
# TODO: Alias groupsize and gridsize as threads and blocks, respectively
15+
compiler_kws = [:device, :agent, :queue, :name]
16+
call_kws = [:groupsize, :gridsize, :device, :agent, :queue]
1617
compiler_kwargs = []
1718
call_kwargs = []
1819
for kwarg in kwargs
@@ -60,6 +61,23 @@ function assign_args!(code, args)
6061
return vars, var_exprs
6162
end
6263

64+
function extract_device(;device=nothing, agent=nothing, kwargs...)
65+
if device !== nothing
66+
return device
67+
elseif agent !== nothing
68+
return agent
69+
else
70+
return default_device()
71+
end
72+
end
73+
function extract_queue(device; queue=nothing, kwargs...)
74+
if queue !== nothing
75+
return queue
76+
else
77+
return default_queue(device)
78+
end
79+
end
80+
6381
# fast lookup of global world age
6482
world_age() = ccall(:jl_get_tls_world_age, UInt, ())
6583

@@ -125,11 +143,11 @@ macro roc(ex...)
125143
GC.@preserve $(vars...) begin
126144
local kernel_args = map(rocconvert, ($(var_exprs...),))
127145
local kernel_tt = Tuple{Core.Typeof.(kernel_args)...}
128-
local agent = get_default_agent()
129-
local kernel = rocfunction(agent, $(esc(f)), kernel_tt;
146+
local device = extract_device(; $(map(esc, call_kwargs)...))
147+
local kernel = rocfunction(device, $(esc(f)), kernel_tt;
130148
$(map(esc, compiler_kwargs)...))
131-
local queue = get_default_queue(agent)
132-
local signal = HSASignal()
149+
local queue = extract_queue(device; $(map(esc, call_kwargs)...))
150+
local signal = create_event()
133151
kernel(queue, signal, kernel_args...; $(map(esc, call_kwargs)...))
134152
wait(signal)
135153
end
@@ -188,7 +206,7 @@ The output of this function is automatically cached, i.e. you can simply call
188206
generated automatically, when the function changes, or when different types or
189207
keyword arguments are provided.
190208
"""
191-
@generated function rocfunction(agent::HSAAgent, f::Core.Function, tt::Type=Tuple{}; name=nothing, kwargs...)
209+
@generated function rocfunction(device::RuntimeDevice, f::Core.Function, tt::Type=Tuple{}; name=nothing, kwargs...)
192210
tt = Base.to_tuple_type(tt.parameters[1])
193211
sig = Base.signature_type(f, tt)
194212
t = Tuple(tt.parameters)
@@ -217,8 +235,8 @@ keyword arguments are provided.
217235

218236
# compile the function
219237
if !haskey(compilecache, key)
220-
fun, mod = compile(:roc, agent, f, tt; name=name, kwargs...)
221-
kernel = Kernel{f,tt}(agent, mod, fun)
238+
fun, mod = compile(:roc, device, f, tt; name=name, kwargs...)
239+
kernel = Kernel{f,tt}(device, mod, fun)
222240
compilecache[key] = kernel
223241
end
224242

@@ -227,10 +245,10 @@ keyword arguments are provided.
227245
end
228246

229247
rocfunction(f::Core.Function, tt::Type=Tuple{}; kwargs...) =
230-
rocfunction(get_default_agent(), f, tt; kwargs...)
248+
rocfunction(default_device(), f, tt; kwargs...)
231249

232-
@generated function call(kernel::Kernel{F,TT}, queue::HSAQueue,
233-
signal::HSASignal, args...; call_kwargs...) where {F,TT}
250+
@generated function call(kernel::Kernel{F,TT}, queue::RuntimeQueue,
251+
signal::RuntimeEvent, args...; call_kwargs...) where {F,TT}
234252

235253
sig = Base.signature_type(F, TT)
236254
args = (:F, (:( args[$i] ) for i in 1:length(args))...)

src/execution_utils.jl

Lines changed: 14 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ function Base.getindex(dims::ROCDim3, idx::Int)
5151
end
5252

5353
"""
54-
launch(queue::HSAQueue, signal::HSASignal, f::ROCFunction,
54+
launch(queue::RuntimeQueue, signal::RuntimeEvent, f::ROCFunction,
5555
groupsize::ROCDim, gridsize::ROCDim, args...)
5656
5757
Low-level call to launch a ROC function `f` on the GPU, using `groupsize` and
@@ -63,7 +63,7 @@ copied to the internal kernel parameter buffer, or a pointer to device memory.
6363
6464
This is a low-level call, preferably use [`roccall`](@ref) instead.
6565
"""
66-
@inline function launch(queue::HSAQueue, signal::HSASignal, f::ROCFunction,
66+
@inline function launch(queue::RuntimeQueue, signal::RuntimeEvent, f::ROCFunction,
6767
groupsize::ROCDim, gridsize::ROCDim, args...)
6868
groupsize = ROCDim3(groupsize)
6969
gridsize = ROCDim3(gridsize)
@@ -77,7 +77,7 @@ end
7777

7878
# we need a generated function to get an args array,
7979
# without having to inspect the types at runtime
80-
@generated function _launch(queue::HSAQueue, signal::HSASignal, f::ROCFunction,
80+
@generated function _launch(queue::RuntimeQueue, signal::RuntimeEvent, f::ROCFunction,
8181
groupsize::ROCDim3, gridsize::ROCDim3,
8282
args::NTuple{N,Any}) where N
8383

@@ -101,31 +101,21 @@ end
101101
GC.@preserve $(arg_refs...) begin
102102
kernelParams = [$(arg_ptrs...)]
103103

104-
# link with ld.lld
105-
ld_path = HSARuntime.ld_lld_path
106-
@assert ld_path != "" "ld.lld was not found; cannot link kernel"
107-
# TODO: Do this more idiomatically
108-
io = open("/tmp/amdgpu-dump.o", "w")
109-
write(io, f.mod.data)
110-
close(io)
111-
run(`$ld_path -shared -o /tmp/amdgpu.exe /tmp/amdgpu-dump.o`)
112-
io = open("/tmp/amdgpu.exe", "r")
113-
data = read(io)
114-
close(io)
115-
116-
# generate executable and kernel instance
117-
exe = HSAExecutable(queue.agent, data, f.entry)
118-
kern = HSAKernelInstance(queue.agent, exe, f.entry, args)
119-
HSARuntime.launch!(queue, kern, signal;
120-
workgroup_size=groupsize, grid_size=gridsize)
104+
# create executable and kernel instance
105+
exe = create_executable(get_device(queue), f)
106+
kern = create_kernel(get_device(queue), exe, f.entry, args)
107+
108+
# launch kernel
109+
launch_kernel(queue, kern, signal;
110+
groupsize=groupsize, gridsize=gridsize)
121111
end
122112
end).args)
123113

124114
return ex
125115
end
126116

127117
"""
128-
roccall(queue::HSAQueue, signal::HSASignal, f::ROCFunction, types, values...;
118+
roccall(queue::RuntimeQueue, signal::RuntimeEvent, f::ROCFunction, types, values...;
129119
groupsize::ROCDim, gridsize::ROCDim)
130120
131121
`ccall`-like interface for launching a ROC function `f` on a GPU.
@@ -151,14 +141,14 @@ being slightly faster.
151141
"""
152142
roccall
153143

154-
@inline function roccall(queue::HSAQueue, signal::HSASignal, f::ROCFunction, types::NTuple{N,DataType}, values::Vararg{Any,N};
144+
@inline function roccall(queue::RuntimeQueue, signal::RuntimeEvent, f::ROCFunction, types::NTuple{N,DataType}, values::Vararg{Any,N};
155145
kwargs...) where N
156146
# this cannot be inferred properly (because types only contains `DataType`s),
157147
# which results in the call `@generated _roccall` getting expanded upon first use
158148
_roccall(queue, signal, f, Tuple{types...}, values; kwargs...)
159149
end
160150

161-
@inline function roccall(queue::HSAQueue, signal::HSASignal, f::ROCFunction, tt::Type, values::Vararg{Any,N};
151+
@inline function roccall(queue::RuntimeQueue, signal::RuntimeEvent, f::ROCFunction, tt::Type, values::Vararg{Any,N};
162152
kwargs...) where N
163153
# in this case, the type of `tt` is `Tuple{<:DataType,...}`,
164154
# which means the generated function can be expanded earlier
@@ -167,7 +157,7 @@ end
167157

168158
# we need a generated function to get a tuple of converted arguments (using unsafe_convert),
169159
# without having to inspect the types at runtime
170-
@generated function _roccall(queue::HSAQueue, signal::HSASignal, f::ROCFunction, tt::Type, args::NTuple{N,Any};
160+
@generated function _roccall(queue::RuntimeQueue, signal::RuntimeEvent, f::ROCFunction, tt::Type, args::NTuple{N,Any};
171161
groupsize::ROCDim=1, gridsize::ROCDim=1) where N
172162

173163
# the type of `tt` is Type{Tuple{<:DataType...}}

src/opencl.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# OpenCL runtime interface to AMDGPUnative
2+
3+
include("opencl/args.jl")

src/opencl/args.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# OpenCL argument utilities
2+
3+
# Argument accessors
4+
5+
# __agocl_global_offset_x - OpenCL Global Offset X
6+
# __agocl_global_offset_y - OpenCL Global Offset Y
7+
# __agocl_global_offset_z - OpenCL Global Offset Z
8+
# __agocl_printf_addr - OpenCL address of printf buffer
9+
# __agocl_queue_addr - OpenCL address of virtual queue used by enqueue_kernel
10+
# __agocl_aqlwrap_addr - OpenCL address of AqlWrap struct used by enqueue_kernel
11+
# __agocl_multigrid - Pointer argument used for Multi-grid synchronization
12+

0 commit comments

Comments
 (0)