|
1 | 1 | # interfacing with LinearAlgebra standard library
|
2 | 2 |
|
3 | 3 | import LinearAlgebra
|
4 |
| -using LinearAlgebra: Transpose, Adjoint, |
| 4 | +using LinearAlgebra: Transpose, Adjoint, AdjOrTrans, |
5 | 5 | Hermitian, Symmetric,
|
6 | 6 | LowerTriangular, UnitLowerTriangular,
|
7 | 7 | UpperTriangular, UnitUpperTriangular,
|
8 |
| - MulAddMul, wrap |
| 8 | + UpperOrLowerTriangular, MulAddMul, wrap |
9 | 9 |
|
10 | 10 | #
|
11 | 11 | # BLAS 1
|
@@ -163,12 +163,50 @@ function LinearAlgebra.generic_matmatmul!(C::oneStridedMatrix, tA, tB, A::oneStr
|
163 | 163 | GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
|
164 | 164 | end
|
165 | 165 |
|
| 166 | +const AdjOrTransOroneMatrix{T} = Union{oneStridedMatrix{T}, AdjOrTrans{<:T,<:oneStridedMatrix}} |
| 167 | + |
| 168 | +function LinearAlgebra.generic_trimatmul!( |
| 169 | + C::oneStridedMatrix{T}, uplocA, isunitcA, |
| 170 | + tfunA::Function, A::oneStridedMatrix{T}, |
| 171 | + triB::UpperOrLowerTriangular{T, <: AdjOrTransOroneMatrix{T}}, |
| 172 | +) where {T<:onemklFloat} |
| 173 | + uplocB = LinearAlgebra.uplo_char(triB) |
| 174 | + isunitcB = LinearAlgebra.isunit_char(triB) |
| 175 | + B = parent(triB) |
| 176 | + tfunB = LinearAlgebra.wrapperop(B) |
| 177 | + transa = tfunA === identity ? 'N' : tfunA === transpose ? 'T' : 'C' |
| 178 | + transb = tfunB === identity ? 'N' : tfunB === transpose ? 'T' : 'C' |
| 179 | + if uplocA == 'L' && tfunA === identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' # lower * upper |
| 180 | + triu!(B) |
| 181 | + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) |
| 182 | + elseif uplocA == 'U' && tfunA === identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' # upper * lower |
| 183 | + tril!(B) |
| 184 | + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) |
| 185 | + elseif uplocA == 'U' && tfunA === identity && tfunB !== identity && uplocB == 'U' && isunitcA == 'N' |
| 186 | + # operation is reversed to avoid executing the tranpose |
| 187 | + triu!(A) |
| 188 | + trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C) |
| 189 | + elseif uplocA == 'L' && tfunA !== identity && tfunB === identity && uplocB == 'L' && isunitcB == 'N' |
| 190 | + tril!(B) |
| 191 | + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) |
| 192 | + elseif uplocA == 'U' && tfunA !== identity && tfunB === identity && uplocB == 'U' && isunitcB == 'N' |
| 193 | + triu!(B) |
| 194 | + trmm!('L', uplocA, transa, isunitcA, one(T), A, B, C) |
| 195 | + elseif uplocA == 'L' && tfunA === identity && tfunB !== identity && uplocB == 'L' && isunitcA == 'N' |
| 196 | + tril!(A) |
| 197 | + trmm!('R', uplocB, transb, isunitcB, one(T), parent(B), A, C) |
| 198 | + else |
| 199 | + throw("mixed triangular-triangular multiplication") # TODO: rethink |
| 200 | + end |
| 201 | + return C |
| 202 | +end |
| 203 | + |
166 | 204 | # triangular
|
167 | 205 | LinearAlgebra.generic_trimatmul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
|
168 |
| - trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) |
| 206 | + trmm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C) |
169 | 207 | LinearAlgebra.generic_mattrimul!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} =
|
170 |
| - trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) |
171 |
| -LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = |
172 |
| - trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, C === B ? C : copyto!(C, B)) |
173 |
| -LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = |
174 |
| - trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, C === A ? C : copyto!(C, A)) |
| 208 | + trmm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C) |
| 209 | +LinearAlgebra.generic_trimatdiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::oneStridedMatrix{T}, B::AbstractMatrix{T}) where {T<:onemklFloat} = |
| 210 | + trsm!('L', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), A, B, C) |
| 211 | +LinearAlgebra.generic_mattridiv!(C::oneStridedMatrix{T}, uploc, isunitc, tfun::Function, A::AbstractMatrix{T}, B::oneStridedMatrix{T}) where {T<:onemklFloat} = |
| 212 | + trsm!('R', uploc, tfun === identity ? 'N' : tfun === transpose ? 'T' : 'C', isunitc, one(T), B, A, C) |
0 commit comments