Skip to content

Commit 903e6ef

Browse files
authored
Implement broadcast_axis (#133)
1 parent c886707 commit 903e6ef

File tree

4 files changed

+97
-0
lines changed

4 files changed

+97
-0
lines changed

src/ArrayInterface.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -792,6 +792,7 @@ include("dimensions.jl")
792792
include("axes.jl")
793793
include("size.jl")
794794
include("stridelayout.jl")
795+
include("broadcast.jl")
795796

796797

797798
abstract type AbstractArray2{T,N} <: AbstractArray{T,N} end

src/broadcast.jl

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
2+
"""
3+
BroadcastAxis
4+
5+
An abstract trait that is used to determine how axes are combined when calling `broadcast_axis`.
6+
"""
7+
abstract type BroadcastAxis end
8+
9+
struct BroadcastAxisDefault <: BroadcastAxis end
10+
11+
BroadcastAxis(x) = BroadcastAxis(typeof(x))
12+
BroadcastAxis(::Type{T}) where {T} = BroadcastAxisDefault()
13+
14+
"""
15+
broadcast_axis(x, y)
16+
17+
Broadcast axis `x` and `y` into a common space. The resulting axis should be equal in length
18+
to both `x` and `y` unless one has a length of `1`, in which case the longest axis will be
19+
equal to the output.
20+
21+
```julia
22+
julia> ArrayInterface.broadcast_axis(1:10, 1:10)
23+
24+
julia> ArrayInterface.broadcast_axis(1:10, 1)
25+
1:10
26+
27+
```
28+
"""
29+
broadcast_axis(x, y) = broadcast_axis(BroadcastAxis(x), x, y)
30+
# stagger default broadcasting in case y has something other than default
31+
broadcast_axis(::BroadcastAxisDefault, x, y) = _broadcast_axis(BroadcastAxis(y), x, y)
32+
function _broadcast_axis(::BroadcastAxisDefault, x, y)
33+
return One():_combine_length(static_length(x), static_length(y))
34+
end
35+
_broadcast_axis(s::BroadcastAxis, x, y) = broadcast_axis(s, x, y)
36+
37+
# we can use a similar trick as we do with `indices` where unequal sizes error and we just
38+
# keep the static value. However, axes can be unequal if one of them is `1` so we have to
39+
# fall back to dynamic values in those cases
40+
_combine_length(x::StaticInt{X}, y::StaticInt{Y}) where {X,Y} = static(_combine_length(X, Y))
41+
_combine_length(x::StaticInt{X}, ::Int) where {X} = x
42+
_combine_length(x::StaticInt{1}, y::Int) = y
43+
_combine_length(x::StaticInt{1}, y::StaticInt{1}) = y
44+
_combine_length(x::Int, y::StaticInt{Y}) where {Y} = y
45+
_combine_length(x::Int, y::StaticInt{1}) = x
46+
@inline function _combine_length(x::Int, y::Int)
47+
if x === y
48+
return x
49+
elseif y === 1
50+
return x
51+
elseif x === 1
52+
return y
53+
else
54+
_dimerr(x, y)
55+
end
56+
end
57+
58+
function _dimerr(@nospecialize(x), @nospecialize(y))
59+
throw(DimensionMismatch("axes could not be broadcast to a common size; " *
60+
"got axes with lengths $(x) and $(y)"))
61+
end

test/broadcast.jl

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
2+
s5 = static(1):static(5)
3+
s4 = static(1):static(4)
4+
s1 = static(1):static(1)
5+
d5 = static(1):5
6+
d4 = static(1):static(4)
7+
d1 = static(1):static(1)
8+
9+
struct DummyBroadcast <: ArrayInterface.BroadcastAxis end
10+
11+
struct DummyAxis end
12+
13+
ArrayInterface.BroadcastAxis(::Type{DummyAxis}) = DummyBroadcast()
14+
15+
ArrayInterface.broadcast_axis(::DummyBroadcast, x, y) = y
16+
17+
@inferred(ArrayInterface.broadcast_axis(s1, s1)) === s1
18+
@inferred(ArrayInterface.broadcast_axis(s5, s5)) === s5
19+
@inferred(ArrayInterface.broadcast_axis(s5, s1)) === s5
20+
@inferred(ArrayInterface.broadcast_axis(s1, s5)) === s5
21+
@inferred(ArrayInterface.broadcast_axis(s5, d5)) === s5
22+
@inferred(ArrayInterface.broadcast_axis(d5, s5)) === s5
23+
@inferred(ArrayInterface.broadcast_axis(d5, d1)) === d5
24+
@inferred(ArrayInterface.broadcast_axis(d1, d5)) === d5
25+
@inferred(ArrayInterface.broadcast_axis(s1, d5)) === d5
26+
@inferred(ArrayInterface.broadcast_axis(d5, s1)) === d5
27+
@inferred(ArrayInterface.broadcast_axis(s5, DummyAxis())) === s5
28+
29+
@test_throws DimensionMismatch ArrayInterface.broadcast_axis(s5, s4)
30+

test/runtests.jl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -726,3 +726,8 @@ end
726726
include("indexing.jl")
727727
include("dimensions.jl")
728728

729+
@testset "broadcast" begin
730+
include("broadcast.jl")
731+
end
732+
733+

0 commit comments

Comments
 (0)