Skip to content

Format code of branch "main" #4

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions ext/ReactantCUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -301,20 +301,22 @@ end

Base.IndexStyle(::Type{<:CuTracedArray}) = Base.IndexLinear()

Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} =
arrayref(A, i1)
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} =
arrayset(A, convert(T, x)::T, i1)
Base.@propagate_inbounds Base.getindex(A::CuTracedArray{T}, i1::Integer) where {T} = arrayref(
A, i1
)
Base.@propagate_inbounds Base.setindex!(A::CuTracedArray{T}, x, i1::Integer) where {T} = arrayset(
A, convert(T, x)::T, i1
)

# preserve the specific integer type when indexing device arrays,
# to avoid extending 32-bit hardware indices to 64-bit.
Base.to_index(::CuTracedArray, i::Integer) = i

# Base doesn't like Integer indices, so we need our own ND get and setindex! routines.
# See also: https://github.com/JuliaLang/julia/pull/42289
Base.@propagate_inbounds Base.getindex(
A::CuTracedArray, I::Union{Integer,CartesianIndex}...
) = A[Base._to_linear_index(A, to_indices(A, I)...)]
Base.@propagate_inbounds Base.getindex(A::CuTracedArray, I::Union{Integer,CartesianIndex}...) = A[Base._to_linear_index(
A, to_indices(A, I)...
)]
Base.@propagate_inbounds Base.setindex!(
A::CuTracedArray, x, I::Union{Integer,CartesianIndex}...
) = A[Base._to_linear_index(A, to_indices(A, I)...)] = x
Expand Down
3 changes: 1 addition & 2 deletions ext/ReactantKernelAbstractionsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,8 @@ function ka_with_reactant end # defined in the CUDA extension

Reactant.@reactant_overlay @noinline Base.@nospecializeinfer function (
obj::KA.Kernel{ReactantBackend}
)(
args...; ndrange=nothing, workgroupsize=nothing
)
(args...; ndrange=nothing, workgroupsize=nothing)
@nospecialize
return Reactant.call_with_reactant(
ka_with_reactant, ndrange, workgroupsize, obj, args...
Expand Down
16 changes: 7 additions & 9 deletions lib/ReactantCore/src/ReactantCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -182,12 +182,10 @@ function trace_for(mod, expr; track_numbers)
step = length(range.args) == 3 ? 1 : range.args[3]
limit = range.args[end]

body_symbols = ExpressionExplorer.compute_symbols_state(
quote
$(Expr(:local, assign))
$body
end,
)
body_symbols = ExpressionExplorer.compute_symbols_state(quote
$(Expr(:local, assign))
$body
end)

external_syms = body_symbols.assignments ∪ body_symbols.references
filter!(∉(SPECIAL_SYMBOLS), external_syms)
Expand Down Expand Up @@ -243,8 +241,8 @@ function trace_for(mod, expr; track_numbers)
cond_fn,
body_fn,
args;
track_numbers=$(track_numbers),
verify_arg_names=$(QuoteNode(args_names)),
track_numbers=($(track_numbers)),
verify_arg_names=($(QuoteNode(args_names))),
)
end
end
Expand Down Expand Up @@ -420,7 +418,7 @@ function trace_if(mod, expr; store_last_line=nothing, depth=0, track_numbers)
$(true_branch_fn_name),
$(false_branch_fn_name),
($(all_input_vars...),);
track_numbers=$(track_numbers),
track_numbers=($(track_numbers)),
)
end

Expand Down
12 changes: 6 additions & 6 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ end
return Base.setproperty!(obj, field, val)
end

@inline traced_setfield!(@nospecialize(obj), field, val, path) =
Base.setfield!(obj, field, val)
@inline traced_setfield!(@nospecialize(obj), field, val, path) = Base.setfield!(
obj, field, val
)

