|
| 1 | + |
| 2 | +""" |
| 3 | + BroadcastAxis |
| 4 | +
|
| 5 | +An abstract trait that is used to determine how axes are combined when calling `broadcast_axis`. |
| 6 | +""" |
| 7 | +abstract type BroadcastAxis end |
| 8 | + |
| 9 | +struct BroadcastAxisDefault <: BroadcastAxis end |
| 10 | + |
| 11 | +BroadcastAxis(x) = BroadcastAxis(typeof(x)) |
| 12 | +BroadcastAxis(::Type{T}) where {T} = BroadcastAxisDefault() |
| 13 | + |
| 14 | +""" |
| 15 | + broadcast_axis(x, y) |
| 16 | +
|
| 17 | +Broadcast axis `x` and `y` into a common space. The resulting axis should be equal in length |
| 18 | +to both `x` and `y` unless one has a length of `1`, in which case the longest axis will be |
| 19 | +equal to the output. |
| 20 | +
|
| 21 | +```julia |
| 22 | +julia> ArrayInterface.broadcast_axis(1:10, 1:10) |
| 23 | +
|
| 24 | +julia> ArrayInterface.broadcast_axis(1:10, 1) |
| 25 | +1:10 |
| 26 | +
|
| 27 | +``` |
| 28 | +""" |
| 29 | +broadcast_axis(x, y) = broadcast_axis(BroadcastAxis(x), x, y) |
| 30 | +# stagger default broadcasting in case y has something other than default |
| 31 | +broadcast_axis(::BroadcastAxisDefault, x, y) = _broadcast_axis(BroadcastAxis(y), x, y) |
| 32 | +function _broadcast_axis(::BroadcastAxisDefault, x, y) |
| 33 | + return One():_combine_length(static_length(x), static_length(y)) |
| 34 | +end |
| 35 | +_broadcast_axis(s::BroadcastAxis, x, y) = broadcast_axis(s, x, y) |
| 36 | + |
| 37 | +# we can use a similar trick as we do with `indices` where unequal sizes error and we just |
| 38 | +# keep the static value. However, axes can be unequal if one of them is `1` so we have to |
| 39 | +# fall back to dynamic values in those cases |
| 40 | +_combine_length(x::StaticInt{X}, y::StaticInt{Y}) where {X,Y} = static(_combine_length(X, Y)) |
| 41 | +_combine_length(x::StaticInt{X}, ::Int) where {X} = x |
| 42 | +_combine_length(x::StaticInt{1}, y::Int) = y |
| 43 | +_combine_length(x::StaticInt{1}, y::StaticInt{1}) = y |
| 44 | +_combine_length(x::Int, y::StaticInt{Y}) where {Y} = y |
| 45 | +_combine_length(x::Int, y::StaticInt{1}) = x |
| 46 | +@inline function _combine_length(x::Int, y::Int) |
| 47 | + if x === y |
| 48 | + return x |
| 49 | + elseif y === 1 |
| 50 | + return x |
| 51 | + elseif x === 1 |
| 52 | + return y |
| 53 | + else |
| 54 | + _dimerr(x, y) |
| 55 | + end |
| 56 | +end |
| 57 | + |
| 58 | +function _dimerr(@nospecialize(x), @nospecialize(y)) |
| 59 | + throw(DimensionMismatch("axes could not be broadcast to a common size; " * |
| 60 | + "got axes with lengths $(x) and $(y)")) |
| 61 | +end |
0 commit comments