Skip to content

Commit 5125f84

Browse files
committed
Support boxdot with n neighboring indices
1 parent ca2e4ca commit 5125f84

File tree

2 files changed

+113
-40
lines changed

2 files changed

+113
-40
lines changed

src/TensorCore.jl

+60-40
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
129129
return dest
130130
end
131131

132-
export boxdot, , boxdot!
132+
export boxdot, , ₂, boxdot!
133133

134134
"""
135135
boxdot(A,B) = A ⊡ B # \\boxdot
@@ -177,40 +177,55 @@ Float64
177177
```
178178
See also `boxdot!(Y,A,B)`, which is to `⊡` as `mul!` is to `*`.
179179
"""
180-
function boxdot(A::AbstractArray, B::AbstractArray)
181-
Amat = _squash_left(A)
182-
Bmat = _squash_right(B)
180+
function boxdot(A::AbstractArray, B::AbstractArray, nth::Val)
181+
_check_boxdot_axes(A, B, nth)
182+
Amat = _squash_left(A, nth)
183+
Bmat = _squash_right(B, nth)
183184

184185
axA, axB = axes(Amat,2), axes(Bmat,1)
185186
axA == axB || _throw_dmm(axA, axB)
186187

187-
return _boxdot_reshape(Amat * Bmat, A, B)
188+
return _boxdot_reshape(Amat * Bmat, A, B, nth)
188189
end
189190

191+
boxdot(A::AbstractArray, B::AbstractArray) = boxdot(A, B, Val(1))
192+
boxdot2(A::AbstractArray, B::AbstractArray) = boxdot(A, B, Val(2))
193+
190194
const = boxdot
195+
const = boxdot2
191196

192197
@noinline _throw_dmm(axA, axB) = throw(DimensionMismatch("neighbouring axes of `A` and `B` must match, got $axA and $axB"))
198+
@noinline _throw_boxdot_nth(n) = throw(ArgumentError("boxdot order should be ≥ 1, got $n"))
199+
200+
function _check_boxdot_axes(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,M}, ::Val{K}) where {N,M,K}
201+
K::Int
202+
(K >= 1) || _throw_boxdot_nth(K)
203+
for i in 1:K
204+
axA, axB = axes(A)[N-K+i], axes(B)[i]
205+
axA == axB || _throw_dmm(axA, axB)
206+
end
207+
end
193208

194-
_squash_left(A::AbstractArray) = reshape(A, :,size(A,ndims(A)))
195-
_squash_left(A::AbstractMatrix) = A
209+
_squash_left(A::AbstractArray, ::Val{N}) where {N} = reshape(A, prod(size(A)[1:end-N]),:)
210+
_squash_left(A::AbstractMatrix, ::Val{1}) = A
196211

197-
_squash_right(B::AbstractArray) = reshape(B, size(B,1),:)
198-
_squash_right(B::AbstractVecOrMat) = B
212+
_squash_right(B::AbstractArray, ::Val{N}) where {N} = reshape(B, :,prod(size(B)[1+N:end]))
213+
_squash_right(B::AbstractVecOrMat, ::Val{1}) = B
199214

200-
function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}) where {T,N,S,M}
201-
ax = ntuple(i -> i<N ? axes(A, i) : axes(B, i-N+2), Val(N+M-2))
215+
function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}, ::Val{K}) where {T,N,S,M,K}
216+
N-K 1 && M-K 1 && N+M-2K 2 && return AB # These can skip final reshape
217+
ax = ntuple(i -> iN-K ? axes(A, i) : axes(B, i-N+2K), Val(N+M-2K))
202218
reshape(AB, ax) # some cases don't come here, so this doesn't really support OffsetArrays
203219
end
204220

205221
# These can skip final reshape:
206-
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat) = AB
222+
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, ::Val) = AB
207223

208224
# These produce scalar output:
209-
function boxdot(A::AbstractVector, B::AbstractVector)
210-
axA, axB = axes(A,1), axes(B,1)
211-
axA == axB || _throw_dmm(axA, axB)
225+
function boxdot(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,N}, ::Val{N}) where {N}
226+
_check_boxdot_axes(A, B, Val(N))
212227
if eltype(A) <: Number
213-
return transpose(A)*B
228+
return transpose(vec(A))*vec(B)
214229
else
215230
return sum(a*b for (a,b) in zip(A,B))
216231
end
@@ -224,30 +239,30 @@ boxdot(a::Number, b::Number) = a*b
224239
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
225240