@inline function traced_setfield!(
@nospecialize(obj::AbstractArray{T}), field, val, path
Expand Down Expand Up @@ -1858,7 +1859,7 @@ function compile_call_expr(mod, compiler, options::Dict, args...)
$(compiled_symbol) = $(compiler)(
$(f_symbol),
$(args_symbol);
fn_kwargs=$(kwargs_symbol),
fn_kwargs=($(kwargs_symbol)),
$(Expr.(:kw, keys(options), values(options))...),
)
end,
Expand Down Expand Up @@ -2879,9 +2880,8 @@ end

@generated function (
thunk::Thunk{FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,ClientTy,GD,DAM}
)(
args...
) where {FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,ClientTy,GD,DAM}
)
(args...) where {FTy,tag,IsClosure,ArgTypes,ExecTy,DeviceTy,ClientTy,GD,DAM}
FoundTypes = Tuple{args...}
if ArgTypes != FoundTypes
return quote
Expand Down
10 changes: 7 additions & 3 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -357,7 +357,9 @@ function Base.setindex!(a::ConcreteIFRTArray, v, args::Vararg{Int,N}) where {N}
end

# TODO is there any way to allocate an uninitialized buffer in XLA?
function Base.similar(a::ConcretePJRTArray{T}, ::Type{S}=T, dims::Dims=size(a)) where {T,S}
function Base.similar(
a::ConcretePJRTArray{T}, (::Type{S})=T, dims::Dims=size(a)
) where {T,S}
return ConcretePJRTArray(
Array{S}(undef, dims); client=XLA.client(a), device=XLA.device(a), a.sharding
)
Expand All @@ -367,7 +369,9 @@ function Base.similar(::Type{ConcretePJRTArray{T}}, dims) where {T}
return ConcretePJRTArray(similar(Array{T}, dims))
end

function Base.similar(a::ConcreteIFRTArray{T}, ::Type{S}=T, dims::Dims=size(a)) where {T,S}
function Base.similar(
a::ConcreteIFRTArray{T}, (::Type{S})=T, dims::Dims=size(a)
) where {T,S}
return ConcreteIFRTArray(
Array{S}(undef, dims); client=XLA.client(a), device=XLA.device(a), a.sharding
)
Expand Down Expand Up @@ -403,7 +407,7 @@ function Base.copy(bc::Base.Broadcast.Broadcasted{Broadcast.ArrayStyle{ConcreteP
if all(buffer_on_cpu, bc.args) && all(
x ->
!(x isa ConcretePJRTArray) ||
(x isa ConcretePJRTArray && !Sharding.is_sharded(x)),
(x isa ConcretePJRTArray && !Sharding.is_sharded(x)),
bc.args,
)
ElType = Base.Broadcast.combine_eltypes(bc.f, bc.args)
Expand Down
57 changes: 32 additions & 25 deletions src/Interpreter.jl
Original file line number Diff line number Diff line change
Expand Up @@ -83,11 +83,11 @@ end
ReactantCacheToken(),
REACTANT_METHOD_TABLE,
world,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
false,
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#set_reactant_abi,
)
end
else
Expand All @@ -100,11 +100,11 @@ else
REACTANT_CACHE,
REACTANT_METHOD_TABLE,
world,
false, #=forward_rules=#
false, #=reverse_rules=#
false, #=inactive_rules=#
false, #=broadcast_rewrite=#
set_reactant_abi,
false,
#=forward_rules=#false,
#=reverse_rules=#false,
#=inactive_rules=#false,
#=broadcast_rewrite=#set_reactant_abi,
)
end
end
Expand All @@ -116,20 +116,25 @@ const enzyme_dupnoneed = 3
const enzyme_outnoneed = 4
const enzyme_constnoneed = 5

@inline act_from_type(x, reverse, needs_primal=true) =
throw(AssertionError("Unhandled activity $(typeof(x))"))
@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) =
act_from_type(Enzyme.Const, reverse, needs_primal)
@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) =
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
@inline act_from_type(x, reverse, needs_primal=true) = throw(
AssertionError("Unhandled activity $(typeof(x))")
)
@inline act_from_type(::Enzyme.Const, reverse, needs_primal=true) = act_from_type(
Enzyme.Const, reverse, needs_primal
)
@inline act_from_type(::Enzyme.Duplicated, reverse, needs_primal=true) = act_from_type(
Enzyme.Duplicated, reverse, needs_primal
)
@inline act_from_type(::Enzyme.DuplicatedNoNeed, reverse, needs_primal=true) =
reverse ? enzyme_out : enzyme_dupnoneed
@inline act_from_type(::Enzyme.BatchDuplicated, reverse, needs_primal=true) =
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
@inline act_from_type(::Enzyme.BatchDuplicated, reverse, needs_primal=true) = act_from_type(
Enzyme.Duplicated, reverse, needs_primal
)
@inline act_from_type(::Enzyme.BatchDuplicatedNoNeed, reverse, needs_primal=true) =
reverse ? enzyme_out : enzyme_dupnoneed
@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) =
act_from_type(Enzyme.Active, reverse, needs_primal)
@inline act_from_type(::Enzyme.Active, reverse, needs_primal=true) = act_from_type(
Enzyme.Active, reverse, needs_primal
)
@inline act_from_type(::Type{<:Enzyme.Const}, reverse, needs_primal) =
if needs_primal
enzyme_const
Expand All @@ -151,10 +156,12 @@ const enzyme_constnoneed = 5
end
end

