Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 16f3bef

Browse files
authored
Merge pull request #544 from JuliaGPU/tb/threads
Rework library handles for multithreading.
2 parents 50c1a3d + 8635a2f commit 16f3bef

20 files changed

+372
-157
lines changed

Manifest.toml

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,25 @@ version = "0.2.0"
2222

2323
[[CUDAapi]]
2424
deps = ["Libdl", "Logging"]
25-
git-tree-sha1 = "6eee47385c81ed3b3f716b745697869c712c2df3"
25+
git-tree-sha1 = "ca1c7f639c5f6326919ee2834fa0dffb5002ff60"
26+
repo-rev = "master"
27+
repo-url = "https://github.com/JuliaGPU/CUDAapi.jl.git"
2628
uuid = "3895d2a7-ec45-59b8-82bb-cfc6a382f9b3"
2729
version = "2.0.0"
2830

2931
[[CUDAdrv]]
3032
deps = ["CEnum", "CUDAapi", "Printf"]
31-
git-tree-sha1 = "0f39fddace3324707469ace7fbcbc7b28d5cf921"
33+
git-tree-sha1 = "5c2cf00a78503e1f71409cecf3d64508fb33f17f"
34+
repo-rev = "master"
35+
repo-url = "https://github.com/JuliaGPU/CUDAdrv.jl.git"
3236
uuid = "c5f51814-7f29-56b8-a69c-e4d8f6be1fde"
3337
version = "4.0.4"
3438

3539
[[CUDAnative]]
3640
deps = ["Adapt", "CEnum", "CUDAapi", "CUDAdrv", "DataStructures", "InteractiveUtils", "LLVM", "Libdl", "Printf", "TimerOutputs"]
37-
git-tree-sha1 = "a67b38619d1fa131027bac1c4a81f0012254d1fd"
41+
git-tree-sha1 = "8b1a585344fee94bdb95ac44653fd057d74e32e6"
42+
repo-rev = "master"
43+
repo-url = "https://github.com/JuliaGPU/CUDAnative.jl.git"
3844
uuid = "be33ccc6-a3ff-5ff2-a52e-74243cff1e17"
3945
version = "2.6.0"
4046

src/CuArrays.jl

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ include("linalg.jl")
3232

3333
include("gpuarray_interface.jl")
3434

35-
# many libraries need to be initialized per-device (per-context, really, but we assume users
36-
# of CuArrays and/or CUDAnative only use a single context), so keep track of the active one.
37-
const active_context = Ref{CuContext}()
38-
3935
include("blas/CUBLAS.jl")
4036
include("sparse/CUSPARSE.jl")
4137
include("solver/CUSOLVER.jl")
@@ -112,28 +108,6 @@ function __init__()
112108
# package integrations
113109
@require ForwardDiff="f6369f11-7733-5829-9624-2563aa707210" include("forwarddiff.jl")
114110

115-
# update the active context when we switch devices
116-
callback = (::CuDevice, ctx::CuContext) -> begin
117-
active_context[] = ctx
118-
119-
# wipe the active handles
120-
CUBLAS._handle[] = C_NULL
121-
CUBLAS._xt_handle[] = C_NULL
122-
CUSOLVER._dense_handle[] = C_NULL
123-
CUSOLVER._sparse_handle[] = C_NULL
124-
CUSPARSE._handle[] = C_NULL
125-
CURAND._generator[] = nothing
126-
CUDNN._handle[] = C_NULL
127-
CUTENSOR._handle[] = nothing
128-
end
129-
push!(CUDAnative.device!_listeners, callback)
130-
131-
# a device might be active already
132-
existing_ctx = CUDAdrv.CuCurrentContext()
133-
if existing_ctx !== nothing
134-
active_context[] = existing_ctx
135-
end
136-
137111
__init_memory__()
138112

139113
__initialized__[] = true

src/blas/CUBLAS.jl

Lines changed: 37 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,10 @@ using CUDAapi
55
using CUDAdrv
66
using CUDAdrv: CUstream
77

8-
import CUDAnative
8+
using CUDAnative
99

1010
using ..CuArrays
11-
using ..CuArrays: active_context, unsafe_free!
11+
using ..CuArrays: unsafe_free!
1212
using LinearAlgebra
1313

1414
using CEnum
@@ -27,45 +27,62 @@ include("wrappers.jl")
2727
# high-level integrations
2828
include("linalg.jl")
2929

30-
const _handles = Dict{CuContext,cublasHandle_t}()
31-
const _xt_handles = Dict{CuContext,cublasXtHandle_t}()
32-
const _handle = Ref{cublasHandle_t}(C_NULL)
33-
const _xt_handle = Ref{cublasXtHandle_t}(C_NULL)
30+
const created_handles = IdDict{CuContext,cublasHandle_t}()
31+
const created_xt_handles = IdDict{CuContext,cublasXtHandle_t}()
32+
const active_handles = Vector{Union{Nothing,cublasHandle_t}}()
33+
const active_xt_handles = Vector{Union{Nothing,cublasXtHandle_t}}()
3434

