|
379 | 379 | #(Float64, ComplexF64, ComplexF64),
|
380 | 380 | #(ComplexF64, Float64, ComplexF64),
|
381 | 381 | # (ComplexF16, ComplexF16, ComplexF16), # does not work
|
382 |
| - # (ComplexF32, ComplexF32, ComplexF32), # works for some |
| 382 | + (ComplexF32, ComplexF32, ComplexF32), # works for some |
383 | 383 | # (ComplexF32, ComplexF32, ComplexF64), # does not work
|
384 |
| - #(ComplexF64, ComplexF64, ComplexF64) # works for some |
| 384 | + (ComplexF64, ComplexF64, ComplexF64) # works for some |
385 | 385 | )
|
386 | 386 |
|
387 | 387 | @testset for NoA=1:3, NoB=1:3, Nc=1:3
|
@@ -420,70 +420,65 @@ end
|
420 | 420 | mB = reshape(permutedims(B, ipB), (lc, loB))
|
421 | 421 | C = zeros(eltyC, (dimsC...,))
|
422 | 422 | dC = CuArray(C)
|
423 |
| - if !(NoA == 1 && NoB == 3 && Nc == 1) # broken for some reason |
424 |
| - # simple case |
425 |
| - opA = CUTENSOR.CUTENSOR_OP_IDENTITY |
426 |
| - opB = CUTENSOR.CUTENSOR_OP_IDENTITY |
427 |
| - opC = CUTENSOR.CUTENSOR_OP_IDENTITY |
428 |
| - opOut = CUTENSOR.CUTENSOR_OP_IDENTITY |
429 |
| - dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB, |
430 |
| - 0, dC, indsC, opC, opOut) |
431 |
| - C = collect(dC) |
432 |
| - mC = reshape(permutedims(C, ipC), (loA, loB)) |
433 |
| - @test mC ≈ mA * mB |
434 |
| - |
435 |
| - # with non-trivial α |
436 |
| - α = rand(eltyC) |
437 |
| - dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, |
438 |
| - 0, dC, indsC, opC, opOut) |
439 |
| - C = collect(dC) |
440 |
| - mC = reshape(permutedims(C, ipC), (loA, loB)) |
441 |
| - @test mC ≈ α * mA * mB |
442 |
| - |
443 |
| - # with non-trivial β |
444 |
| - C = rand(eltyC, (dimsC...,)) |
445 |
| - dC = CuArray(C) |
446 |
| - α = rand(eltyC) |
447 |
| - β = rand(eltyC) |
448 |
| - copyto!(dC, C) |
449 |
| - dD = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, |
450 |
| - β, dC, indsC, opC, opOut) |
451 |
| - D = collect(dD) |
452 |
| - mC = reshape(permutedims(C, ipC), (loA, loB)) |
453 |
| - mD = reshape(permutedims(D, ipC), (loA, loB)) |
454 |
| - @test mD ≈ α * mA * mB + β * mC |
455 |
| - |
456 |
| - # with CuTensor objects |
457 |
| - ctA = CuTensor(dA, indsA) |
458 |
| - ctB = CuTensor(dB, indsB) |
459 |
| - ctC = CuTensor(dC, indsC) |
460 |
| - ctC = LinearAlgebra.mul!(ctC, ctA, ctB) |
461 |
| - C2, C2inds = collect(ctC) |
462 |
| - mC = reshape(permutedims(C2, ipC), (loA, loB)) |
463 |
| - @test mC ≈ mA * mB |
464 |
| - ctC = ctA * ctB |
465 |
| - C2, C2inds = collect(ctC) |
466 |
| - pC2 = convert.(Int, indexin(convert.(Char, C2inds), [indsoA; indsoB])) |
467 |
| - mC = reshape(permutedims(C2, invperm(pC2)), (loA, loB)) |
468 |
| - @test mC ≈ mA * mB |
469 |
| - |
470 |
| - # with conjugation flag for complex arguments |
| 423 | + |
| 424 | + # simple case |
| 425 | + opA = CUTENSOR.CUTENSOR_OP_IDENTITY |
| 426 | + opB = CUTENSOR.CUTENSOR_OP_IDENTITY |
| 427 | + opC = CUTENSOR.CUTENSOR_OP_IDENTITY |
| 428 | + opOut = CUTENSOR.CUTENSOR_OP_IDENTITY |
| 429 | + dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB, |
| 430 | + 0, dC, indsC, opC, opOut) |
| 431 | + C = collect(dC) |
| 432 | + mC = reshape(permutedims(C, ipC), (loA, loB)) |
| 433 | + @test mC ≈ mA * mB |
| 434 | + |
| 435 | + # with non-trivial α |
| 436 | + α = rand(eltyC) |
| 437 | + dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, |
| 438 | + 0, dC, indsC, opC, opOut) |
| 439 | + C = collect(dC) |
| 440 | + mC = reshape(permutedims(C, ipC), (loA, loB)) |
| 441 | + @test mC ≈ α * mA * mB |
| 442 | + |
| 443 | + # with non-trivial β |
| 444 | + C = rand(eltyC, (dimsC...,)) |
| 445 | + dC = CuArray(C) |
| 446 | + α = rand(eltyC) |
| 447 | + β = rand(eltyC) |
| 448 | + copyto!(dC, C) |
| 449 | + dD = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, |
| 450 | + β, dC, indsC, opC, opOut) |
| 451 | + D = collect(dD) |
| 452 | + mC = reshape(permutedims(C, ipC), (loA, loB)) |
| 453 | + mD = reshape(permutedims(D, ipC), (loA, loB)) |
| 454 | + @test mD ≈ α * mA * mB + β * mC |
| 455 | + |
| 456 | + # with CuTensor objects |
| 457 | + ctA = CuTensor(dA, indsA) |
| 458 | + ctB = CuTensor(dB, indsB) |
| 459 | + ctC = CuTensor(dC, indsC) |
| 460 | + ctC = LinearAlgebra.mul!(ctC, ctA, ctB) |
| 461 | + C2, C2inds = collect(ctC) |
| 462 | + mC = reshape(permutedims(C2, ipC), (loA, loB)) |
| 463 | + @test mC ≈ mA * mB |
| 464 | + ctC = ctA * ctB |
| 465 | + C2, C2inds = collect(ctC) |
| 466 | + pC2 = convert.(Int, indexin(convert.(Char, C2inds), [indsoA; indsoB])) |
| 467 | + mC = reshape(permutedims(C2, invperm(pC2)), (loA, loB)) |
| 468 | + @test mC ≈ mA * mB |
| 469 | + |
| 470 | + # with conjugation flag for complex arguments |
| 471 | + if !((NoA, NoB, Nc) in ((1,1,3), (1,2,3), (3,1,2))) |
| 472 | + # not supported for these specific cases for unknown reason |
471 | 473 | if eltyA <: Complex
|
472 |
| - # # not supported yet |
473 |
| - #opA = CUTENSOR.CUTENSOR_OP_CONJ |
474 |
| - #opB = CUTENSOR.CUTENSOR_OP_IDENTITY |
475 |
| - #opOut = CUTENSOR.CUTENSOR_OP_IDENTITY |
476 |
| - #dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB, |
477 |
| - # 0, dC, indsC, opC, opOut) |
478 |
| - #C = collect(dC) |
479 |
| - #mC = reshape(permutedims(C, ipC), (loA, loB)) |
480 |
| - #@test mC ≈ conj(mA) * mB |
481 |
| - # # not supported yet |
482 |
| - # opOut = CUTENSOR.CUTENSOR_OP_CONJ |
483 |
| - # dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, zero(eltyC), dC, indsC, opC, opOut) |
484 |
| - # C = collect(dC) |
485 |
| - # mC = reshape(permutedims(C, ipC), (loA, loB)) |
486 |
| - # @test mC ≈ conj(α * conj(mA) * mB) |
| 474 | + opA = CUTENSOR.CUTENSOR_OP_CONJ |
| 475 | + opB = CUTENSOR.CUTENSOR_OP_IDENTITY |
| 476 | + opOut = CUTENSOR.CUTENSOR_OP_IDENTITY |
| 477 | + dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB, |
| 478 | + 0, dC, indsC, opC, opOut) |
| 479 | + C = collect(dC) |
| 480 | + mC = reshape(permutedims(C, ipC), (loA, loB)) |
| 481 | + @test mC ≈ conj(mA) * mB |
487 | 482 | end
|
488 | 483 | if eltyB <: Complex
|
489 | 484 | opA = CUTENSOR.CUTENSOR_OP_IDENTITY
|
|
494 | 489 | C = collect(dC)
|
495 | 490 | mC = reshape(permutedims(C, ipC), (loA, loB))
|
496 | 491 | @test mC ≈ mA*conj(mB)
|
497 |
| - # # not supported yet |
498 |
| - # opOut = CUTENSOR.CUTENSOR_OP_CONJ |
499 |
| - # dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, zero(eltyC), dC, indsC, opC, opOut) |
500 |
| - # C = collect(dC) |
501 |
| - # mC = reshape(permutedims(C, ipC), (loA, loB)) |
502 |
| - # @test mC ≈ conj(α * mA * conj(mB)) |
503 | 492 | end
|
504 | 493 | if eltyA <: Complex && eltyB <: Complex
|
505 | 494 | opA = CUTENSOR.CUTENSOR_OP_CONJ
|
|
510 | 499 | C = collect(dC)
|
511 | 500 | mC = reshape(permutedims(C, ipC), (loA, loB))
|
512 | 501 | @test mC ≈ conj(mA)*conj(mB)
|
513 |
| - # # not supported yet |
514 |
| - # opOut = CUTENSOR.CUTENSOR_OP_CONJ |
515 |
| - # dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, zero(eltyC), dC, indsC, opC, opOut) |
516 |
| - # C = collect(dC) |
517 |
| - # mC = reshape(permutedims(C, ipC), (loA, loB)) |
518 |
| - # @test mC ≈ conj(α * conj(mA) * conj(mB)) |
519 | 502 | end
|
520 | 503 | end
|
521 | 504 | end
|
|
0 commit comments