From 82bc9a6675818648ed21ca7256324ccf08b81a16 Mon Sep 17 00:00:00 2001 From: Pangoraw <9824244+Pangoraw@users.noreply.github.com> Date: Thu, 18 Jul 2024 01:00:30 +0000 Subject: [PATCH] Format code --- ext/ReactantNNlibExt.jl | 8 +++++--- src/Reactant.jl | 2 +- test/nn.jl | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ext/ReactantNNlibExt.jl b/ext/ReactantNNlibExt.jl index 225cfd828..b761562e2 100644 --- a/ext/ReactantNNlibExt.jl +++ b/ext/ReactantNNlibExt.jl @@ -60,7 +60,7 @@ function NNlib.conv( in1, in2 = in1.mlir_data, in2.mlir_data if !NNlib.flipkernel(cdims) - rev = Reactant.MLIR.Dialects.stablehlo.reverse(in2; dimensions=[1,0]) + rev = Reactant.MLIR.Dialects.stablehlo.reverse(in2; dimensions=[1, 0]) in2 = Reactant.MLIR.IR.result(rev, 1) end @@ -68,7 +68,8 @@ function NNlib.conv( (), Reactant.MLIR.IR.result( Reactant.MLIR.Dialects.stablehlo.convolution( - in1, in2; + in1, + in2; result_0=output_type, dimension_numbers=parse( Reactant.MLIR.IR.Attribute, @@ -88,7 +89,8 @@ function NNlib.conv( ), rhs_dilation, lhs_dilation, - padding, window_strides, + padding, + window_strides, feature_group_count=1, batch_group_count=1, ), diff --git a/src/Reactant.jl b/src/Reactant.jl index 07c7636f6..02d4c4692 100644 --- a/src/Reactant.jl +++ b/src/Reactant.jl @@ -1174,7 +1174,7 @@ function run_pass_pipeline!(mod, pass_pipeline) pm = MLIR.IR.PassManager() opm = MLIR.IR.OpPassManager(pm) MLIR.IR.add_pipeline!(opm, pass_pipeline) - MLIR.IR.run!(pm, mod) + return MLIR.IR.run!(pm, mod) end function compile_to_module(mod, f, args; optimize=true) diff --git a/test/nn.jl b/test/nn.jl index a949c93ff..0dbc4a42c 100644 --- a/test/nn.jl +++ b/test/nn.jl @@ -75,7 +75,7 @@ mean((out2[1, :] .> 0.5) .== truth) # accuracy 94% so far! W = randn(Float32, 10, 10, 3, 1) x = randn(Float32, 64, 64, 3, 2) - cW, cx = (W, x) .|> Reactant.ConcreteRArray + cW, cx = Reactant.ConcreteRArray.((W, x)) cconv = Reactant.compile(NNlib.conv, (cx, cW)) out = NNlib.conv(x, W)