3535
function handle()
36-
if _handle[] == C_NULL
37-
CUDAnative.maybe_initialize("CUBLAS")
38-
_handle[] = get!(_handles, active_context[]) do
39-
context = active_context[]
36+
tid = Threads.threadid()
37+
if @inbounds active_handles[tid] === nothing
38+
ctx = context()
39+
active_handles[tid] = get!(created_handles, ctx) do
4040
handle = cublasCreate_v2()
41+
atexit(()->CUDAdrv.isvalid(ctx) && cublasDestroy_v2(handle))
4142

4243
# enable tensor math mode if our device supports it, and fast math is enabled
43-
dev = CUDAdrv.device(context)
44+
dev = CUDAdrv.device()
4445
if Base.JLOptions().fast_math == 1 && CUDAdrv.capability(dev) >= v"7.0" && version() >= v"9"
4546
cublasSetMathMode(CUBLAS_TENSOR_OP_MATH, handle)
4647
end
4748

48-
atexit(()->CUDAdrv.isvalid(context) && cublasDestroy_v2(handle))
4949
handle
5050
end
5151
end
52-
53-
return _handle[]
52+
@inbounds active_handles[tid]
5453
end
5554

5655
function xt_handle()
57-
if _xt_handle[] == C_NULL
58-
@assert isassigned(active_context) # some other call should have initialized CUDA
59-
_xt_handle[] = get!(_xt_handles, active_context[]) do
60-
context = active_context[]
56+
tid = Threads.threadid()
57+
if @inbounds active_xt_handles[tid] === nothing
58+
ctx = context()
59+
active_xt_handles[tid] = get!(created_xt_handles, ctx) do
6160
handle = cublasXtCreate()
61+
atexit(()->CUDAdrv.isvalid(ctx) && cublasXtDestroy(handle))
62+
63+
# select the devices
64+
# TODO: this is weird, since we typically use a single device per thread/context
6265
devs = convert.(Cint, CUDAdrv.devices())
6366
cublasXtDeviceSelect(handle, length(devs), devs)
64-
atexit(()->CUDAdrv.isvalid(context) && cublasXtDestroy(handle))
67+
6568
handle
6669
end
6770
end
68-
return _xt_handle[]
71+
@inbounds active_xt_handles[tid]
72+
end
73+
74+
function __init__()
75+
resize!(active_handles, Threads.nthreads())
76+
fill!(active_handles, nothing)
77+
78+
resize!(active_xt_handles, Threads.nthreads())
79+
fill!(active_xt_handles, nothing)
80+
81+
CUDAnative.atcontextswitch() do tid, ctx
82+
# we don't eagerly initialize handles, but do so lazily when requested
83+
active_handles[tid] = nothing
84+
active_xt_handles[tid] = nothing
85+
end
6986
end
7087

7188
end

src/blas/error.jl

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,34 @@ function status_message(status)
3737
end
3838
end
3939

40-
macro check(blas_func)
40+
41+
## API call wrapper
42+
43+
# API calls that are allowed without a functional context
44+
const preinit_apicalls = Set{Symbol}([
45+
:cublasGetVersion,
46+
:cublasGetProperty,
47+
:cublasGetCudartVersion
48+
])
49+
50+
# outlined functionality to avoid GC frame allocation
51+
@noinline function throw_api_error(res)
52+
throw(CuError(res))
53+
end
54+
55+
macro check(ex)
56+
fun = Symbol(decode_ccall_function(ex))
57+
init = if !in(fun, preinit_apicalls)
58+
:(CUDAnative.maybe_initialize())
59+
end
4160
quote
42-
local err::cublasStatus_t
43-
err = $(esc(blas_func::Expr))
44-
if err != CUBLAS_STATUS_SUCCESS
45-
throw(CUBLASError(err))
61+
$init
62+
63+
res = $(esc(ex))
64+
if res != CUBLAS_STATUS_SUCCESS
65+
throw_api_error(res)
4666
end
47-
err
67+
68+
return
4869
end
49-
end
70+
end

src/dnn/CUDNN.jl

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,12 @@ using CUDAapi: libraryPropertyType
66
using CUDAdrv
77
using CUDAdrv: CUstream
88

9-
import CUDAnative
9+
using CUDAnative
1010

1111
using CEnum
1212

1313
using ..CuArrays
14-
using ..CuArrays: active_context, @argout, @workspace
14+
using ..CuArrays: @argout, @workspace
1515
import ..CuArrays.unsafe_free!
1616

1717
import NNlib
@@ -41,21 +41,30 @@ include("nnlib.jl")
4141

4242
include("compat.jl")
4343

44-
const _handles = Dict{CuContext,cudnnHandle_t}()
45-
const _handle = Ref{cudnnHandle_t}(C_NULL)
44+
const created_handles = IdDict{CuContext,cudnnHandle_t}()
45+
const active_handles = Vector{Union{Nothing,cudnnHandle_t}}()
4646