226241
# Adjont and Transpose, vectors or almost (returning a scalar)
227-
boxdot(A::AdjointAbsVec, B::AbstractVector) = A * B
228-
boxdot(A::TransposeAbsVec, B::AbstractVector) = A * B
242+
boxdot(A::AdjointAbsVec, B::AbstractVector, ::Val{1}) = A * B
243+
boxdot(A::TransposeAbsVec, B::AbstractVector, ::Val{1}) = A * B
229244

230-
boxdot(A::AbstractVector, B::AdjointAbsVec) = A vec(B)
231-
boxdot(A::AbstractVector, B::TransposeAbsVec) = A vec(B)
245+
boxdot(A::AbstractVector, B::AdjointAbsVec, ::Val{1}) = A vec(B)
246+
boxdot(A::AbstractVector, B::TransposeAbsVec, ::Val{1}) = A vec(B)
232247

233-
boxdot(A::AdjointAbsVec, B::AdjointAbsVec) = adjoint(adjoint(B) adjoint(A))
234-
boxdot(A::AdjointAbsVec, B::TransposeAbsVec) = vec(A) vec(B)
235-
boxdot(A::TransposeAbsVec, B::AdjointAbsVec) = vec(A) vec(B)
236-
boxdot(A::TransposeAbsVec, B::TransposeAbsVec) = transpose(transpose(B) transpose(A))
248+
boxdot(A::AdjointAbsVec, B::AdjointAbsVec, ::Val{1}) = adjoint(adjoint(B) adjoint(A))
249+
boxdot(A::AdjointAbsVec, B::TransposeAbsVec, ::Val{1}) = vec(A) vec(B)
250+
boxdot(A::TransposeAbsVec, B::AdjointAbsVec, ::Val{1}) = vec(A) vec(B)
251+
boxdot(A::TransposeAbsVec, B::TransposeAbsVec, ::Val{1}) = transpose(transpose(B) transpose(A))
237252

238253
# ... with a matrix (returning another such)
239-
boxdot(A::AdjointAbsVec, B::AbstractMatrix) = A * B
240-
boxdot(A::TransposeAbsVec, B::AbstractMatrix) = A * B
254+
boxdot(A::AdjointAbsVec, B::AbstractMatrix, ::Val{1}) = A * B
255+
boxdot(A::TransposeAbsVec, B::AbstractMatrix, ::Val{1}) = A * B
241256

242-
boxdot(A::AbstractMatrix, B::AdjointAbsVec) = (B' A')'
243-
boxdot(A::AbstractMatrix, B::TransposeAbsVec) = transpose(transpose(B) transpose(A))
257+
boxdot(A::AbstractMatrix, B::AdjointAbsVec, ::Val{1}) = (B' A')'
258+
boxdot(A::AbstractMatrix, B::TransposeAbsVec, ::Val{1}) = transpose(transpose(B) transpose(A))
244259

245260
# ... and with higher-dim (returning a plain array)
246-
boxdot(A::AdjointAbsVec, B::AbstractArray) = vec(A) B
247-
boxdot(A::TransposeAbsVec, B::AbstractArray) = vec(A) B
261+
boxdot(A::AdjointAbsVec, B::AbstractArray, ::Val{1}) = vec(A) B
262+
boxdot(A::TransposeAbsVec, B::AbstractArray, ::Val{1}) = vec(A) B
248263

249-
boxdot(A::AbstractArray, B::AdjointAbsVec) = A vec(B)
250-
boxdot(A::AbstractArray, B::TransposeAbsVec) = A vec(B)
264+
boxdot(A::AbstractArray, B::AdjointAbsVec, ::Val{1}) = A vec(B)
265+
boxdot(A::AbstractArray, B::TransposeAbsVec, ::Val{1}) = A vec(B)
251266

252267

