Skip to content

Commit 8a9f06c

Browse files
committed
fix: special handling for concatenation of numbers
1 parent 60b614b commit 8a9f06c

File tree

1 file changed

+31
-0
lines changed

1 file changed

+31
-0
lines changed

Diff for: src/TracedRNumber.jl

+31
Original file line numberDiff line numberDiff line change
@@ -209,3 +209,34 @@ struct TypeCast{T<:ReactantPrimitives} <: Function end
209209
(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x)
210210

211211
Base.float(x::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{float(T)}, x)
212+
213+
# Concatenation. Numbers in Julia are handled in a much less generic fashion than arrays
214+
Base.vcat(x::TracedRNumber...) = Base.typed_vcat(Base.promote_eltypeof(x...), x...)
215+
function Base.typed_vcat(::Type{T}, x::TracedRNumber...) where {T}
216+
return Base.typed_vcat(T, map(Base.Fix2(broadcast_to_size, (1,)), x)...)
217+
end
218+
219+
Base.hcat(x::TracedRNumber...) = Base.typed_hcat(Base.promote_eltypeof(x...), x...)
220+
function Base.typed_hcat(::Type{T}, x::TracedRNumber...) where {T}
221+
return Base.typed_hcat(T, map(Base.Fix2(broadcast_to_size, (1, 1)), x)...)
222+
end
223+
224+
function Base.hvcat(rows::Tuple{Vararg{Int}}, xs::TracedRNumber...)
225+
return Base.typed_hvcat(Base.promote_eltypeof(xs...), rows, xs...)
226+
end
227+
function Base.typed_hvcat(
228+
::Type{T}, rows::Tuple{Vararg{Int}}, xs::TracedRNumber...
229+
) where {T}
230+
xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs)
231+
return Base.typed_hvcat(T, rows, xs...)
232+
end
233+
234+
function Base.hvncat(dims::Tuple{Vararg{Int}}, row_first::Bool, xs::TracedRNumber...)
235+
return Base.typed_hvncat(Base.promote_eltypeof(xs...), dims, row_first, xs...)
236+
end
237+
function Base.typed_hvncat(
238+
::Type{T}, dims::Tuple{Vararg{Int}}, row_first::Bool, xs::TracedRNumber...
239+
) where {T}
240+
xs = map(Base.Fix2(broadcast_to_size, (1, 1)), xs)
241+
return Base.typed_hvncat(T, dims, row_first, xs...)
242+
end

0 commit comments

Comments
 (0)