Skip to content

Commit eb3d1db

Browse files
committed
fix: broadcasting of closures
1 parent 841376d commit eb3d1db

File tree

3 files changed

+26
-2
lines changed

3 files changed

+26
-2
lines changed

Diff for: src/TracedRArray.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
246246
invmap[v] = k
247247
end
248248

249-
input_shapes = size.(keys(seen_args))
249+
keys_seen = [k for k in keys(seen_args) if k isa TracedTypes]
250+
input_shapes = size.(keys_seen)
250251
# by the time we reach here all args must have same size
251252
@assert allequal(input_shapes) "input shapes are $(input_shapes)"
252253
OutShape = isempty(seen_args) ? nothing : first(input_shapes)

Diff for: src/utils.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ end
2929

3030
function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=false)
3131
if sizeof(typeof(f)) != 0 || f isa BroadcastFunction
32-
return (true, make_mlir_fn(apply, (f, args...), kwargs, name, concretein)[2:end]...)
32+
return (
33+
true,
34+
make_mlir_fn(apply, (f, args...), kwargs, name, concretein; toscalar)[2:end]...,
35+
)
3336
end
3437

3538
N = length(args)

Diff for: test/bcast.jl

+20
Original file line numberDiff line numberDiff line change
@@ -112,3 +112,23 @@ pow(x, n) = x .^ n
112112

113113
@test pow_compiled(x_ra) pow(x, 2)
114114
end
115+
116+
struct CustomBCastFunction{X}
117+
x::X
118+
end
119+
120+
(f::CustomBCastFunction)(x::Number) = f.x + x
121+
122+
function custombcast(x)
123+
fn = CustomBCastFunction(3.0)
124+
return fn.(x)
125+
end
126+
127+
@testset "Broadcasting closures / functors" begin
128+
x = rand(2, 3)
129+
x_ra = Reactant.to_rarray(x)
130+
131+
custombcast_compiled = @compile custombcast(x_ra)
132+
133+
@test custombcast_compiled(x_ra) custombcast(x)
134+
end

0 commit comments

Comments
 (0)