Skip to content

Commit ed9d83f

Browse files
committed
Support boxdot with n neighboring indices
1 parent ca2e4ca commit ed9d83f

File tree

2 files changed

+128
-40
lines changed

2 files changed

+128
-40
lines changed

src/TensorCore.jl

+59-40
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
129129
return dest
130130
end
131131

132-
export boxdot, , boxdot!
132+
export boxdot, , ₂, boxdot!
133133

134134
"""
135135
boxdot(A,B) = A ⊡ B # \\boxdot
@@ -177,40 +177,54 @@ Float64
177177
```
178178
See also `boxdot!(Y,A,B)`, which is to `⊡` as `mul!` is to `*`.
179179
"""
180-
function boxdot(A::AbstractArray, B::AbstractArray)
181-
Amat = _squash_left(A)
182-
Bmat = _squash_right(B)
180+
function boxdot(A::AbstractArray, B::AbstractArray, nth::Val)
181+
_check_boxdot_axes(A, B, nth)
182+
Amat = _squash_left(A, nth)
183+
Bmat = _squash_right(B, nth)
183184

184185
axA, axB = axes(Amat,2), axes(Bmat,1)
185186
axA == axB || _throw_dmm(axA, axB)
186187

187-
return _boxdot_reshape(Amat * Bmat, A, B)
188+
return _boxdot_reshape(Amat * Bmat, A, B, nth)
188189
end
189190

191+
boxdot(A::AbstractArray, B::AbstractArray) = boxdot(A, B, Val(1))
192+
boxdot2(A::AbstractArray, B::AbstractArray) = boxdot(A, B, Val(2))
193+
190194
const = boxdot
195+
const = boxdot2
191196

192197
@noinline _throw_dmm(axA, axB) = throw(DimensionMismatch("neighbouring axes of `A` and `B` must match, got $axA and $axB"))
198+
@noinline _throw_boxdot_nth(n) = throw(ArgumentError("boxdot order should be ≥ 1, got $n"))
199+
200+
function _check_boxdot_axes(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,M}, ::Val{K}) where {N,M,K}
201+
K::Int
202+
(K >= 1) || _throw_boxdot_nth(K)
203+
for i in 1:K
204+
axA, axB = axes(A)[N-K+i], axes(B)[i]
205+
axA == axB || _throw_dmm(axA, axB)
206+
end
207+
end
193208

194-
_squash_left(A::AbstractArray) = reshape(A, :,size(A,ndims(A)))
195-
_squash_left(A::AbstractMatrix) = A
209+
_squash_left(A::AbstractArray, ::Val{N}) where {N} = reshape(A, prod(size(A)[1:end-N]),:)
210+
_squash_left(A::AbstractMatrix, ::Val{1}) = A
196211

197-
_squash_right(B::AbstractArray) = reshape(B, size(B,1),:)
198-
_squash_right(B::AbstractVecOrMat) = B
212+
_squash_right(B::AbstractArray, ::Val{N}) where {N} = reshape(B, :,prod(size(B)[1+N:end]))
213+
_squash_right(B::AbstractVecOrMat, ::Val{1}) = B
199214

200-
function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}) where {T,N,S,M}
201-
ax = ntuple(i -> i<N ? axes(A, i) : axes(B, i-N+2), Val(N+M-2))
215+
function _boxdot_reshape(AB::AbstractArray, A::AbstractArray{T,N}, B::AbstractArray{S,M}, ::Val{K}) where {T,N,S,M,K}
216+
ax = ntuple(i -> iN-K ? axes(A, i) : axes(B, i-N+2K), Val(N+M-2K))
202217
reshape(AB, ax) # some cases don't come here, so this doesn't really support OffsetArrays
203218
end
204219

205220
# These can skip final reshape:
206-
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat) = AB
221+
_boxdot_reshape(AB::AbstractVecOrMat, A::AbstractMatrix, B::AbstractVecOrMat, ::Val) = AB
207222

