Skip to content

Commit 1efe487

Browse files
blegatandreasnoack
authored andcommitted
Make matrix multiplication work for more types (#18218)
* Make matrix multiplication work for more types Currently it is assumed that the type of a sum of x::T and y::T is T but this may not be the case * Remove arithtype in matmul and deprecate it
1 parent f80ea1c commit 1efe487

File tree

3 files changed

+62
-18
lines changed

3 files changed

+62
-18
lines changed

base/deprecated.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1058,4 +1058,22 @@ function reduced_dims0(dims::Dims, region)
10581058
map(last, reduced_dims0(map(n->OneTo(n), dims), region))
10591059
end
10601060

1061+
# #18218
1062+
eval(Base.LinAlg, quote
1063+
function arithtype(T)
1064+
depwarn(string("arithtype is now deprecated. If you were using it inside a ",
1065+
"promote_op call, use promote_op(LinAlg.matprod, Ts...) instead. Otherwise, ",
1066+
"if you need its functionality, consider defining it locally."),
1067+
:arithtype)
1068+
T
1069+
end
1070+
function arithtype(::Type{Bool})
1071+
depwarn(string("arithtype is now deprecated. If you were using it inside a ",
1072+
"promote_op call, use promote_op(LinAlg.matprod, Ts...) instead. Otherwise, ",
1073+
"if you need its functionality, consider defining it locally."),
1074+
:arithtype)
1075+
Int
1076+
end
1077+
end)
1078+
10611079
# End deprecations scheduled for 0.6

base/linalg/matmul.jl

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,7 @@
22

33
# matmul.jl: Everything to do with dense matrix multiplication
44

5-
arithtype(T) = T
6-
arithtype(::Type{Bool}) = Int
5+
matprod(x, y) = x*y + x*y
76

87
# multiply by diagonal matrix as vector
98
function scale!(C::AbstractMatrix, A::AbstractMatrix, b::AbstractVector)
@@ -76,11 +75,11 @@ At_mul_B{T<:BlasComplex}(x::StridedVector{T}, y::StridedVector{T}) = [BLAS.dotu(
7675

7776
# Matrix-vector multiplication
7877
function (*){T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
79-
TS = promote_op(*, arithtype(T), arithtype(S))
78+
TS = promote_op(matprod, T, S)
8079
A_mul_B!(similar(x, TS, size(A,1)), A, convert(AbstractVector{TS}, x))
8180
end
8281
function (*){T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
83-
TS = promote_op(*, arithtype(T), arithtype(S))
82+
TS = promote_op(matprod, T, S)
8483
A_mul_B!(similar(x,TS,size(A,1)),A,x)
8584
end
8685
(*)(A::AbstractVector, B::AbstractMatrix) = reshape(A,length(A),1)*B
@@ -99,22 +98,22 @@ end
9998
A_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'N', A, x)
10099

101100
function At_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
102-
TS = promote_op(*, arithtype(T), arithtype(S))
101+
TS = promote_op(matprod, T, S)
103102
At_mul_B!(similar(x,TS,size(A,2)), A, convert(AbstractVector{TS}, x))
104103
end
105104
function At_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
106-
TS = promote_op(*, arithtype(T), arithtype(S))
105+
TS = promote_op(matprod, T, S)
107106
At_mul_B!(similar(x,TS,size(A,2)), A, x)
108107
end
109108
At_mul_B!{T<:BlasFloat}(y::StridedVector{T}, A::StridedVecOrMat{T}, x::StridedVector{T}) = gemv!(y, 'T', A, x)
110109
At_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_matvecmul!(y, 'T', A, x)
111110

112111
function Ac_mul_B{T<:BlasFloat,S}(A::StridedMatrix{T}, x::StridedVector{S})
113-
TS = promote_op(*, arithtype(T), arithtype(S))
112+
TS = promote_op(matprod, T, S)
114113
Ac_mul_B!(similar(x,TS,size(A,2)),A,convert(AbstractVector{TS},x))
115114
end
116115
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, x::AbstractVector{S})
117-
TS = promote_op(*, arithtype(T), arithtype(S))
116+
TS = promote_op(matprod, T, S)
118117
Ac_mul_B!(similar(x,TS,size(A,2)), A, x)
119118
end
120119

@@ -132,7 +131,7 @@ Ac_mul_B!(y::AbstractVector, A::AbstractVecOrMat, x::AbstractVector) = generic_m
132131
Matrix multiplication.
133132
"""
134133
function (*){T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
135-
TS = promote_op(*, arithtype(T), arithtype(S))
134+
TS = promote_op(matprod, T, S)
136135
A_mul_B!(similar(B, TS, (size(A,1), size(B,2))), A, B)
137136
end
138137
A_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'N', 'N', A, B)
@@ -166,14 +165,14 @@ julia> Y
166165
A_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'N', A, B)
167166

168167
function At_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
169-
TS = promote_op(*, arithtype(T), arithtype(S))
168+
TS = promote_op(matprod, T, S)
170169
At_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
171170
end
172171
At_mul_B!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? syrk_wrapper!(C, 'T', A) : gemm_wrapper!(C, 'T', 'N', A, B)
173172
At_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'T', 'N', A, B)
174173

175174
function A_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
176-
TS = promote_op(*, arithtype(T), arithtype(S))
175+
TS = promote_op(matprod, T, S)
177176
A_mul_Bt!(similar(B, TS, (size(A,1), size(B,1))), A, B)
178177
end
179178
A_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? syrk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'T', A, B)
@@ -190,7 +189,7 @@ end
190189
A_mul_Bt!(C::AbstractVecOrMat, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'T', A, B)
191190

192191
function At_mul_Bt{T,S}(A::AbstractMatrix{T}, B::AbstractVecOrMat{S})
193-
TS = promote_op(*, arithtype(T), arithtype(S))
192+
TS = promote_op(matprod, T, S)
194193
At_mul_Bt!(similar(B, TS, (size(A,2), size(B,1))), A, B)
195194
end
196195
At_mul_Bt!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'T', 'T', A, B)
@@ -199,7 +198,7 @@ At_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generi
199198
Ac_mul_B{T<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{T}) = At_mul_B(A, B)
200199
Ac_mul_B!{T<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = At_mul_B!(C, A, B)
201200
function Ac_mul_B{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
202-
TS = promote_op(*, arithtype(T), arithtype(S))
201+
TS = promote_op(matprod, T, S)
203202
Ac_mul_B!(similar(B, TS, (size(A,2), size(B,2))), A, B)
204203
end
205204
Ac_mul_B!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? herk_wrapper!(C,'C',A) : gemm_wrapper!(C,'C', 'N', A, B)
@@ -208,14 +207,14 @@ Ac_mul_B!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic
208207
A_mul_Bc{T<:BlasFloat,S<:BlasReal}(A::StridedMatrix{T}, B::StridedMatrix{S}) = A_mul_Bt(A, B)
209208
A_mul_Bc!{T<:BlasFloat,S<:BlasReal}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{S}) = A_mul_Bt!(C, A, B)
210209
function A_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S})
211-
TS = promote_op(*, arithtype(T), arithtype(S))
210+
TS = promote_op(matprod, T, S)
212211
A_mul_Bc!(similar(B,TS,(size(A,1),size(B,1))),A,B)
213212
end
214213
A_mul_Bc!{T<:BlasComplex}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = A===B ? herk_wrapper!(C, 'N', A) : gemm_wrapper!(C, 'N', 'C', A, B)
215214
A_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'N', 'C', A, B)
216215

217216
Ac_mul_Bc{T,S}(A::AbstractMatrix{T}, B::AbstractMatrix{S}) =
218-
Ac_mul_Bc!(similar(B, promote_op(*, arithtype(T), arithtype(S)), (size(A,2), size(B,1))), A, B)
217+
Ac_mul_Bc!(similar(B, promote_op(matprod, T, S), (size(A,2), size(B,1))), A, B)
219218
Ac_mul_Bc!{T<:BlasFloat}(C::StridedMatrix{T}, A::StridedVecOrMat{T}, B::StridedVecOrMat{T}) = gemm_wrapper!(C, 'C', 'C', A, B)
220219
Ac_mul_Bc!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'C', A, B)
221220
Ac_mul_Bt!(C::AbstractMatrix, A::AbstractVecOrMat, B::AbstractVecOrMat) = generic_matmatmul!(C, 'C', 'T', A, B)
@@ -448,7 +447,7 @@ end
448447
function generic_matmatmul{T,S}(tA, tB, A::AbstractVecOrMat{T}, B::AbstractMatrix{S})
449448
mA, nA = lapack_size(tA, A)
450449
mB, nB = lapack_size(tB, B)
451-
C = similar(B, promote_op(*, arithtype(T), arithtype(S)), mA, nB)
450+
C = similar(B, promote_op(matprod, T, S), mA, nB)
452451
generic_matmatmul!(C, tA, tB, A, B)
453452
end
454453

@@ -642,7 +641,7 @@ end
642641

643642
# multiply 2x2 matrices
644643
function matmul2x2{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
645-
matmul2x2!(similar(B, promote_op(*, T, S), 2, 2), tA, tB, A, B)
644+
matmul2x2!(similar(B, promote_op(matprod, T, S), 2, 2), tA, tB, A, B)
646645
end
647646

648647
function matmul2x2!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
@@ -671,7 +670,7 @@ end
671670

672671
# Multiply 3x3 matrices
673672
function matmul3x3{T,S}(tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})
674-
matmul3x3!(similar(B, promote_op(*, T, S), 3, 3), tA, tB, A, B)
673+
matmul3x3!(similar(B, promote_op(matprod, T, S), 3, 3), tA, tB, A, B)
675674
end
676675

677676
function matmul3x3!{T,S,R}(C::AbstractMatrix{R}, tA, tB, A::AbstractMatrix{T}, B::AbstractMatrix{S})

test/linalg/matmul.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -389,3 +389,30 @@ let
389389
@test_throws DimensionMismatch A_mul_B!(full43, full43, tri44)
390390
end
391391
end
392+
393+
# #18218
394+
module TestPR18218
395+
using Base.Test
396+
import Base.*, Base.+, Base.zero
397+
immutable TypeA
398+
x::Int
399+
end
400+
Base.convert(::Type{TypeA}, x::Int) = TypeA(x)
401+
immutable TypeB
402+
x::Int
403+
end
404+
immutable TypeC
405+
x::Int
406+
end
407+
Base.convert(::Type{TypeC}, x::Int) = TypeC(x)
408+
zero(c::TypeC) = TypeC(0)
409+
zero(::Type{TypeC}) = TypeC(0)
410+
(*)(x::Int, a::TypeA) = TypeB(x*a.x)
411+
(*)(a::TypeA, x::Int) = TypeB(a.x*x)
412+
(+)(a::Union{TypeB,TypeC}, b::Union{TypeB,TypeC}) = TypeC(a.x+b.x)
413+
A = TypeA[1 2; 3 4]
414+
b = [1, 2]
415+
d = A * b
416+
@test typeof(d) == Vector{TypeC}
417+
@test d == TypeC[5, 11]
418+
end

0 commit comments

Comments
 (0)