Skip to content

Commit 49a339c

Browse files
tam724maleadt
andauthored
Reflect change in rotate! from LinearAlgebra.jl (#603)
Co-authored-by: Tim Besard <[email protected]>
1 parent e61d13c commit 49a339c

File tree

3 files changed

+38
-3
lines changed

3 files changed

+38
-3
lines changed

src/host/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -687,8 +687,8 @@ function LinearAlgebra.rotate!(x::AbstractGPUArray, y::AbstractGPUArray, c::Numb
687687
i = @index(Global, Linear)
688688
@inbounds xi = x[i]
689689
@inbounds yi = y[i]
690-
@inbounds x[i] = c * xi + s * yi
691-
@inbounds y[i] = -conj(s) * xi + c * yi
690+
@inbounds x[i] = s*yi + c *xi
691+
@inbounds y[i] = c*yi - conj(s)*xi
692692
end
693693
rotate_kernel!(get_backend(x))(x, y, c, s; ndrange = size(x))
694694
return x, y

test/testsuite.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ function compare(f, AT::Type{<:AbstractGPUArray}, xs...; kwargs...)
4646
end
4747

4848
function compare(f, AT::Type{<:Array}, xs...; kwargs...)
49-
# no need to actually run this tests: we have nothing to compoare against,
49+
# no need to actually run this tests: we have nothing to compare against,
5050
# and we'll run it on a CPU array anyhow when comparing to a GPU array.
5151
#
5252
# this method exists so that we can at least run the test suite with Array,
@@ -67,6 +67,8 @@ isrealtype(T) = T <: Real
6767
iscomplextype(T) = T <: Complex
6868
isrealfloattype(T) = T <: AbstractFloat
6969
isfloattype(T) = T <: AbstractFloat || T <: Complex{<:AbstractFloat}
70+
NaN_T(T::Type{<:AbstractFloat}) = T(NaN)
71+
NaN_T(T::Type{<:Complex{<:AbstractFloat}}) = T(NaN, NaN)
7072

7173
# list of tests
7274
const tests = Dict()

test/testsuite/linalg.jl

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,36 @@ end
391391
@test isrealfloattype(typeof(opnorm(AT(mat), p)))
392392
end
393393
end
394+
395+
@testsuite "linalg/NaN_false" (AT, eltypes)->begin
396+
eltypes = filter(T -> isfloattype(T), eltypes) # only floats have NaN
397+
if AT <: AbstractGPUArray
398+
@testset "rmul! / lmul!" for T in eltypes
399+
y = invoke(rmul!, Tuple{AbstractGPUArray, Number}, adapt(AT, fill(NaN_T(T), 3)), false)
400+
@test !any(isnan, collect(y))
401+
y = invoke(lmul!, Tuple{Number, AbstractGPUArray}, false, adapt(AT, fill(NaN_T(T), 3)))
402+
@test !any(isnan, collect(y))
403+
end
404+
405+
@testset "axp{b}y!" for T in eltypes
406+
y = invoke(axpby!, Tuple{Number, AbstractGPUArray, Number, AbstractGPUArray}, false, adapt(AT, fill(NaN_T(T), 3)), false, adapt(AT, fill(NaN_T(T), 3)))
407+
@test !any(isnan, collect(y))
408+
y = invoke(axpy!, Tuple{Number, AbstractGPUArray, AbstractGPUArray}, false, adapt(AT, fill(NaN_T(T), 3)), adapt(AT, rand(T, 3)))
409+
@test !any(isnan, collect(y))
410+
end
411+
412+
@testset "rotate! / reflect!" for T in eltypes
413+
x, y = invoke(rotate!, Tuple{AbstractGPUArray, AbstractGPUArray, Number, Number}, adapt(AT, fill(NaN_T(T), 3)), adapt(AT, fill(NaN_T(T), 3)), false, false)
414+
@test !any(isnan, collect(x))
415+
@test !any(isnan, collect(y))
416+
x, y = invoke(reflect!, Tuple{AbstractGPUArray, AbstractGPUArray, Number, Number}, adapt(AT, fill(NaN_T(T), 3)), adapt(AT, fill(NaN_T(T), 3)), false, false)
417+
@test !any(isnan, collect(x))
418+
@test !any(isnan, collect(y))
419+
end
420+
421+
@testset "generic_matmatmul!" for T in eltypes
422+
y = invoke(GPUArrays.generic_matmatmul!, Tuple{AbstractArray, AbstractArray, AbstractArray, Number, Number}, adapt(AT, fill(NaN_T(T), 3, 3)), adapt(AT, fill(NaN_T(T), 3, 3)), adapt(AT, fill(NaN_T(T), 3, 3)), false, false)
423+
@test !any(isnan, collect(y))
424+
end
425+
end
426+
end

0 commit comments

Comments
 (0)