Skip to content
This repository was archived by the owner on Mar 12, 2021. It is now read-only.

Commit 60b1d8e

Browse files
authored
Merge pull request #533 from JuliaGPU/jh/updatetensortests
reenable complex tensor contraction tests
2 parents c3400ed + 09c52fe commit 60b1d8e

File tree

1 file changed

+60
-77
lines changed

1 file changed

+60
-77
lines changed

test/tensor.jl

Lines changed: 60 additions & 77 deletions
Original file line numberDiff line numberDiff line change
@@ -379,9 +379,9 @@ end
379379
#(Float64, ComplexF64, ComplexF64),
380380
#(ComplexF64, Float64, ComplexF64),
381381
# (ComplexF16, ComplexF16, ComplexF16), # does not work
382-
# (ComplexF32, ComplexF32, ComplexF32), # works for some
382+
(ComplexF32, ComplexF32, ComplexF32), # works for some
383383
# (ComplexF32, ComplexF32, ComplexF64), # does not work
384-
#(ComplexF64, ComplexF64, ComplexF64) # works for some
384+
(ComplexF64, ComplexF64, ComplexF64) # works for some
385385
)
386386

387387
@testset for NoA=1:3, NoB=1:3, Nc=1:3
@@ -420,70 +420,65 @@ end
420420
mB = reshape(permutedims(B, ipB), (lc, loB))
421421
C = zeros(eltyC, (dimsC...,))
422422
dC = CuArray(C)
423-
if !(NoA == 1 && NoB == 3 && Nc == 1) # broken for some reason
424-
# simple case
425-
opA = CUTENSOR.CUTENSOR_OP_IDENTITY
426-
opB = CUTENSOR.CUTENSOR_OP_IDENTITY
427-
opC = CUTENSOR.CUTENSOR_OP_IDENTITY
428-
opOut = CUTENSOR.CUTENSOR_OP_IDENTITY
429-
dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB,
430-
0, dC, indsC, opC, opOut)
431-
C = collect(dC)
432-
mC = reshape(permutedims(C, ipC), (loA, loB))
433-
@test mC mA * mB
434-
435-
# with non-trivial α
436-
α = rand(eltyC)
437-
dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB,
438-
0, dC, indsC, opC, opOut)
439-
C = collect(dC)
440-
mC = reshape(permutedims(C, ipC), (loA, loB))
441-
@test mC α * mA * mB
442-
443-
# with non-trivial β
444-
C = rand(eltyC, (dimsC...,))
445-
dC = CuArray(C)
446-
α = rand(eltyC)
447-
β = rand(eltyC)
448-
copyto!(dC, C)
449-
dD = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB,
450-
β, dC, indsC, opC, opOut)
451-
D = collect(dD)
452-
mC = reshape(permutedims(C, ipC), (loA, loB))
453-
mD = reshape(permutedims(D, ipC), (loA, loB))
454-
@test mD α * mA * mB + β * mC
455-
456-
# with CuTensor objects
457-
ctA = CuTensor(dA, indsA)
458-
ctB = CuTensor(dB, indsB)
459-
ctC = CuTensor(dC, indsC)
460-
ctC = LinearAlgebra.mul!(ctC, ctA, ctB)
461-
C2, C2inds = collect(ctC)
462-
mC = reshape(permutedims(C2, ipC), (loA, loB))
463-
@test mC mA * mB
464-
ctC = ctA * ctB
465-
C2, C2inds = collect(ctC)
466-
pC2 = convert.(Int, indexin(convert.(Char, C2inds), [indsoA; indsoB]))
467-
mC = reshape(permutedims(C2, invperm(pC2)), (loA, loB))
468-
@test mC mA * mB
469-
470-
# with conjugation flag for complex arguments
423+
424+
# simple case
425+
opA = CUTENSOR.CUTENSOR_OP_IDENTITY
426+
opB = CUTENSOR.CUTENSOR_OP_IDENTITY
427+
opC = CUTENSOR.CUTENSOR_OP_IDENTITY
428+
opOut = CUTENSOR.CUTENSOR_OP_IDENTITY
429+
dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB,
430+
0, dC, indsC, opC, opOut)
431+
C = collect(dC)
432+
mC = reshape(permutedims(C, ipC), (loA, loB))
433+
@test mC mA * mB
434+
435+
# with non-trivial α
436+
α = rand(eltyC)
437+
dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB,
438+
0, dC, indsC, opC, opOut)
439+
C = collect(dC)
440+
mC = reshape(permutedims(C, ipC), (loA, loB))
441+
@test mC α * mA * mB
442+
443+
# with non-trivial β
444+
C = rand(eltyC, (dimsC...,))
445+
dC = CuArray(C)
446+
α = rand(eltyC)
447+
β = rand(eltyC)
448+
copyto!(dC, C)
449+
dD = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB,
450+
β, dC, indsC, opC, opOut)
451+
D = collect(dD)
452+
mC = reshape(permutedims(C, ipC), (loA, loB))
453+
mD = reshape(permutedims(D, ipC), (loA, loB))
454+
@test mD α * mA * mB + β * mC
455+
456+
# with CuTensor objects
457+
ctA = CuTensor(dA, indsA)
458+
ctB = CuTensor(dB, indsB)
459+
ctC = CuTensor(dC, indsC)
460+
ctC = LinearAlgebra.mul!(ctC, ctA, ctB)
461+
C2, C2inds = collect(ctC)
462+
mC = reshape(permutedims(C2, ipC), (loA, loB))
463+
@test mC mA * mB
464+
ctC = ctA * ctB
465+
C2, C2inds = collect(ctC)
466+
pC2 = convert.(Int, indexin(convert.(Char, C2inds), [indsoA; indsoB]))
467+
mC = reshape(permutedims(C2, invperm(pC2)), (loA, loB))
468+
@test mC mA * mB
469+
470+
# with conjugation flag for complex arguments
471+
if !((NoA, NoB, Nc) in ((1,1,3), (1,2,3), (3,1,2)))
472+
# not supported for these specific cases for unknown reason
471473
if eltyA <: Complex
472-
# # not supported yet
473-
#opA = CUTENSOR.CUTENSOR_OP_CONJ
474-
#opB = CUTENSOR.CUTENSOR_OP_IDENTITY
475-
#opOut = CUTENSOR.CUTENSOR_OP_IDENTITY
476-
#dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB,
477-
# 0, dC, indsC, opC, opOut)
478-
#C = collect(dC)
479-
#mC = reshape(permutedims(C, ipC), (loA, loB))
480-
#@test mC ≈ conj(mA) * mB
481-
# # not supported yet
482-
# opOut = CUTENSOR.CUTENSOR_OP_CONJ
483-
# dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, zero(eltyC), dC, indsC, opC, opOut)
484-
# C = collect(dC)
485-
# mC = reshape(permutedims(C, ipC), (loA, loB))
486-
# @test mC ≈ conj(α * conj(mA) * mB)
474+
opA = CUTENSOR.CUTENSOR_OP_CONJ
475+
opB = CUTENSOR.CUTENSOR_OP_IDENTITY
476+
opOut = CUTENSOR.CUTENSOR_OP_IDENTITY
477+
dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB,
478+
0, dC, indsC, opC, opOut)
479+
C = collect(dC)
480+
mC = reshape(permutedims(C, ipC), (loA, loB))
481+
@test mC conj(mA) * mB
487482
end
488483
if eltyB <: Complex
489484
opA = CUTENSOR.CUTENSOR_OP_IDENTITY
@@ -494,12 +489,6 @@ end
494489
C = collect(dC)
495490
mC = reshape(permutedims(C, ipC), (loA, loB))
496491
@test mC mA*conj(mB)
497-
# # not supported yet
498-
# opOut = CUTENSOR.CUTENSOR_OP_CONJ
499-
# dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, zero(eltyC), dC, indsC, opC, opOut)
500-
# C = collect(dC)
501-
# mC = reshape(permutedims(C, ipC), (loA, loB))
502-
# @test mC ≈ conj(α * mA * conj(mB))
503492
end
504493
if eltyA <: Complex && eltyB <: Complex
505494
opA = CUTENSOR.CUTENSOR_OP_CONJ
@@ -510,12 +499,6 @@ end
510499
C = collect(dC)
511500
mC = reshape(permutedims(C, ipC), (loA, loB))
512501
@test mC conj(mA)*conj(mB)
513-
# # not supported yet
514-
# opOut = CUTENSOR.CUTENSOR_OP_CONJ
515-
# dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, zero(eltyC), dC, indsC, opC, opOut)
516-
# C = collect(dC)
517-
# mC = reshape(permutedims(C, ipC), (loA, loB))
518-
# @test mC ≈ conj(α * conj(mA) * conj(mB))
519502
end
520503
end
521504
end

0 commit comments

Comments
 (0)