Skip to content

Commit 5922ef0

Browse files
committed
one more test
1 parent 7a08273 commit 5922ef0

File tree

1 file changed

+24
-1
lines changed

1 file changed

+24
-1
lines changed

test/control_flow.jl

Lines changed: 24 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -589,7 +589,7 @@ end
589589
# With different operand sizes, different functions need to be generated:
590590
c = rand(4, 5)
591591
c_ra = Reactant.to_rarray(c)
592-
592+
593593
@test @jit(call1(a_ra, c_ra)) call1(a, c)
594594
ir = @code_hlo optimize=false call1(a_ra, c_ra)
595595
ops = [op for op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))]
@@ -608,3 +608,26 @@ end
608608
@test @jit(call2(a_rn)) == call2(a)
609609
end
610610

611+
function _call3(x::Int, y)
612+
if x > 10
613+
return y .+ y
614+
else
615+
return y .* y
616+
end
617+
end
618+
619+
function call3(y)
620+
z = @trace _call3(1, y)
621+
@trace _call3(1, z) # doesn't generate new function because y.shape == z.shape
622+
@trace _call3(11, y) # new function because x changed.
623+
end
624+
625+
@testset "call: caching for Julia operands" begin
626+
y = rand(3)
627+
y_ra = Reactant.to_rarray(y)
628+
629+
ir = @code_hlo optimize=false call3(y_ra)
630+
ops = [op for op in Reactant.MLIR.IR.OperationIterator(Reactant.MLIR.IR.body(ir))]
631+
@test length(ops) == 5 # call3, .+, .*, _call3 (2X)
632+
end
633+

0 commit comments

Comments
 (0)