208223
# These produce scalar output:
209-
function boxdot(A::AbstractVector, B::AbstractVector)
210-
axA, axB = axes(A,1), axes(B,1)
211-
axA == axB || _throw_dmm(axA, axB)
224+
function boxdot(A::AbstractArray{<:Any,N}, B::AbstractArray{<:Any,N}, ::Val{N}) where {N}
225+
_check_boxdot_axes(A, B, Val(N))
212226
if eltype(A) <: Number
213-
return transpose(A)*B
227+
return transpose(vec(A))*vec(B)
214228
else
215229
return sum(a*b for (a,b) in zip(A,B))
216230
end
@@ -224,30 +238,30 @@ boxdot(a::Number, b::Number) = a*b
224238
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
225239

226240
# Adjont and Transpose, vectors or almost (returning a scalar)
227-
boxdot(A::AdjointAbsVec, B::AbstractVector) = A * B
228-
boxdot(A::TransposeAbsVec, B::AbstractVector) = A * B
241+
boxdot(A::AdjointAbsVec, B::AbstractVector, ::Val{1}) = A * B
242+
boxdot(A::TransposeAbsVec, B::AbstractVector, ::Val{1}) = A * B
229243

230-
boxdot(A::AbstractVector, B::AdjointAbsVec) = A vec(B)
231-
boxdot(A::AbstractVector, B::TransposeAbsVec) = A vec(B)
244+
boxdot(A::AbstractVector, B::AdjointAbsVec, ::Val{1}) = A vec(B)
245+
boxdot(A::AbstractVector, B::TransposeAbsVec, ::Val{1}) = A vec(B)
232246

233-
boxdot(A::AdjointAbsVec, B::AdjointAbsVec) = adjoint(adjoint(B) adjoint(A))
234-
boxdot(A::AdjointAbsVec, B::TransposeAbsVec) = vec(A) vec(B)
235-
boxdot(A::TransposeAbsVec, B::AdjointAbsVec) = vec(A) vec(B)
236-
boxdot(A::TransposeAbsVec, B::TransposeAbsVec) = transpose(transpose(B) transpose(A))
247+
boxdot(A::AdjointAbsVec, B::AdjointAbsVec, ::Val{1}) = adjoint(adjoint(B) adjoint(A))
248+
boxdot(A::AdjointAbsVec, B::TransposeAbsVec, ::Val{1}) = vec(A) vec(B)
249+
boxdot(A::TransposeAbsVec, B::AdjointAbsVec, ::Val{1}) = vec(A) vec(B)
250+
boxdot(A::TransposeAbsVec, B::TransposeAbsVec, ::Val{1}) = transpose(transpose(B) transpose(A))
237251

238252
# ... with a matrix (returning another such)
239-
boxdot(A::AdjointAbsVec, B::AbstractMatrix) = A * B
240-
boxdot(A::TransposeAbsVec, B::AbstractMatrix) = A * B
253+
boxdot(A::AdjointAbsVec, B::AbstractMatrix, ::Val{1}) = A * B
254+
boxdot(A::TransposeAbsVec, B::AbstractMatrix, ::Val{1}) = A * B
241255

242-
boxdot(A::AbstractMatrix, B::AdjointAbsVec) = (B' A')'
243-
boxdot(A::AbstractMatrix, B::TransposeAbsVec) = transpose(transpose(B) transpose(A))
256+
boxdot(A::AbstractMatrix, B::AdjointAbsVec, ::Val{1}) = (B' A')'
257+
boxdot(A::AbstractMatrix, B::TransposeAbsVec, ::Val{1}) = transpose(transpose(B) transpose(A))
244258

