Skip to content

Commit f11c7cb

Browse files
committed
check dimensions of output matrix in mul!
1 parent 160ef4e commit f11c7cb

File tree

2 files changed

+18
-3
lines changed

2 files changed

+18
-3
lines changed

src/linalg.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,18 @@ end
3636

3737
function LinearAlgebra.mul!(Y::AbstractVecOrMat, A::AbstractMatrix, B::OneHotLike)
3838
_isonehot(B) || return invoke(mul!, Tuple{AbstractArray,AbstractMatrix,AbstractMatrix}, Y, A, B)
39-
size(A,2) == size(B,1) || throw(DimensionMismatch("Matrix column must correspond with the OneHot Size $(size(A,2))$(size(B,1))")
40-
)
39+
if size(A,2) size(B,1)
40+
throw(DimensionMismatch("Matrix column must correspond with the OneHot Size $(size(A,2))$(size(B,1))"))
41+
end
42+
if !(size(Y,1) == size(A,1) && size(Y,2) == size(B,2))
43+
throw(DimensionMismatch("Invalid output matrix size for multiplication of matrix sizes $(size(A)) and $(size(B))"))
44+
end
4145
# matmul sometimes wraps in ReshapedArray, taking parent is a simple way to handle that case
42-
copyto!(Y, view(A, :, onecold(parent(B))))
46+
idxs = onecold(parent(B))
47+
if idxs isa Integer # occurs whe B is AbstractVector
48+
copyto!(Y, view(A, :, idxs))
49+
else
50+
NNlib.gather!(Y, A, idxs)
51+
end
4352
end
4453

test/gpu.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ end
2929

3030
# some specialized implementations call only mul! and not *, so we must ensure this works
3131
@test LinearAlgebra.mul!(similar(gA, 3, 3), gA, y) gA*y
32+
@test LinearAlgebra.mul!(similar(gA, 3, 1), gA, onehot(1, 1:2)) gA*onehot(1, 1:2)
33+
34+
@test_throws DimensionMismatch LinearAlgebra.mul!(similar(gA, 3, 4), gA, y)
35+
36+
gB = rand(3, 3) |> cu
37+
@test_throws DimensionMismatch LinearAlgebra.mul!(similar(gB, 3, 3), gB, y)
3238

3339
#TODO: the below fails due to method ambiguity and GPU scalar indexing
3440
y = reshape(y, 3, 2)

0 commit comments

Comments
 (0)