253268
"""
@@ -260,25 +275,30 @@ function boxdot! end
260275

261276
if VERSION < v"1.3" # Then 5-arg mul! isn't defined
262277

263-
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray)
264-
szY = prod(size(A)[1:end-1]), prod(size(B)[2:end])
265-
mul!(reshape(Y, szY), _squash_left(A), _squash_right(B))
278+
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, ::Val{N}) where {N}
279+
_check_boxdot_axes(A, B, Val(N))
280+
szY = prod(size(A)[1:end-N]), prod(size(B)[1+N:end])
281+
mul!(reshape(Y, szY), _squash_left(A, Val(N)), _squash_right(B, Val(N)))
266282
Y
267283
end
268284

269-
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B))
285+
boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray) = boxdot!(Y, A, B, Val(1))
286+
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B), Val(1))
270287

271288
else
272289

273-
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false)
274-
szY = prod(size(A)[1:end-1]), prod(size(B)[2:end])
275-
mul!(reshape(Y, szY), _squash_left(A), _squash_right(B), α, β)
290+
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, ::Val{N}, α::Number=true, β::Number=false) where {N}
291+
_check_boxdot_axes(A, B, Val(N))
292+
szY = prod(size(A)[1:end-N]), prod(size(B)[1+N:end])
293+
mul!(reshape(Y, szY), _squash_left(A, Val(N)), _squash_right(B, Val(N)), α, β)
276294
Y
277295
end
278296

297+
boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false) = boxdot!(Y, A, B, Val(1), α, β)
298+
279299
# For boxdot!, only where mul! behaves differently:
280300
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec,
281-
α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B), α, β)
301+
α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B), Val(1), α, β)
282302

283303
end
284304

test/runtests.jl

+53
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,59 @@ end
279279
@test boxdot!(similar(c,1,2), c', A) == c' * A
280280

281281
@test boxdot!(similar(c,1), c', d) == [dot(c, d)]
282+
283+
@testset "higher-order boxdot" begin
284+
@test A ₂ A isa Complex
285+
@test boxdot(E3, E3, Val(3)) isa Complex
286+
@test boxdot(F4, F4, Val(4)) isa Complex
287+
@test A ₂ A == sum(A .* A)
288+
@test boxdot(E3, E3, Val(3)) == sum(E3 .* E3)
289+
@test boxdot(F4, F4, Val(4)) == sum(F4 .* F4)
290+
291+
@test size(A ₂ E3) == (2,)
292+
@test A ₂ E3 == vec(reshape(A, 1,:) * reshape(E3, :,2))
293+
@test A ₂ E3lazy == A ₂ E3
294+
@test E3 ₂ A' == vec((A ₂ E3adjoint)')
295+
@test E3 transpose(A) == A conj(E3adjoint)
296+
297+
@test size(A ₂ F4) == (2,2)
298+
@test A ₂ F4 == reshape(reshape(A, 1,:) * reshape(F4, :,4), 2,2)
299+
@test A ₂ F4lazy == A ₂ F4
300+
@test F4lazy ₂ A == F4 ₂ A
301+
302+
@test size(F4 ₂ E3) == (2,2,2)
303+
@test F4 ₂ E3 == reshape(reshape(F4, 4,:) * reshape(E3, :,2), 2,2,2)
304+
@test F4 ₂ E3 == F4lazy ₂ E3lazy
305+
306+
# In-place
307+
@test boxdot!(similar(c), A, E3, Val(2)) == A ₂ E3
308+
if VERSION >= v"1.3"
309+
@test boxdot!(similar(c), A, E3, Val(2), 100) == A ₂ E3 * 100
310+
@test boxdot!(copy(c), B, E3, Val(2), 100, -5) == B ₂ E3 * 100 .- 5 .* c
311+
end
312+
313+
@test boxdot!(similar(c,1), A, A, Val(2)) == [A ₂ A]
314+
@test boxdot!(similar(c,2,2), A, F4, Val(2)) == A ₂ F4
315+
@test boxdot!(similar(c,2,2,2), F4, E3, Val(2)) == F4 ₂ E3
316+
317+
# Errors
318+
@test_throws DimensionMismatch ones(2,2) ones(3,2)
319+
@test_throws DimensionMismatch ones(2,2) ones(2,3)
320+
@test_throws DimensionMismatch ones(2,2,2) ones(2,3,2)
321+
@test_throws BoundsError ones(2,2) ones(2)
322+
@test_throws BoundsError ones(2) ones(2,2)
323+
@test_throws ArgumentError boxdot(ones(2), ones(2), Val(-1))
324+
@test_throws TypeError boxdot(ones(2), ones(2), Val(UInt(1)))
325+
326+
@test_throws DimensionMismatch boxdot!(similar(c,1), ones(2,2), ones(3,2), Val(2))
327+
@test_throws DimensionMismatch boxdot!(similar(c,1), ones(2,2), ones(2,3), Val(2))
328+
@test_throws DimensionMismatch boxdot!(similar(c,2,2), ones(2,2,2), ones(2,3,2), Val(2))
329+
@test_throws BoundsError boxdot!(similar(c,1), ones(2,2), ones(2), Val(2))
330+
@test_throws BoundsError boxdot!(similar(c,1), ones(2), ones(2,2), Val(2))
331+
@test_throws DimensionMismatch boxdot!(similar(c,2,3), ones(2,2,3), ones(2,3,2), Val(2))
332+
@test_throws ArgumentError boxdot!(similar(c,1), ones(2), ones(2), Val(-1))
333+
@test_throws TypeError boxdot!(similar(c,1), ones(2), ones(2), Val(UInt(1)))
334+
end
282335
end
283336

284337
@testset "_adjoint" begin

0 commit comments

Comments
 (0)