@inline act_from_type(::Type{<:Enzyme.BatchDuplicated}, reverse, needs_primal) =
act_from_type(Enzyme.Duplicated, reverse, needs_primal)
@inline act_from_type(::Type{<:Enzyme.BatchDuplicatedNoNeed}, reverse, needs_primal) =
act_from_type(Enzyme.DuplicatedNoNeed, Reverse, needs_primal)
@inline act_from_type(::Type{<:Enzyme.BatchDuplicated}, reverse, needs_primal) = act_from_type(
Enzyme.Duplicated, reverse, needs_primal
)
@inline act_from_type(::Type{<:Enzyme.BatchDuplicatedNoNeed}, reverse, needs_primal) = act_from_type(
Enzyme.DuplicatedNoNeed, Reverse, needs_primal
)

@inline act_from_type(::Type{<:Enzyme.Active}, reverse, needs_primal) =
if needs_primal
Expand Down Expand Up @@ -498,7 +505,7 @@ function overload_autodiff(
false,
TracedUtils.transpose_val(MLIR.IR.result(res, residx));
emptypaths=true,
) #=reverse=#
)#=reverse=#
residx += 1
continue
end
Expand Down
102 changes: 49 additions & 53 deletions src/Ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,7 @@ for (dialect, op) in
res = MLIR.IR.result(
$(:($dialect.$op))(
x.mlir_data;
$(result)=mlir_type(TracedRArray{Bool,N}, size(x)),
($(result))=mlir_type(TracedRArray{Bool,N}, size(x)),
location,
),
)
Expand All @@ -404,7 +404,7 @@ for (dialect, op) in
) where {T}
res = MLIR.IR.result(
$(:($dialect.$op))(
x.mlir_data; $(result)=mlir_type(TracedRArray{Bool,0}, ()), location
x.mlir_data; ($(result))=mlir_type(TracedRArray{Bool,0}, ()), location
),
)
return TracedRNumber{Bool}((), res)
Expand Down Expand Up @@ -1127,16 +1127,15 @@ end
sample_inputs[2i - 1] = Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0)
sample_inputs[2i] = Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0)
end
func =
Reactant.TracedUtils.make_mlir_fn(
comparator,
(sample_inputs...,),
(),
"comparator",
false;
args_in_result=:none,
return_dialect=:stablehlo,
).f
func = Reactant.TracedUtils.make_mlir_fn(
comparator,
(sample_inputs...,),
(),
"comparator",
false;
args_in_result=:none,
return_dialect=:stablehlo,
).f
@assert MLIR.IR.nregions(func) == 1
fn_name = String(
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
Expand Down Expand Up @@ -1745,36 +1744,34 @@ end

input_types = [mlir_type(arg) for arg in linear_args]

cond_fn_compiled =
Reactant.TracedUtils.make_mlir_fn(
cond_fn,
traced_args,
(),
string(gensym("cond_fn")),
false;
return_dialect=:stablehlo,
args_in_result=:result,
do_transpose=false,
argprefix=gensym("loop_condarg"),
resprefix=gensym("loop_condres"),
resargprefix=gensym("loop_condresarg"),
).f

body_fn_compiled =
Reactant.TracedUtils.make_mlir_fn(
body_fn,
traced_args,
(),
string(gensym("body_fn")),
false;
return_dialect=:stablehlo,
args_in_result=:all,
do_transpose=false,
verify_arg_names,
argprefix=gensym("loop_bodyarg"),
resprefix=gensym("loop_bodyres"),
resargprefix=gensym("loop_bodyresarg"),
).f
cond_fn_compiled = Reactant.TracedUtils.make_mlir_fn(
cond_fn,
traced_args,
(),
string(gensym("cond_fn")),
false;
return_dialect=:stablehlo,
args_in_result=:result,
do_transpose=false,
argprefix=gensym("loop_condarg"),
resprefix=gensym("loop_condres"),
resargprefix=gensym("loop_condresarg"),
).f

body_fn_compiled = Reactant.TracedUtils.make_mlir_fn(
body_fn,
traced_args,
(),
string(gensym("body_fn")),
false;
return_dialect=:stablehlo,
args_in_result=:all,
do_transpose=false,
verify_arg_names,
argprefix=gensym("loop_bodyarg"),
resprefix=gensym("loop_bodyres"),
resargprefix=gensym("loop_bodyresarg"),
).f

cond_reg = Reactant.TracedUtils.__take_region(cond_fn_compiled)
body_reg = Reactant.TracedUtils.__take_region(body_fn_compiled)
Expand Down Expand Up @@ -2384,7 +2381,7 @@ end
@assert ndevices == length(logical_device_ids) "length(logical_device_ids) should be \
same as prod(last, mesh_axes)"
@assert all(Base.Fix2(≥, 0), logical_device_ids) "logical_device_ids must be \
non-negative"
non-negative"

sorted_logical_device_ids = Base.sort(logical_device_ids)
@assert sorted_logical_device_ids == 0:(ndevices - 1) "sorted logical_device_ids \
Expand Down Expand Up @@ -2526,16 +2523,15 @@ Applies a reduction function `fn` along the specified `dimensions` of input `x`,
Reactant.TracedUtils.promote_to(TracedRNumber{T}, 0),
]

