Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit c3a68db

Browse files
authored
Merge pull request #597 from JuliaGPU/tb/findfirst_oob
Fix OOB in find_first kernel.
2 parents dd4cc1e + 7994d8f commit c3a68db

File tree

2 files changed

+16
-2
lines changed

2 files changed

+16
-2
lines changed

src/indexing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -156,9 +156,9 @@ function Base.findfirst(vals::CuArray, xs::CuArray)
156156
Ipre = Rpre[I[1]]
157157
Ipost = Rpost[I[2]]
158158

159-
@inbounds if xs[Ipre, i, Ipost] == vals[Ipre, Ipost]
159+
@inbounds if xs[Ipre, i, Ipost] == vals[Ipre, 1, Ipost]
160160
full_index = LinearIndices(xs)[Ipre, i, Ipost] # atomic_min only works with integers
161-
reduced_index = LinearIndices(indices)[Ipre, Ipost] # FIXME: @atomic doesn't handle array ref with CartesianIndices
161+
reduced_index = LinearIndices(indices)[Ipre, 1, Ipost] # FIXME: @atomic doesn't handle array ref with CartesianIndices
162162
CUDAnative.@atomic indices[reduced_index] = min(indices[reduced_index], full_index)
163163
end
164164
end

test/base.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -366,6 +366,13 @@ end
366366
let x = rand(Float32, 10, 10)
367367
@test findmax(x) == findmax(CuArray(x))
368368
@test findmax(x; dims=1) == Array.(findmax(CuArray(x); dims=1))
369+
@test findmax(x; dims=2) == Array.(findmax(CuArray(x); dims=2))
370+
end
371+
let x = rand(Float32, 10, 10, 10)
372+
@test findmax(x) == findmax(CuArray(x))
373+
@test findmax(x; dims=1) == Array.(findmax(CuArray(x); dims=1))
374+
@test findmax(x; dims=2) == Array.(findmax(CuArray(x); dims=2))
375+
@test findmax(x; dims=3) == Array.(findmax(CuArray(x); dims=3))
369376
end
370377

371378
let x = rand(Float32, 100)
@@ -375,6 +382,13 @@ end
375382
let x = rand(Float32, 10, 10)
376383
@test findmin(x) == findmin(CuArray(x))
377384
@test findmin(x; dims=1) == Array.(findmin(CuArray(x); dims=1))
385+
@test findmin(x; dims=2) == Array.(findmin(CuArray(x); dims=2))
386+
end
387+
let x = rand(Float32, 10, 10, 10)
388+
@test findmin(x) == findmin(CuArray(x))
389+
@test findmin(x; dims=1) == Array.(findmin(CuArray(x); dims=1))
390+
@test findmin(x; dims=2) == Array.(findmin(CuArray(x); dims=2))
391+
@test findmin(x; dims=3) == Array.(findmin(CuArray(x); dims=3))
378392
end
379393
end
380394

0 commit comments

Comments
 (0)