Skip to content

Commit f15350b

Browse files
committed
fix: broadcasting of closures
1 parent 841376d commit f15350b

File tree

2 files changed

+6
-2
lines changed

2 files changed

+6
-2
lines changed

Diff for: src/TracedRArray.jl

+2-1
Original file line numberDiff line numberDiff line change
@@ -246,7 +246,8 @@ function elem_apply(f, args::Vararg{Any,Nargs}) where {Nargs}
246246
invmap[v] = k
247247
end
248248

249-
input_shapes = size.(keys(seen_args))
249+
keys_seen = [k for k in keys(seen_args) if k isa TracedTypes]
250+
input_shapes = size.(keys_seen)
250251
# by the time we reach here all args must have same size
251252
@assert allequal(input_shapes) "input shapes are $(input_shapes)"
252253
OutShape = isempty(seen_args) ? nothing : first(input_shapes)

Diff for: src/utils.jl

+4-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,10 @@ end
2929

3030
function make_mlir_fn(f, args, kwargs, name="main", concretein=true; toscalar=false)
3131
if sizeof(typeof(f)) != 0 || f isa BroadcastFunction
32-
return (true, make_mlir_fn(apply, (f, args...), kwargs, name, concretein)[2:end]...)
32+
return (
33+
true,
34+
make_mlir_fn(apply, (f, args...), kwargs, name, concretein; toscalar)[2:end]...,
35+
)
3336
end
3437

3538
N = length(args)

0 commit comments

Comments
 (0)