func =
Reactant.TracedUtils.make_mlir_fn(
fn,
(sample_inputs),
(),
"reduce_fn",
false;
args_in_result=:none,
return_dialect=:stablehlo,
).f
func = Reactant.TracedUtils.make_mlir_fn(
fn,
(sample_inputs),
(),
"reduce_fn",
false;
args_in_result=:none,
return_dialect=:stablehlo,
).f
@assert MLIR.IR.nregions(func) == 1
fn_name = String(
MLIR.IR.attr(func, String(MLIR.API.mlirSymbolTableGetSymbolAttributeName()))
Expand Down
2 changes: 1 addition & 1 deletion src/Overlay.jl
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ for randfun in (:rand, :randn, :randexp)

# scalars
@reactant_overlay @noinline function Random.$(randfun)(
rng::AbstractRNG, ::Type{T}=Float64
rng::AbstractRNG, (::Type{T})=Float64
) where {T}
if T <: ReactantPrimitive
return TracedRandom.$(overload_randfun)(rng, T)
Expand Down
2 changes: 1 addition & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ include("Compiler.jl")
include("Overlay.jl")

function Enzyme.make_zero(
::Type{RT}, seen::IdDict, prev::RT, ::Val{copy_if_inactive}=Val(false)
::Type{RT}, seen::IdDict, prev::RT, (::Val{copy_if_inactive})=Val(false)
)::RT where {copy_if_inactive,RT<:Union{RArray,RNumber}}
if haskey(seen, prev)
return seen[prev]
Expand Down
Loading