Skip to content

Commit 40af781

Browse files
committed
fix: import ordering
1 parent 36a6fea commit 40af781

File tree

3 files changed

+7
-5
lines changed

3 files changed

+7
-5
lines changed

Diff for: src/Reactant.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,8 +65,8 @@ include("XLA.jl")
6565
include("Interpreter.jl")
6666
include("utils.jl")
6767
include("ConcreteRArray.jl")
68-
include("TracedRArray.jl")
6968
include("TracedRNumber.jl")
69+
include("TracedRArray.jl")
7070
include("Tracing.jl")
7171
include("Compiler.jl")
7272

Diff for: src/TracedRArray.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -213,10 +213,10 @@ end
213213

214214
promote_to(::TracedRArray{T,N}, rhs) where {T,N} = promote_to(TracedRArray{T,N}, rhs)
215215

216-
struct TypeCast{T<:Number} <: Function end
217-
218-
elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:Number} = x
219-
function elem_apply(::Type{T}, x::TracedRArray{T2}) where {T<:Number,T2<:Number}
216+
elem_apply(::Type{T}, x::TracedRArray{T}) where {T<:ReactantPrimitives} = x
217+
function elem_apply(
218+
::Type{T}, x::TracedRArray{T2}
219+
) where {T<:ReactantPrimitives,T2<:ReactantPrimitives}
220220
# Special Path to prevent going down a despecialized path
221221
return elem_apply(TypeCast{T}(), x)
222222
end

Diff for: src/TracedRNumber.jl

+2
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,8 @@ for (jlop, hloop) in (
158158
end
159159
end
160160

161+
struct TypeCast{T<:ReactantPrimitives} <: Function end
162+
161163
(::TypeCast{T})(x::TracedRNumber{T2}) where {T,T2} = promote_to(TracedRNumber{T}, x)
162164

163165
Base.float(x::TracedRNumber{T}) where {T} = promote_to(TracedRNumber{float(T)}, x)

0 commit comments

Comments
 (0)