245259
# ... and with higher-dim (returning a plain array)
246-
boxdot(A::AdjointAbsVec, B::AbstractArray) = vec(A) B
247-
boxdot(A::TransposeAbsVec, B::AbstractArray) = vec(A) B
260+
boxdot(A::AdjointAbsVec, B::AbstractArray, ::Val{1}) = vec(A) B
261+
boxdot(A::TransposeAbsVec, B::AbstractArray, ::Val{1}) = vec(A) B
248262

249-
boxdot(A::AbstractArray, B::AdjointAbsVec) = A vec(B)
250-
boxdot(A::AbstractArray, B::TransposeAbsVec) = A vec(B)
263+
boxdot(A::AbstractArray, B::AdjointAbsVec, ::Val{1}) = A vec(B)
264+
boxdot(A::AbstractArray, B::TransposeAbsVec, ::Val{1}) = A vec(B)
251265

252266

253267
"""
@@ -260,25 +274,30 @@ function boxdot! end
260274

261275
if VERSION < v"1.3" # Then 5-arg mul! isn't defined
262276

263-
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray)
264-
szY = prod(size(A)[1:end-1]), prod(size(B)[2:end])
265-
mul!(reshape(Y, szY), _squash_left(A), _squash_right(B))
277+
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, ::Val{N}) where {N}
278+
_check_boxdot_axes(A, B, Val(N))
279+
szY = prod(size(A)[1:end-N]), prod(size(B)[1+N:end])
280+
mul!(reshape(Y, szY), _squash_left(A, Val(N)), _squash_right(B, Val(N)))
266281
Y
267282
end
268283

269-
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B))
284+
boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray) = boxdot!(Y, A, B, Val(1))
285+
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec) = boxdot!(Y, A, vec(B), Val(1))
270286

271287
else
272288

273-
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false)
274-
szY = prod(size(A)[1:end-1]), prod(size(B)[2:end])
275-
mul!(reshape(Y, szY), _squash_left(A), _squash_right(B), α, β)
289+
function boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, ::Val{N}, α::Number=true, β::Number=false) where {N}
290+
_check_boxdot_axes(A, B, Val(N))
291+
szY = prod(size(A)[1:end-N]), prod(size(B)[1+N:end])
292+
mul!(reshape(Y, szY), _squash_left(A, Val(N)), _squash_right(B, Val(N)), α, β)
276293
Y
277294
end
278295

296+
boxdot!(Y::AbstractArray, A::AbstractArray, B::AbstractArray, α::Number=true, β::Number=false) = boxdot!(Y, A, B, Val(1), α, β)
297+
279298
# For boxdot!, only where mul! behaves differently:
280299
boxdot!(Y::AbstractArray, A::AbstractArray, B::AdjOrTransAbsVec,
281-
α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B), α, β)
300+
α::Number=true, β::Number=false) = boxdot!(Y, A, vec(B), Val(1), α, β)
282301

283302
end
284303

test/runtests.jl

+69
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,75 @@ end
281281
@test boxdot!(similar(c,1), c', d) == [dot(c, d)]
282282
end
283283

284+
@testset "higher-order boxdot" begin
285+
286+
# Arrays
287+
A = [1 2+im; 3 4im; 5 6+im]
288+
B = [5im 6; 7+im 8; 9im 10]
289+
E3 = cat(A, B, conj(A .+ 1), dims=3)
290+
F4 = cat(E3, conj(E3 .+ 1), dims=4)
291+
E3adjoint = conj(permutedims(E3, (3,2,1)))
292+
F4adjoint = conj(permutedims(F4, (4,3,2,1)))
293+
E3lazy = PermutedDimsArray(permutedims(E3, (3,2,1)), (3,2,1))
294+
F4lazy = PermutedDimsArray(permutedims(F4, (4,3,2,1)), (4,3,2,1))
295+
@test E3lazy == E3
296+
@test F4lazy == F4
297+
298+
@test A ₂ A isa Complex
299+
@test boxdot(E3, E3, Val(3)) isa Complex
300+
@test boxdot(F4, F4, Val(4)) isa Complex
301+
@test A ₂ A == sum(A .* A)
302+
@test boxdot(E3, E3, Val(3)) == sum(E3 .* E3)
303+
@test boxdot(F4, F4, Val(4)) == sum(F4 .* F4)
304+
305+
@test size(A ₂ E3) == (3,)
306+
@test A ₂ E3 == vec(reshape(A, 1,:) * reshape(E3, :,3))
307+
@test A ₂ E3lazy == A ₂ E3
308+
@test E3 ₂ A' == vec((A ₂ E3adjoint)')
309+
@test E3 transpose(A) == A conj(E3adjoint)
310+
311+
@test size(A ₂ F4) == (3,2)
312+
@test A ₂ F4 == reshape(reshape(A, 1,:) * reshape(F4, :,6), 3,2)
313+
@test F4 ₂ A == (A' ₂ F4adjoint)'
314+
@test A ₂ F4lazy == A ₂ F4
315+
@test F4lazy ₂ A == F4 ₂ A
316+
317+
@test size(F4 ₂ E3) == (3,2,3)
318+
@test F4 ₂ E3 == reshape(reshape(F4, 6,:) * reshape(E3, :,3), 3,2,3)
319+
@test F4 ₂ E3adjoint == conj(permutedims(E3 ₂ F4adjoint, (3,2,1)))
320+
@test F4 ₂ E3 == F4lazy ₂ E3lazy
321+
322+
# In-place
323+
c = A ₂ E3
324+
@test boxdot!(similar(c), A, E3, Val(2)) == A ₂ E3
325+
if VERSION >= v"1.3"
326+
@test boxdot!(similar(c), A, E3, Val(2), 100) == A ₂ E3 * 100
327+
@test boxdot!(copy(c), B, E3, Val(2), 100, -5) == B ₂ E3 * 100 .- 5 .* c
328+
end
329+
330+
@test boxdot!(similar(c,1), A, A, Val(2)) == [A ₂ A]
331+
@test boxdot!(similar(c,3,2), A, F4, Val(2)) == A ₂ F4
332+
@test boxdot!(similar(c,3,2,3), F4, E3, Val(2)) == F4 ₂ E3
333+
334+
# Errors
335+
@test_throws DimensionMismatch ones(2,2) ones(3,2)
336+
@test_throws DimensionMismatch ones(2,2) ones(2,3)
337+
@test_throws DimensionMismatch ones(2,2,2) ones(2,3,2)
338+
@test_throws BoundsError ones(2,2) ones(2)
339+
@test_throws BoundsError ones(2) ones(2,2)
340+
@test_throws ArgumentError boxdot(ones(2), ones(2), Val(-1))
341+
@test_throws TypeError boxdot(ones(2), ones(2), Val(UInt(1)))
342+
343+
@test_throws DimensionMismatch boxdot!(similar(c,1), ones(2,2), ones(3,2), Val(2))
344+
@test_throws DimensionMismatch boxdot!(similar(c,1), ones(2,2), ones(2,3), Val(2))
345+
@test_throws DimensionMismatch boxdot!(similar(c,2,2), ones(2,2,2), ones(2,3,2), Val(2))
346+
@test_throws BoundsError boxdot!(similar(c,1), ones(2,2), ones(2), Val(2))
347+
@test_throws BoundsError boxdot!(similar(c,1), ones(2), ones(2,2), Val(2))
348+
@test_throws DimensionMismatch boxdot!(similar(c,2,3), ones(2,2,3), ones(2,3,2), Val(2))
349+
@test_throws ArgumentError boxdot!(similar(c,1), ones(2), ones(2), Val(-1))
350+
@test_throws TypeError boxdot!(similar(c,1), ones(2), ones(2), Val(UInt(1)))
351+
end
352+
284353
@testset "_adjoint" begin
285354
A = [1 2+im; 3 4im]
286355
E3 = cat(A, -A, dims=3)

0 commit comments

Comments
 (0)