4747
function handle()
48-
if _handle[] == C_NULL
49-
CUDAnative.maybe_initialize("CUDNN")
50-
_handle[] = get!(_handles, active_context[]) do
51-
context = active_context[]
48+
tid = Threads.threadid()
49+
if @inbounds active_handles[tid] === nothing
50+
ctx = context()
51+
active_handles[tid] = get!(created_handles, ctx) do
5252
handle = cudnnCreate()
53-
atexit(()->CUDAdrv.isvalid(context) && cudnnDestroy(handle))
53+
atexit(()->CUDAdrv.isvalid(ctx) && cudnnDestroy(handle))
5454
handle
5555
end
5656
end
57+
@inbounds active_handles[tid]
58+
end
59+
60+
function __init__()
61+
resize!(active_handles, Threads.nthreads())
62+
fill!(active_handles, nothing)
5763

58-
return _handle[]
64+
CUDAnative.atcontextswitch() do tid, ctx
65+
# we don't eagerly initialize handles, but do so lazily when requested
66+
active_handles[tid] = nothing
67+
end
5968
end
6069

6170
end

src/dnn/error.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,13 +11,35 @@ function CUDNNError(status::cudnnStatus_t)
1111
return CUDNNError(status, msg)
1212
end
1313

14-
macro check(dnn_func)
14+
15+
## API call wrapper
16+
17+
# API calls that are allowed without a functional context
18+
const preinit_apicalls = Set{Symbol}([
19+
:cudnnGetVersion,
20+
:cudnnGetProperty,
21+
:cudnnGetCudartVersion,
22+
:cudnnGetErrorString,
23+
])
24+
25+
# outlined functionality to avoid GC frame allocation
26+
@noinline function throw_api_error(res)
27+
throw(CUDNNError(res))
28+
end
29+
30+
macro check(ex)
31+
fun = Symbol(decode_ccall_function(ex))
32+
init = if !in(fun, preinit_apicalls)
33+
:(CUDAnative.maybe_initialize())
34+
end
1535
quote
16-
local err::cudnnStatus_t
17-
err = $(esc(dnn_func))
18-
if err != CUDNN_STATUS_SUCCESS
19-
throw(CUDNNError(err))
36+
$init
37+
38+
res = $(esc(ex))
39+
if res != CUDNN_STATUS_SUCCESS
40+
throw_api_error(res)
2041
end
21-
err
42+
43+
return
2244
end
2345
end

src/dnn/filter.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ Base.unsafe_convert(::Type{cudnnFilterDescriptor_t}, fd::FilterDesc) = fd.ptr
1010

1111
function createFilterDesc()
1212
d = Ref{cudnnFilterDescriptor_t}()
13-
@check cudnnCreateFilterDescriptor(d)
13+
cudnnCreateFilterDescriptor(d)
1414
return d[]
1515
end
1616

src/fft/CUFFT.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ import ..CuArrays: unsafe_free!
88
using CUDAdrv
99
using CUDAdrv: CUstream
1010

11+
using CUDAnative
12+
1113
using CEnum
1214

1315
const libcufft = Ref("libcufft")

src/fft/error.jl

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,13 +51,33 @@ function status_message(status)
5151
end
5252
end
5353

54-
macro check(fft_func)
54+
55+
## API call wrapper
56+
57+
# API calls that are allowed without a functional context
58+
const preinit_apicalls = Set{Symbol}([
59+
:cufftGetVersion,
60+
:cufftGetProperty,
61+
])
62+
63+
# outlined functionality to avoid GC frame allocation
64+
@noinline function throw_api_error(res)
65+
throw(CUFFTError(res))
66+
end
67+
68+
macro check(ex)
69+
fun = Symbol(decode_ccall_function(ex))
70+
init = if !in(fun, preinit_apicalls)
71+
:(CUDAnative.maybe_initialize())
72+
end
5573
quote
56-
local err::cufftResult
57-
err = $(esc(fft_func::Expr))
58-
if err != CUFFT_SUCCESS
59-
throw(CUFFTError(err))
74+
$init
75+
76+
res = $(esc(ex))
77+
if res != CUFFT_SUCCESS
78+
throw_api_error(res)
6079
end
61-
err
80+
81+
return
6282
end
6383
end

src/memory.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -280,9 +280,7 @@ synchronized right before and after executing `ex` to exclude any external effec
280280
macro time(ex)
281281
quote
282282
# @time might surround an application, so be sure to initialize CUDA before that
283-
# FIXME: this should be done in CUDAdrv (`synchronize(ctx=CuCurrentOrNewContext()`)
284-
# but the CUDA initialization mechanics are part of CUDAnative.jl
285-
CUDAnative.maybe_initialize("@time")
283+
CUDAnative.maybe_initialize()
286284

287285
# coarse synchronization to exclude effects from previously-executed code
288286
CUDAdrv.synchronize()

0 commit comments

Comments
 (0)