Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 0cdf32c

Browse files
authored
Merge pull request #603 from JuliaGPU/ksh/xttests
Tests for cublas_xt
2 parents ba30d99 + 2b2089e commit 0cdf32c

File tree

1 file changed

+29
-44
lines changed

1 file changed

+29
-44
lines changed

test/blas.jl

Lines changed: 29 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -634,6 +634,13 @@ end # level 1 testset
634634
h_C = Array(dC)
635635
@test C h_C
636636
end
637+
@testset "xt_trsm" begin
638+
C = alpha*(A\B)
639+
dC = CuArrays.CUBLAS.xt_trsm('L','U','N','N',alpha,dA,dB)
640+
# move to host and compare
641+
h_C = Array(dC)
642+
@test C h_C
643+
end
637644
@testset "trsm" begin
638645
Br = rand(elty,m,n)
639646
Bl = rand(elty,n,m)
@@ -782,50 +789,6 @@ end # level 1 testset
782789
end
783790
A = rand(elty,m,k)
784791
d_A = CuArray(A)
785-
@testset "syrk!" begin
786-
# generate matrices
787-
d_C = CuArray(sA)
788-
# C = (alpha*A)*transpose(A) + beta*C
789-
CuArrays.CUBLAS.syrk!('U','N',alpha,d_A,beta,d_C)
790-
C = (alpha*A)*transpose(A) + beta*sA
791-
C = triu(C)
792-
# move to host and compare
793-
h_C = Array(d_C)
794-
h_C = triu(C)
795-
@test C h_C
796-
end
797-
@testset "xt_syrk!" begin
798-
# generate matrices
799-
d_C = CuArray(sA)
800-
# C = (alpha*A)*transpose(A) + beta*C
801-
CuArrays.CUBLAS.xt_syrk!('U','N',alpha,d_A,beta,d_C)
802-
C = (alpha*A)*transpose(A) + beta*sA
803-
C = triu(C)
804-
# move to host and compare
805-
h_C = Array(d_C)
806-
h_C = triu(C)
807-
@test C h_C
808-
end
809-
@testset "syrk" begin
810-
# C = A*transpose(A)
811-
d_C = CuArrays.CUBLAS.syrk('U','N',d_A)
812-
C = A*transpose(A)
813-
C = triu(C)
814-
# move to host and compare
815-
h_C = Array(d_C)
816-
h_C = triu(C)
817-
@test C h_C
818-
end
819-
@testset "xt_syrk" begin
820-
# C = A*transpose(A)
821-
d_C = CuArrays.CUBLAS.xt_syrk('U','N',d_A)
822-
C = A*transpose(A)
823-
C = triu(C)
824-
# move to host and compare
825-
h_C = Array(d_C)
826-
h_C = triu(C)
827-
@test C h_C
828-
end
829792
@testset "syrkx!" begin
830793
# generate matrices
831794
syrkx_A = rand(elty, n, k)
@@ -870,6 +833,17 @@ end # level 1 testset
870833
d_badC = CuArray(badC)
871834
@test_throws DimensionMismatch CuArrays.CUBLAS.xt_syrkx!('U','N',alpha,d_syrkx_A,d_syrkx_B,beta,d_badC)
872835
end
836+
@testset "xt_syrkx" begin
837+
# generate matrices
838+
syrkx_A = rand(elty, n, k)
839+
syrkx_B = rand(elty, n, k)
840+
d_syrkx_A = CuArray(syrkx_A)
841+
d_syrkx_B = CuArray(syrkx_B)
842+
d_syrkx_C = CuArrays.CUBLAS.xt_syrkx('U','N',d_syrkx_A,d_syrkx_B)
843+
final_C = syrkx_A*transpose(syrkx_B)
844+
# move to host and compare
845+
h_C = Array(d_syrkx_C)
846+
end
873847
@testset "syrk" begin
874848
# C = A*transpose(A)
875849
d_C = CuArrays.CUBLAS.syrk('U','N',d_A)
@@ -1005,6 +979,17 @@ end # level 1 testset
1005979
h_C = triu(h_C)
1006980
@test C h_C
1007981
end
982+
@testset "xt_her2k" begin
983+
# generate parameters
984+
C = C + C'
985+
C = (A*B') + (B*A')
986+
d_C = CuArrays.CUBLAS.xt_her2k('U','N',d_A,d_B)
987+
# move back to host and compare
988+
C = triu(C)
989+
h_C = Array(d_C)
990+
h_C = triu(h_C)
991+
@test C h_C
992+
end
1008993
@testset "her2k" begin
1009994
C = A*B' + B*A'
1010995
d_C = CuArrays.CUBLAS.her2k('U','N',d_A,d_B)

0 commit comments

Comments
 (0)