-
Notifications
You must be signed in to change notification settings - Fork 28
[oneMKL] Interface variants of trsm! and trmm! #479
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
@kballeda It seems that the new variants of T = Float32
alpha = rand(T)
beta = rand(T)
@testset "trmm!" begin
A = triu(rand(T, m, m))
B = rand(T, m, n)
dA = oneArray(A)
dB = oneArray(B)
# Test without beta
C = alpha * A * B
oneMKL.trmm!('L', 'U', 'N', 'N', alpha, dA, dB)
# Move to host and compare
h_C = Array(dB)
@test C ≈ h_C
# Test with beta
C = rand(T, m, n)
dC = oneArray(C)
oneMKL.trmm!('L', 'U', 'N', 'N', alpha, beta, dA, dB, dC) # <-- fail
h_C = Array(dC)
D = alpha * A * B + beta * C
@test D ≈ h_C
end
@testset "left trsm!" begin
A = triu(rand(T, m, m))
B = rand(T, m, n)
dA = oneArray(A)
dB = oneArray(B)
# Test without beta
C = alpha * (A \ B)
dC = copy(dB)
oneMKL.trsm!('L', 'U', 'N', 'N', alpha, dA, dC)
@test C ≈ Array(dC)
# Test with beta
C = rand(T, m, n)
dC = oneArray(C)
oneMKL.trsm!('L', 'U', 'N', 'N', alpha, beta, dA, dB, dC) # <-- fail
h_C = Array(dC)
D = alpha * (A \ B) + beta * C
@test D ≈ h_C
end
@testset "right trsm!" begin
A = rand(T, m, m)
B = triu(rand(T, m, m))
dA = oneArray(A)
dB = oneArray(B)
# Test without beta
C = alpha * (A / B)
dC = copy(dA)
oneMKL.trsm!('R', 'U', 'N', 'N', alpha, dB, dC)
@test C ≈ Array(dC)
# Test with beta
C = rand(T, m, m)
dC = oneArray(C)
oneMKL.trsm!('R', 'U', 'N', 'N', alpha, beta, dA, dB, dC) # <-- fail
h_C = Array(dC)
D = alpha * (A / B) + beta * C
@test D ≈ h_C
end |
Thanks for reporting! Let me check this at my end with C reproducer. |
No, the build over at JuliaPackaging/Yggdrasil#9552 fails probably because of bugs in the Intel libraries. I was going to wait for a new version of the base toolkit before investigating. I see 2025.0.0 has been released now, so we should probably try again. |
I will check how to update the wrappers for oneMKL. [1/4] Building CXX object CMakeFiles/oneapi_support.dir/src/sycl.cpp.o
FAILED: CMakeFiles/oneapi_support.dir/src/sycl.cpp.o
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/icpx -Doneapi_support_EXPORTS -fsycl -isystem /home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/include -isystem /home/alexis/.julia/artifacts/4acaedf5204fc60d0f11bb5d32020fa91c5b3d10/include -std=gnu++17 -fPIC -MD -MT CMakeFiles/oneapi_support.dir/src/sycl.cpp.o -MF CMakeFiles/oneapi_support.dir/src/sycl.cpp.o.d -o CMakeFiles/oneapi_support.dir/src/sycl.cpp.o -c /home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:9:72: error: unknown type name 'pi_native_handle'; did you mean 'ur_native_handle_t'?
9 | auto sycl_platform = sycl::ext::oneapi::level_zero::make_platform((pi_native_handle) driver);
| ^~~~~~~~~~~~~~~~
| ur_native_handle_t
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/compiler/../../include/sycl/ur_api.h:400:19: note: 'ur_native_handle_t' declared here
400 | typedef uintptr_t ur_native_handle_t;
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:9:57: error: no member named 'make_platform' in namespace 'sycl::ext::oneapi::level_zero'
9 | auto sycl_platform = sycl::ext::oneapi::level_zero::make_platform((pi_native_handle) driver);
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:22:68: error: unknown type name 'pi_native_handle'; did you mean 'ur_native_handle_t'?
22 | sycl::ext::oneapi::level_zero::make_device(platform->val, (pi_native_handle) device);
| ^~~~~~~~~~~~~~~~
| ur_native_handle_t
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/compiler/../../include/sycl/ur_api.h:400:19: note: 'ur_native_handle_t' declared here
400 | typedef uintptr_t ur_native_handle_t;
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:22:9: error: no member named 'make_device' in namespace 'sycl::ext::oneapi::level_zero'; did you mean 'sycl::ext::oneapi::level_zero::detail::make_device'?
22 | sycl::ext::oneapi::level_zero::make_device(platform->val, (pi_native_handle) device);
| ^~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
| sycl::ext::oneapi::level_zero::detail::make_device
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/include/sycl/ext/oneapi/backend/level_zero.hpp:44:22: note: 'sycl::ext::oneapi::level_zero::detail::make_device' declared here
44 | __SYCL_EXPORT device make_device(const platform &Platform,
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:40:68: error: unknown type name 'pi_native_handle'; did you mean 'ur_native_handle_t'?
40 | sycl::ext::oneapi::level_zero::make_context(sycl_devices, (pi_native_handle) context, keep_ownership);
| ^~~~~~~~~~~~~~~~
| ur_native_handle_t
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/compiler/../../include/sycl/ur_api.h:400:19: note: 'ur_native_handle_t' declared here
400 | typedef uintptr_t ur_native_handle_t;
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:40:40: error: no member named 'make_context' in namespace 'sycl::ext::oneapi::level_zero'
40 | sycl::ext::oneapi::level_zero::make_context(sycl_devices, (pi_native_handle) context, keep_ownership);
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:54:93: error: unknown type name 'pi_native_handle'; did you mean 'ur_native_handle_t'?
54 | auto sycl_queue = sycl::ext::oneapi::level_zero::make_queue(context->val, device->val, (pi_native_handle) queue, false, keep_ownership, {});
| ^~~~~~~~~~~~~~~~
| ur_native_handle_t
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/compiler/../../include/sycl/ur_api.h:400:19: note: 'ur_native_handle_t' declared here
400 | typedef uintptr_t ur_native_handle_t;
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:54:54: error: no member named 'make_queue' in namespace 'sycl::ext::oneapi::level_zero'
54 | auto sycl_queue = sycl::ext::oneapi::level_zero::make_queue(context->val, device->val, (pi_native_handle) queue, false, keep_ownership, {});
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:66:79: error: unknown type name 'pi_native_handle'; did you mean 'ur_native_handle_t'?
66 | auto sycl_event = sycl::ext::oneapi::level_zero::make_event(context->val, (pi_native_handle) event, keep_ownership);
| ^~~~~~~~~~~~~~~~
| ur_native_handle_t
/home/alexis/.julia/scratchspaces/8f75cd03-7ff8-4ecb-9b8f-daf728133b1b/conda/bin/compiler/../../include/sycl/ur_api.h:400:19: note: 'ur_native_handle_t' declared here
400 | typedef uintptr_t ur_native_handle_t;
| ^
/home/alexis/Bureau/git/oneAPI.jl/deps/src/sycl.cpp:66:53: error: no member named 'make_event' in namespace 'sycl::ext::oneapi::level_zero'
66 | auto sycl_event = sycl::ext::oneapi::level_zero::make_event(context->val, (pi_native_handle) event, keep_ownership);
| ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~^
10 errors generated.
[2/4] Building CXX object CMakeFiles/oneapi_support.dir/src/onemkl.cpp.o
ninja: build stopped: subcommand failed.
ERROR: LoadError: failed process: Process(`/home/alexis/.julia/artifacts/7e62c00e1f15f21da3a56196bac84e23e6d629c3/bin/ninja -C /tmp/jl_NRO6kE install`, ProcessExited(1)) [1] |
dce5866
to
08e7dfa
Compare
Doesn't seem fixed on v2025.0.0 |
1a8482f
to
653050a
Compare
653050a
to
211e6f1
Compare
Your PR requires formatting changes to meet the project's style guidelines. Click here to view the suggested changes.diff --git a/lib/mkl/linalg.jl b/lib/mkl/linalg.jl
index db16da6..22f2ba0 100644
--- a/lib/mkl/linalg.jl
+++ b/lib/mkl/linalg.jl
@@ -5,7 +5,7 @@ using LinearAlgebra: Transpose, Adjoint, AdjOrTrans,
Hermitian, Symmetric,
LowerTriangular, UnitLowerTriangular,
UpperTriangular, UnitUpperTriangular,
- UpperOrLowerTriangular, MulAddMul, wrap
+ UpperOrLowerTriangular, MulAddMul, wrap
#
# BLAS 1
@@ -163,13 +163,13 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
end
-const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<:T,<:oneStridedMatrix}}
+const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<:T, <:oneStridedMatrix}}
function LinearAlgebra.generic_trimatmul!(
- C::oneStridedMatrix{T}, uplocA, isunitcA,
- tfunA::Function, A::oneStridedMatrix{T},
- triB::UpperOrLowerTriangular{T, <: AdjOrTransOroneMatrix{T}},
-) where {T<:onemklFloat}
+ C::oneStridedMatrix{T}, uplocA, isunitcA,
+ tfunA::Function, A::oneStridedMatrix{T},
+ triB::UpperOrLowerTriangular{T, <:AdjOrTransOroneMatrix{T}},
+ ) where {T <: onemklFloat}
uplocB = LinearAlgebra.uplo_char(triB)
isunitcB = LinearAlgebra.isunit_char(triB)
B = parent(triB)
@@ -206,7 +206,7 @@ LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::F
trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)
-LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractMatrix{T}) where {T<:onemklFloat} =
+LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractMatrix{T}) where {T <: onemklFloat} =
trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C)
-LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
+LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::oneStridedMatrix{T}) where {T <: onemklFloat} =
trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C)
diff --git a/lib/mkl/wrappers_blas.jl b/lib/mkl/wrappers_blas.jl
index e01ffd2..a986e0b 100644
--- a/lib/mkl/wrappers_blas.jl
+++ b/lib/mkl/wrappers_blas.jl
@@ -1140,73 +1140,91 @@ function trsm(side::Char,
end
for (mmname_variant, smname_variant, elty) in
- ((:onemklDtrmm_variant, :onemklDtrsm_variant, :Float64),
- (:onemklStrmm_variant, :onemklStrsm_variant, :Float32),
- (:onemklZtrmm_variant, :onemklZtrsm_variant, :ComplexF64),
- (:onemklCtrmm_variant, :onemklCtrsm_variant, :ComplexF32))
+ (
+ (:onemklDtrmm_variant, :onemklDtrsm_variant, :Float64),
+ (:onemklStrmm_variant, :onemklStrsm_variant, :Float32),
+ (:onemklZtrmm_variant, :onemklZtrsm_variant, :ComplexF64),
+ (:onemklCtrmm_variant, :onemklCtrsm_variant, :ComplexF32),
+ )
@eval begin
- function trmm!(side::Char,
- uplo::Char,
- transa::Char,
- diag::Char,
- alpha::Number,
- beta::Number,
- A::oneStridedMatrix{$elty},
- B::oneStridedMatrix{$elty},
- C::oneStridedMatrix{$elty})
+ function trmm!(
+ side::Char,
+ uplo::Char,
+ transa::Char,
+ diag::Char,
+ alpha::Number,
+ beta::Number,
+ A::oneStridedMatrix{$elty},
+ B::oneStridedMatrix{$elty},
+ C::oneStridedMatrix{$elty}
+ )
m, n = size(B)
mA, nA = size(A)
- if mA != nA throw(DimensionMismatch("A must be square")) end
- if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trmm!")) end
- lda = max(1,stride(A,2))
- ldb = max(1,stride(B,2))
- ldc = max(1,stride(C,2))
+ if mA != nA
+ throw(DimensionMismatch("A must be square"))
+ end
+ if nA != (side == 'L' ? m : n)
+ throw(DimensionMismatch("trmm!"))
+ end
+ lda = max(1, stride(A, 2))
+ ldb = max(1, stride(B, 2))
+ ldc = max(1, stride(C, 2))
queue = global_queue(context(A), device())
$mmname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc)
- B
+ return B
end
- function trsm!(side::Char,
- uplo::Char,
- transa::Char,
- diag::Char,
- alpha::Number,
- beta::Number,
- A::oneStridedMatrix{$elty},
- B::oneStridedMatrix{$elty},
- C::oneStridedMatrix{$elty})
+ function trsm!(
+ side::Char,
+ uplo::Char,
+ transa::Char,
+ diag::Char,
+ alpha::Number,
+ beta::Number,
+ A::oneStridedMatrix{$elty},
+ B::oneStridedMatrix{$elty},
+ C::oneStridedMatrix{$elty}
+ )
m, n = size(B)
mA, nA = size(A)
- if mA != nA throw(DimensionMismatch("A must be square")) end
- if nA != (side == 'L' ? m : n) throw(DimensionMismatch("trsm!")) end
- lda = max(1,stride(A,2))
- ldb = max(1,stride(B,2))
- ldc = max(1,stride(C,2))
+ if mA != nA
+ throw(DimensionMismatch("A must be square"))
+ end
+ if nA != (side == 'L' ? m : n)
+ throw(DimensionMismatch("trsm!"))
+ end
+ lda = max(1, stride(A, 2))
+ ldb = max(1, stride(B, 2))
+ ldc = max(1, stride(C, 2))
queue = global_queue(context(A), device())
$smname_variant(sycl_queue(queue), side, uplo, transa, diag, m, n, alpha, A, lda, B, ldb, beta, C, ldc)
- B
+ return B
end
end
end
-function trmm!(side::Char,
- uplo::Char,
- transa::Char,
- diag::Char,
- alpha::Number,
- A::oneStridedMatrix{T},
- B::oneStridedMatrix{T},
- C::oneStridedMatrix{T}) where T
- trmm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
-end
-function trsm!(side::Char,
- uplo::Char,
- transa::Char,
- diag::Char,
- alpha::Number,
- A::oneStridedMatrix{T},
- B::oneStridedMatrix{T},
- C::oneStridedMatrix{T}) where T
- trsm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
+function trmm!(
+ side::Char,
+ uplo::Char,
+ transa::Char,
+ diag::Char,
+ alpha::Number,
+ A::oneStridedMatrix{T},
+ B::oneStridedMatrix{T},
+ C::oneStridedMatrix{T}
+ ) where {T}
+ return trmm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
+end
+function trsm!(
+ side::Char,
+ uplo::Char,
+ transa::Char,
+ diag::Char,
+ alpha::Number,
+ A::oneStridedMatrix{T},
+ B::oneStridedMatrix{T},
+ C::oneStridedMatrix{T}
+ ) where {T}
+ return trsm!(side, uplo, transa, diag, alpha, zero(T), A, B, C)
end
## hemm
diff --git a/test/onemkl.jl b/test/onemkl.jl
index bbafaed..33f1ba7 100644
--- a/test/onemkl.jl
+++ b/test/onemkl.jl
@@ -662,13 +662,13 @@ end
h_C = Array(dB)
@test C ≈ h_C
- C = rand(T,m,n)
- dC = oneArray(C)
- beta = zero(T) # rand(T)
- oneMKL.trmm!('L','U','N','N',alpha,beta,dA,dB,dC)
- h_C = Array(dC)
- D = alpha*A*B + beta*C
- @test D ≈ h_C
+ C = rand(T, m, n)
+ dC = oneArray(C)
+ beta = zero(T) # rand(T)
+ oneMKL.trmm!('L', 'U', 'N', 'N', alpha, beta, dA, dB, dC)
+ h_C = Array(dC)
+ D = alpha * A * B + beta * C
+ @test D ≈ h_C
end
@testset "trmm" begin
@@ -693,13 +693,13 @@ end
oneMKL.trsm!('L','U','N','N',alpha,dA,dC)
@test C ≈ Array(dC)
- C = rand(T,m,n)
- dC = oneArray(C)
- beta = rand(T)
- oneMKL.trsm!('L','U','N','N',alpha,beta,dA,dB,dC)
- h_C = Array(dC)
- D = alpha*(A\B) + beta*C
- @test D ≈ h_C
+ C = rand(T, m, n)
+ dC = oneArray(C)
+ beta = rand(T)
+ oneMKL.trsm!('L', 'U', 'N', 'N', alpha, beta, dA, dB, dC)
+ h_C = Array(dC)
+ D = alpha * (A \ B) + beta * C
+ @test D ≈ h_C
end
@testset "left trsm" begin
@@ -742,13 +742,13 @@ end
oneMKL.trsm!('R','U','N','N',alpha,dB,dC)
@test C ≈ Array(dC)
- C = rand(T,m,m)
- dC = oneArray(C)
- beta = rand(T)
- oneMKL.trsm!('R','U','N','N',alpha,beta,dA,dB,dC)
- h_C = Array(dC)
- D = alpha*(A/B) + beta*C
- @test D ≈ h_C
+ C = rand(T, m, m)
+ dC = oneArray(C)
+ beta = rand(T)
+ oneMKL.trsm!('R', 'U', 'N', 'N', alpha, beta, dA, dB, dC)
+ h_C = Array(dC)
+ D = alpha * (A / B) + beta * C
+ @test D ≈ h_C
end
@testset "right trsm" begin |
CI failures still seem related. |
Interface variants of
trsm!
andtrmm!
with additional arguments.