Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 2093291

Browse files
authored
Merge pull request #599 from JuliaGPU/tb/simplify_findfirst
Simplify multidimensional find kernel.
2 parents c3a68db + 137bd69 commit 2093291

File tree

1 file changed

+22
-41
lines changed

1 file changed

+22
-41
lines changed

src/indexing.jl

Lines changed: 22 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -126,65 +126,46 @@ end
126126
Base.findfirst(xs::CuArray{Bool}) = findfirst(identity, xs)
127127

128128
function Base.findfirst(vals::CuArray, xs::CuArray)
129-
# figure out which dimension was reduced
130-
@assert ndims(vals) == ndims(xs)
131-
dims = [i for i in 1:ndims(xs) if size(xs,i)!=1 && size(vals,i)==1]
132-
@assert length(dims) == 1
133-
dim = dims[1]
134-
135-
136129
## find the first matching element
137130

131+
# NOTE: this kernel performs global atomic operations for the sake of simplicity.
132+
# if this turns out to be a bottleneck, we will need to cache in local memory.
133+
# that requires the dimension-under-reduction to be iterated in first order.
134+
# this can be done by splitting the iteration domain eagerly; see the
135+
# accumulate kernel for an example, or git history from before this comment.
136+
138137
indices = fill(typemax(Int), size(vals))
139138

140-
# iteration domain across the main dimension
141-
Rdim = CartesianIndices((size(xs, dim),))
139+
function kernel(xs, vals, indices)
140+
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
142141

143-
# iteration domain for the other dimensions
144-
Rpre = CartesianIndices(size(xs)[1:dim-1])
145-
Rpost = CartesianIndices(size(xs)[dim+1:end])
146-
Rother = CartesianIndices((length(Rpre), length(Rpost)))
142+
R = CartesianIndices(xs)
147143

148-
function kernel(xs, vals, indices, Rdim, Rpre, Rpost, Rother)
149-
# iterate the main dimension using threads and the first block dimension
150-
i = (blockIdx().x-1) * blockDim().x + threadIdx().x
151-
# iterate the other dimensions using the remaining block dimensions
152-
j = (blockIdx().z-1) * gridDim().y + blockIdx().y
153-
154-
if i <= length(Rdim) && j <= length(Rother)
155-
I = Rother[j]
156-
Ipre = Rpre[I[1]]
157-
Ipost = Rpost[I[2]]
158-
159-
@inbounds if xs[Ipre, i, Ipost] == vals[Ipre, 1, Ipost]
160-
full_index = LinearIndices(xs)[Ipre, i, Ipost] # atomic_min only works with integers
161-
reduced_index = LinearIndices(indices)[Ipre, 1, Ipost] # FIXME: @atomic doesn't handle array ref with CartesianIndices
162-
CUDAnative.@atomic indices[reduced_index] = min(indices[reduced_index], full_index)
144+
if i <= length(R)
145+
I = R[i]
146+
Jmax = last(CartesianIndices(vals))
147+
J = min(I, Jmax)
148+
149+
@inbounds if xs[I] == vals[J]
150+
I′ = LinearIndices(xs)[I] # atomic_min only works with integers
151+
J′ = LinearIndices(indices)[J] # FIXME: @atomic doesn't handle array ref with CartesianIndices
152+
CUDAnative.@atomic indices[J′] = min(indices[J′], I′)
163153
end
164154
end
165155

166156
return
167157
end
168158

169159
function configurator(kernel)
170-
# what's a good launch configuration for this kernel?
171160
config = launch_configuration(kernel.fun)
172161

173-
# blocks to cover the main dimension
174-
threads = min(length(Rdim), config.threads)
175-
blocks_dim = cld(length(Rdim), threads)
176-
# NOTE: the grid X dimension is virtually unconstrained
177-
178-
# blocks to cover the remaining dimensions
179-
dev = CUDAdrv.device(kernel.fun.mod.ctx)
180-
max_other_blocks = attribute(dev, CUDAdrv.DEVICE_ATTRIBUTE_MAX_GRID_DIM_Y)
181-
blocks_other = (min(length(Rother), max_other_blocks),
182-
cld(length(Rother), max_other_blocks))
162+
threads = min(length(xs), config.threads)
163+
blocks = cld(length(xs), threads)
183164

184-
return (threads=threads, blocks=(blocks_dim, blocks_other...))
165+
return (threads=threads, blocks=blocks)
185166
end
186167

187-
@cuda config=configurator kernel(xs, vals, indices, Rdim, Rpre, Rpost, Rother)
168+
@cuda config=configurator kernel(xs, vals, indices)
188169

189170

190171
## convert the linear indices to an appropriate type

0 commit comments

Comments
 (0)