@@ -129,7 +129,7 @@ function tensor!(dest::AbstractArray, A::AbstractArray, B::AbstractArray)
129
129
return dest
130
130
end
131
131
132
- export boxdot, ⊡ , boxdot!
132
+ export boxdot, ⊡ , ⊡ ₂, boxdot!
133
133
134
134
"""
135
135
boxdot(A,B) = A ⊡ B # \\ boxdot
@@ -177,40 +177,54 @@ Float64
177
177
```
178
178
See also `boxdot!(Y,A,B)`, which is to `⊡` as `mul!` is to `*`.
179
179
"""
180
- function boxdot (A:: AbstractArray , B:: AbstractArray )
181
- Amat = _squash_left (A)
182
- Bmat = _squash_right (B)
180
+ function boxdot (A:: AbstractArray , B:: AbstractArray , nth:: Val )
181
+ _check_boxdot_axes (A, B, nth)
182
+ Amat = _squash_left (A, nth)
183
+ Bmat = _squash_right (B, nth)
183
184
184
185
axA, axB = axes (Amat,2 ), axes (Bmat,1 )
185
186
axA == axB || _throw_dmm (axA, axB)
186
187
187
- return _boxdot_reshape (Amat * Bmat, A, B)
188
+ return _boxdot_reshape (Amat * Bmat, A, B, nth )
188
189
end
189
190
191
+ boxdot (A:: AbstractArray , B:: AbstractArray ) = boxdot (A, B, Val (1 ))
192
+ boxdot2 (A:: AbstractArray , B:: AbstractArray ) = boxdot (A, B, Val (2 ))
193
+
190
194
const ⊡ = boxdot
195
+ const ⊡ ₂ = boxdot2
191
196
192
197
@noinline _throw_dmm (axA, axB) = throw (DimensionMismatch (" neighbouring axes of `A` and `B` must match, got $axA and $axB " ))
198
+ @noinline _throw_boxdot_nth (n) = throw (ArgumentError (" boxdot order should be ≥ 1, got $n " ))
199
+
200
+ function _check_boxdot_axes (A:: AbstractArray{<:Any,N} , B:: AbstractArray{<:Any,M} , :: Val{K} ) where {N,M,K}
201
+ K:: Int
202
+ (K >= 1 ) || _throw_boxdot_nth (K)
203
+ for i in 1 : K
204
+ axA, axB = axes (A)[N- K+ i], axes (B)[i]
205
+ axA == axB || _throw_dmm (axA, axB)
206
+ end
207
+ end
193
208
194
- _squash_left (A:: AbstractArray ) = reshape (A, :, size (A, ndims (A)) )
195
- _squash_left (A:: AbstractMatrix ) = A
209
+ _squash_left (A:: AbstractArray , :: Val{N} ) where {N} = reshape (A, prod ( size (A)[ 1 : end - N]),: )
210
+ _squash_left (A:: AbstractMatrix , :: Val{1} ) = A
196
211
197
- _squash_right (B:: AbstractArray ) = reshape (B, size (B, 1 ),: )
198
- _squash_right (B:: AbstractVecOrMat ) = B
212
+ _squash_right (B:: AbstractArray , :: Val{N} ) where {N} = reshape (B, :, prod ( size (B)[ 1 + N : end ]) )
213
+ _squash_right (B:: AbstractVecOrMat , :: Val{1} ) = B
199
214
200
- function _boxdot_reshape (AB:: AbstractArray , A:: AbstractArray{T,N} , B:: AbstractArray{S,M} ) where {T,N,S,M}
201
- ax = ntuple (i -> i< N ? axes (A, i) : axes (B, i- N+ 2 ), Val (N+ M- 2 ))
215
+ function _boxdot_reshape (AB:: AbstractArray , A:: AbstractArray{T,N} , B:: AbstractArray{S,M} , :: Val{K} ) where {T,N,S,M,K }
216
+ ax = ntuple (i -> i≤ N - K ? axes (A, i) : axes (B, i- N+ 2 K ), Val (N+ M- 2 K ))
202
217
reshape (AB, ax) # some cases don't come here, so this doesn't really support OffsetArrays
203
218
end
204
219
205
220
# These can skip final reshape:
206
- _boxdot_reshape (AB:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat ) = AB
221
+ _boxdot_reshape (AB:: AbstractVecOrMat , A:: AbstractMatrix , B:: AbstractVecOrMat , :: Val ) = AB
207
222
208
223
# These produce scalar output:
209
- function boxdot (A:: AbstractVector , B:: AbstractVector )
210
- axA, axB = axes (A,1 ), axes (B,1 )
211
- axA == axB || _throw_dmm (axA, axB)
224
+ function boxdot (A:: AbstractArray{<:Any,N} , B:: AbstractArray{<:Any,N} , :: Val{N} ) where {N}
225
+ _check_boxdot_axes (A, B, Val (N))
212
226
if eltype (A) <: Number
213
- return transpose (A) * B
227
+ return transpose (vec (A)) * vec (B)
214
228
else
215
229
return sum (a* b for (a,b) in zip (A,B))
216
230
end
@@ -224,30 +238,30 @@ boxdot(a::Number, b::Number) = a*b
224
238
using LinearAlgebra: AdjointAbsVec, TransposeAbsVec, AdjOrTransAbsVec
225
239
226
240
# Adjont and Transpose, vectors or almost (returning a scalar)
227
- boxdot (A:: AdjointAbsVec , B:: AbstractVector ) = A * B
228
- boxdot (A:: TransposeAbsVec , B:: AbstractVector ) = A * B
241
+ boxdot (A:: AdjointAbsVec , B:: AbstractVector , :: Val{1} ) = A * B
242
+ boxdot (A:: TransposeAbsVec , B:: AbstractVector , :: Val{1} ) = A * B
229
243
230
- boxdot (A:: AbstractVector , B:: AdjointAbsVec ) = A ⊡ vec (B)
231
- boxdot (A:: AbstractVector , B:: TransposeAbsVec ) = A ⊡ vec (B)
244
+ boxdot (A:: AbstractVector , B:: AdjointAbsVec , :: Val{1} ) = A ⊡ vec (B)
245
+ boxdot (A:: AbstractVector , B:: TransposeAbsVec , :: Val{1} ) = A ⊡ vec (B)
232
246
233
- boxdot (A:: AdjointAbsVec , B:: AdjointAbsVec ) = adjoint (adjoint (B) ⊡ adjoint (A))
234
- boxdot (A:: AdjointAbsVec , B:: TransposeAbsVec ) = vec (A) ⊡ vec (B)
235
- boxdot (A:: TransposeAbsVec , B:: AdjointAbsVec ) = vec (A) ⊡ vec (B)
236
- boxdot (A:: TransposeAbsVec , B:: TransposeAbsVec ) = transpose (transpose (B) ⊡ transpose (A))
247
+ boxdot (A:: AdjointAbsVec , B:: AdjointAbsVec , :: Val{1} ) = adjoint (adjoint (B) ⊡ adjoint (A))
248
+ boxdot (A:: AdjointAbsVec , B:: TransposeAbsVec , :: Val{1} ) = vec (A) ⊡ vec (B)
249
+ boxdot (A:: TransposeAbsVec , B:: AdjointAbsVec , :: Val{1} ) = vec (A) ⊡ vec (B)
250
+ boxdot (A:: TransposeAbsVec , B:: TransposeAbsVec , :: Val{1} ) = transpose (transpose (B) ⊡ transpose (A))
237
251
238
252
# ... with a matrix (returning another such)
239
- boxdot (A:: AdjointAbsVec , B:: AbstractMatrix ) = A * B
240
- boxdot (A:: TransposeAbsVec , B:: AbstractMatrix ) = A * B
253
+ boxdot (A:: AdjointAbsVec , B:: AbstractMatrix , :: Val{1} ) = A * B
254
+ boxdot (A:: TransposeAbsVec , B:: AbstractMatrix , :: Val{1} ) = A * B
241
255
242
- boxdot (A:: AbstractMatrix , B:: AdjointAbsVec ) = (B' ⊡ A' )'
243
- boxdot (A:: AbstractMatrix , B:: TransposeAbsVec ) = transpose (transpose (B) ⊡ transpose (A))
256
+ boxdot (A:: AbstractMatrix , B:: AdjointAbsVec , :: Val{1} ) = (B' ⊡ A' )'
257
+ boxdot (A:: AbstractMatrix , B:: TransposeAbsVec , :: Val{1} ) = transpose (transpose (B) ⊡ transpose (A))
244
258
245
259
# ... and with higher-dim (returning a plain array)
246
- boxdot (A:: AdjointAbsVec , B:: AbstractArray ) = vec (A) ⊡ B
247
- boxdot (A:: TransposeAbsVec , B:: AbstractArray ) = vec (A) ⊡ B
260
+ boxdot (A:: AdjointAbsVec , B:: AbstractArray , :: Val{1} ) = vec (A) ⊡ B
261
+ boxdot (A:: TransposeAbsVec , B:: AbstractArray , :: Val{1} ) = vec (A) ⊡ B
248
262
249
- boxdot (A:: AbstractArray , B:: AdjointAbsVec ) = A ⊡ vec (B)
250
- boxdot (A:: AbstractArray , B:: TransposeAbsVec ) = A ⊡ vec (B)
263
+ boxdot (A:: AbstractArray , B:: AdjointAbsVec , :: Val{1} ) = A ⊡ vec (B)
264
+ boxdot (A:: AbstractArray , B:: TransposeAbsVec , :: Val{1} ) = A ⊡ vec (B)
251
265
252
266
253
267
"""
@@ -260,25 +274,30 @@ function boxdot! end
260
274
261
275
if VERSION < v " 1.3" # Then 5-arg mul! isn't defined
262
276
263
- function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray )
264
- szY = prod (size (A)[1 : end - 1 ]), prod (size (B)[2 : end ])
265
- mul! (reshape (Y, szY), _squash_left (A), _squash_right (B))
277
+ function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , :: Val{N} ) where {N}
278
+ _check_boxdot_axes (A, B, Val (N))
279
+ szY = prod (size (A)[1 : end - N]), prod (size (B)[1 + N: end ])
280
+ mul! (reshape (Y, szY), _squash_left (A, Val (N)), _squash_right (B, Val (N)))
266
281
Y
267
282
end
268
283
269
- boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AdjOrTransAbsVec ) = boxdot! (Y, A, vec (B))
284
+ boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray ) = boxdot! (Y, A, B, Val (1 ))
285
+ boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AdjOrTransAbsVec ) = boxdot! (Y, A, vec (B), Val (1 ))
270
286
271
287
else
272
288
273
- function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , α:: Number = true , β:: Number = false )
274
- szY = prod (size (A)[1 : end - 1 ]), prod (size (B)[2 : end ])
275
- mul! (reshape (Y, szY), _squash_left (A), _squash_right (B), α, β)
289
+ function boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , :: Val{N} , α:: Number = true , β:: Number = false ) where {N}
290
+ _check_boxdot_axes (A, B, Val (N))
291
+ szY = prod (size (A)[1 : end - N]), prod (size (B)[1 + N: end ])
292
+ mul! (reshape (Y, szY), _squash_left (A, Val (N)), _squash_right (B, Val (N)), α, β)
276
293
Y
277
294
end
278
295
296
+ boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AbstractArray , α:: Number = true , β:: Number = false ) = boxdot! (Y, A, B, Val (1 ), α, β)
297
+
279
298
# For boxdot!, only where mul! behaves differently:
280
299
boxdot! (Y:: AbstractArray , A:: AbstractArray , B:: AdjOrTransAbsVec ,
281
- α:: Number = true , β:: Number = false ) = boxdot! (Y, A, vec (B), α, β)
300
+ α:: Number = true , β:: Number = false ) = boxdot! (Y, A, vec (B), Val ( 1 ), α, β)
282
301
283
302
end
284
303
0 commit comments