@@ -129,7 +129,7 @@ function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
129
129
return dest
130
130
end
131
131
132
- export boxdot, ⊡ , boxdot!
132
+ export boxdot, ⊡ , ⊡ ₂, boxdot!
133
133
134
134
"""
135
135
boxdot(A,B) = A ⊡ B # \\ boxdot
@@ -177,40 +177,55 @@ Float64
177
177
```
178
178
See also `boxdot!(Y,A,B)`, which is to `⊡` as `mul!` is to `*`.
179
179
"""
180
- function boxdot (A:: AbstractArray , B:: AbstractArray )
181
- Amat = _squash_left (A)
182
- Bmat = _squash_right (B)
180
+ function boxdot (A:: AbstractArray , B:: AbstractArray , nth:: Val )
181
+ _check_boxdot_axes (A, B, nth)
182
+ Amat = _squash_left (A, nth)
183
+ Bmat = _squash_right (B, nth)
183
184
184
185
axA, axB = axes (Amat,2 ), axes (Bmat,1 )
185
186
axA == axB || _throw_dmm (axA, axB)
186
187
187
- return _boxdot_reshape (Amat * Bmat, A, B)
188
+ return _boxdot_reshape (Amat * Bmat, A, B, nth )
188
189
end
189
190
191
+ boxdot (A:: AbstractArray , B:: AbstractArray ) = boxdot (A, B, Val (1 ))
192
+ boxdot2 (A:: AbstractArray , B:: AbstractArray ) = boxdot (A, B, Val (2 ))
193
+
190
194
const ⊡ = boxdot
195
+ const ⊡ ₂ = boxdot2
191
196
192
197
@noinline _throw_dmm (axA, axB) = throw (DimensionMismatch (" neighbouring axes of `A` and `B` must match, got $axA and $axB " ))
198
+ @noinline _throw_boxdot_nth (n) = throw (ArgumentError (" boxdot order should be ≥ 1, got $n " ))
199
+
200
+ function _check_boxdot_axes (A:: AbstractArray{<:Any,N} , B:: AbstractArray{<:Any,M} , :: Val{K} ) where {N,M,K}
201
+ K:: Int
202
+ (K >= 1 ) || _throw_boxdot_nth (K)
203
+ for i in 1 : K
204
+ axA, axB = axes (A)[N- K+ i], axes (B)[i]
205
+ axA == axB || _throw_dmm (axA, axB)
206
+ end
207
+ end
193
208
194
- _squash_left (A:: AbstractArray ) = reshape (A, :, size (A, ndims (A)) )
195
- _squash_left (A:: AbstractMatrix ) = A
209
+ _squash_left (A:: AbstractArray , :: Val{N} ) where {N} = reshape (A, prod ( size (A)[ 1 : end - N]),: )
210
+ _squash_left (A:: AbstractMatrix , :: Val{1} ) = A
196
211
197
- _squash_right (B:: AbstractArray ) = reshape (B, size (B, 1 ),: )
198
- _squash_right (B:: AbstractVecOrMat ) = B
212
+ _squash_right (B:: AbstractArray , :: Val{N} ) where {N} = reshape (B, :, prod ( size (B)[ 1 + N : end ]) )
213
+ _squash_right (B:: AbstractVecOrMat , :: Val{1} ) = B
199
214
200
- function _boxdot_reshape (AB:: AbstractArray , A:: AbstractArray{T,N} , B:: AbstractArray{S,M} ) where {T,N,S,M}
201
- ax = ntuple (i -> i< N ? axes (A, i) : axes (B, i- N+ 2 ), Val (N+ M- 2 ))
215
+ function _boxdot_reshape (AB:: AbstractArray , A:: AbstractArray{T,N} , B:: AbstractArray{S,M} , :: Val{K} ) where {T,N,S,M,K}
216
+ N- K ≥ 1 && M- K ≥ 1 && N+ M- 2 K ≤ 2 && return AB # These can skip final reshape
217
+ ax = ntuple (i -> i≤ N- K ? axes (A, i) : axes (B, i- N+ 2 K), Val (N+ M- 2 K))
202
218
reshape (AB, ax) # some cases don't come here, so this doesn't really support OffsetArrays
203
219
end
204
220
205
221
# These can skip final reshape:
206
- _boxdot_reshape (AB:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat ) = AB
222
+ _boxdot_reshape (AB:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , :: Val ) = AB
207
223
208
224
# These produce scalar output:
209
- function boxdot (A:: AbstractVector , B:: AbstractVector )
210
- axA, axB = axes (A,1 ), axes (B,1 )
211
- axA == axB || _throw_dmm (axA, axB)
225
+ function boxdot (A:: AbstractArray{<:Any,N} , B:: AbstractArray{<:Any,N} , :: Val{N} ) where {N}
226
+ _check_boxdot_axes (A, B, Val (N))
212
227
if eltype (A) <: Number
213
- return transpose (A) * B
228
+ return transpose (vec (A)) * vec (B)
214
229
else
215
230
return sum (a* b for (a,b) in zip (A,B))
216
231
end
@@ -224,30 +239,30 @@ boxdot(a::Number, b::Number) = a*b
224
239
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
225
240
226
241
# Adjont and Transpose, vectors or almost (returning a scalar)
227
- boxdot (A:: AdjointAbsVec , B:: AbstractVector ) = A * B
228
- boxdot (A:: TransposeAbsVec , B:: AbstractVector ) = A * B
242
+ boxdot (A:: AdjointAbsVec , B:: AbstractVector , :: Val{1} ) = A * B
243
+ boxdot (A:: TransposeAbsVec , B:: AbstractVector , :: Val{1} ) = A * B
229
244
230
- boxdot (A:: AbstractVector , B:: AdjointAbsVec ) = A ⊡ vec (B)
231
- boxdot (A:: AbstractVector , B:: TransposeAbsVec ) = A ⊡ vec (B)
245
+ boxdot (A:: AbstractVector , B:: AdjointAbsVec , :: Val{1} ) = A ⊡ vec (B)
246
+ boxdot (A:: AbstractVector , B:: TransposeAbsVec , :: Val{1} ) = A ⊡ vec (B)
232
247
233
- boxdot (A:: AdjointAbsVec , B:: AdjointAbsVec ) = adjoint (adjoint (B) ⊡ adjoint (A))
234
- boxdot (A:: AdjointAbsVec , B:: TransposeAbsVec ) = vec (A) ⊡ vec (B)
235
- boxdot (A:: TransposeAbsVec , B:: AdjointAbsVec ) = vec (A) ⊡ vec (B)
236
- boxdot (A:: TransposeAbsVec , B:: TransposeAbsVec ) = transpose (transpose (B) ⊡ transpose (A))
248
+ boxdot (A:: AdjointAbsVec , B:: AdjointAbsVec , :: Val{1} ) = adjoint (adjoint (B) ⊡ adjoint (A))
249
+ boxdot (A:: AdjointAbsVec , B:: TransposeAbsVec , :: Val{1} ) = vec (A) ⊡ vec (B)
250
+ boxdot (A:: TransposeAbsVec , B:: AdjointAbsVec , :: Val{1} ) = vec (A) ⊡ vec (B)
251
+ boxdot (A:: TransposeAbsVec , B:: TransposeAbsVec , :: Val{1} ) = transpose (transpose (B) ⊡ transpose (A))
237
252
238
253
# ... with a matrix (returning another such)
239
- boxdot (A:: AdjointAbsVec , B:: AbstractMatrix ) = A * B
240
- boxdot (A:: TransposeAbsVec , B:: AbstractMatrix ) = A * B
254
+ boxdot (A:: AdjointAbsVec , B:: AbstractMatrix , :: Val{1} ) = A * B
255
+ boxdot (A:: TransposeAbsVec , B:: AbstractMatrix , :: Val{1} ) = A * B
241
256
242
- boxdot (A:: AbstractMatrix , B:: AdjointAbsVec ) = (B' ⊡ A' )'
243
- boxdot (A:: AbstractMatrix , B:: TransposeAbsVec ) = transpose (transpose (B) ⊡ transpose (A))
257
+ boxdot (A:: AbstractMatrix , B:: AdjointAbsVec , :: Val{1} ) = (B' ⊡ A' )'
258
+ boxdot (A:: AbstractMatrix , B:: TransposeAbsVec , :: Val{1} ) = transpose (transpose (B) ⊡ transpose (A))
244
259
245
260
# ... and with higher-dim (returning a plain array)
246
- boxdot (A:: AdjointAbsVec , B:: AbstractArray ) = vec (A) ⊡ B
247
- boxdot (A:: TransposeAbsVec , B:: AbstractArray ) = vec (A) ⊡ B
261
+ boxdot (A:: AdjointAbsVec , B:: AbstractArray , :: Val{1} ) = vec (A) ⊡ B
262
+ boxdot (A:: TransposeAbsVec , B:: AbstractArray , :: Val{1} ) = vec (A) ⊡ B
248
263
249
- boxdot (A:: AbstractArray , B:: AdjointAbsVec ) = A ⊡ vec (B)
250
- boxdot (A:: AbstractArray , B:: TransposeAbsVec ) = A ⊡ vec (B)
264
+ boxdot (A:: AbstractArray , B:: AdjointAbsVec , :: Val{1} ) = A ⊡ vec (B)
265
+ boxdot (A:: AbstractArray , B:: TransposeAbsVec , :: Val{1} ) = A ⊡ vec (B)
251
266
252
267
253
268
"""
@@ -260,25 +275,30 @@ function boxdot! end
260
275
261
276
if VERSION < v " 1.3" # Then 5-arg mul! isn't defined
262
277
263
- function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray )
264
- szY = prod (size (A)[1 : end - 1 ]), prod (size (B)[2 : end ])
265
- mul! (reshape (Y, szY), _squash_left (A), _squash_right (B))
278
+ function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , :: Val{N} ) where {N}
279
+ _check_boxdot_axes (A, B, Val (N))
280
+ szY = prod (size (A)[1 : end - N]), prod (size (B)[1 + N: end ])
281
+ mul! (reshape (Y, szY), _squash_left (A, Val (N)), _squash_right (B, Val (N)))
266
282
Y
267
283
end
268
284
269
- boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AdjOrTransAbsVec ) = boxdot! (Y, A, vec (B))
285
+ boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray ) = boxdot! (Y, A, B, Val (1 ))
286
+ boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AdjOrTransAbsVec ) = boxdot! (Y, A, vec (B), Val (1 ))
270
287
271
288
else
272
289
273
- function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , α:: Number = true , β:: Number = false )
274
- szY = prod (size (A)[1 : end - 1 ]), prod (size (B)[2 : end ])
275
- mul! (reshape (Y, szY), _squash_left (A), _squash_right (B), α, β)
290
+ function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , :: Val{N} , α:: Number = true , β:: Number = false ) where {N}
291
+ _check_boxdot_axes (A, B, Val (N))
292
+ szY = prod (size (A)[1 : end - N]), prod (size (B)[1 + N: end ])
293
+ mul! (reshape (Y, szY), _squash_left (A, Val (N)), _squash_right (B, Val (N)), α, β)
276
294
Y
277
295
end
278
296
297
+ boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , α:: Number = true , β:: Number = false ) = boxdot! (Y, A, B, Val (1 ), α, β)
298
+
279
299
# For boxdot!, only where mul! behaves differently:
280
300
boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AdjOrTransAbsVec ,
281
- α:: Number = true , β:: Number = false ) = boxdot! (Y, A, vec (B), α, β)
301
+ α:: Number = true , β:: Number = false ) = boxdot! (Y, A, vec (B), Val ( 1 ), α, β)
282
302
283
303
end
284
304
0 commit comments