Skip to content

Commit 4ec401f

Browse files
authored
ScratchVar: Disallow @subroutine Recursion (#261)
* Recursive by-ref treatment: Impose more constraints - Disallow all recursion for any ScratchVar in any `@Subroutine` * `DynamicScratchVar` treatment: Leave as is with explanatory inline comment explaining why it's like that * Break down associated integration tests in `tests/pass_by_ref_test.py` to smaller components * Add documentation regarding illegality of pass-by-ref recursion * Add a simple unit test to `pyteal/compiler/subroutines_test.py` that tests the new functionality of `spillLocalSlotsDuringRecursion`
1 parent a0ae619 commit 4ec401f

11 files changed

+956
-176
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,4 @@ dmypy.json
132132

133133
# IDE
134134
.idea
135+
.vscode

docs/control_structures.rst

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,27 @@ argument is even, but uses recursion to do so.
364364
.Else(recursiveIsEven(i - Int(2)))
365365
)
366366
367+
Recursion and `ScratchVar`'s
368+
----------------------------
369+
370+
Recursion with parameters of type `ScratchVar` is disallowed. For example, the following
371+
subroutine is considered illegal and attempting compilation will result in a `TealInputError`:
372+
373+
.. code-block:: python
374+
375+
@Subroutine(TealType.none)
376+
def ILLEGAL_recursion(i: ScratchVar):
377+
return (
378+
If(i.load() == Int(0))
379+
.Then(i.store(Int(1)))
380+
.ElseIf(i.load() == Int(1))
381+
.Then(i.store(Int(0)))
382+
.Else(Seq(i.store(i.load() - Int(2)), ILLEGAL_recursion(i)))
383+
)
384+
385+
386+
387+
367388
Exiting Subroutines
368389
-------------------
369390

pyteal/ast/scratchvar.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def index(self) -> Expr:
5959
class DynamicScratchVar(ScratchVar):
6060
"""
6161
Example of Dynamic Scratch space whereby the slot index is picked up from the stack:
62-
.. code-block:: python1
62+
.. code-block:: python
6363
6464
player_score = DynamicScratchVar(TealType.uint64)
6565
@@ -97,6 +97,11 @@ def set_index(self, index_var: ScratchVar) -> Expr:
9797
Followup `store`, `load` and `index` operations will use the provided `index_var` until
9898
`set_index()` is called again to reset the referenced ScratchVar.
9999
"""
100+
# Explanatory comment per Issue #242: Preliminary evidence shows that letting users
101+
# pass in any ScratchVar subtype (i.e. DynamicScratchVar) may in fact work.
102+
# However, we are leaving this guard in place pending further investigation.
103+
# TODO: gain confidence that DynamicScratchVar can be used here and
104+
# modify the below strict type equality to `isinstance(index_var, ScratchVar)`
100105
if type(index_var) is not ScratchVar:
101106
raise TealInputError(
102107
"Only allowed to use ScratchVar objects for setting indices, but was given a {}".format(

pyteal/compiler/subroutines.py

Lines changed: 46 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from collections import OrderedDict
44
from itertools import chain
55

6+
from ..errors import TealInputError
67
from ..types import TealType
78
from ..ast import SubroutineDefinition
89
from ..ir import TealComponent, TealOp, Op
@@ -11,7 +12,7 @@
1112
Node = TypeVar("Node")
1213

1314

14-
def depthFirstSearch(graph: Dict[Node, Set[Node]], start: Node, end: Node) -> bool:
15+
def graph_search(graph: Dict[Node, Set[Node]], start: Node, end: Node) -> bool:
1516
"""Check whether a path between start and end exists in the graph.
1617
1718
This works even if start == end, in which case True is only returned if the
@@ -56,12 +57,38 @@ def findRecursionPoints(
5657
reentryPoints[subroutine] = set(
5758
callee
5859
for callee in subroutineGraph[subroutine]
59-
if depthFirstSearch(subroutineGraph, callee, subroutine)
60+
if graph_search(subroutineGraph, callee, subroutine)
6061
)
6162

6263
return reentryPoints
6364

6465

66+
def find_recursive_path(
67+
subroutine_graph: Dict[SubroutineDefinition, Set[SubroutineDefinition]],
68+
subroutine: SubroutineDefinition,
69+
) -> List[SubroutineDefinition]:
70+
visited = set()
71+
loop = []
72+
73+
def dfs(x):
74+
if x in visited:
75+
return False
76+
77+
visited.add(x)
78+
loop.append(x)
79+
for y in subroutine_graph[x]:
80+
if y == subroutine:
81+
loop.append(y)
82+
return True
83+
if dfs(y):
84+
return True
85+
loop.pop()
86+
return False
87+
88+
found = dfs(subroutine)
89+
return loop if found else []
90+
91+
6592
def spillLocalSlotsDuringRecursion(
6693
version: int,
6794
subroutineMapping: Dict[Optional[SubroutineDefinition], List[TealComponent]],
@@ -90,6 +117,23 @@ def spillLocalSlotsDuringRecursion(
90117
"""
91118
recursivePoints = findRecursionPoints(subroutineGraph)
92119

120+
recursive_byref = None
121+
for k, v in recursivePoints.items():
122+
if v and k.by_ref_args:
123+
recursive_byref = k
124+
break
125+
126+
if recursive_byref:
127+
msg = "ScratchVar arguments not allowed in recursive subroutines, but a recursive call-path was detected: {}()"
128+
raise TealInputError(
129+
msg.format(
130+
"()-->".join(
131+
f.name()
132+
for f in find_recursive_path(subroutineGraph, recursive_byref)
133+
)
134+
)
135+
)
136+
93137
coverAvailable = version >= Op.cover.min_version
94138

95139
for subroutine, reentryPoints in recursivePoints.items():

pyteal/compiler/subroutines_test.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from collections import OrderedDict
22

3+
import pytest
4+
35
from .. import *
46

57
from .subroutines import (
@@ -1277,6 +1279,94 @@ def subImpl(a1, a2, a3):
12771279
}
12781280

12791281

1282+
def test_spillLocalSlotsDuringRecursion_recursive_with_scratchvar():
1283+
# modifying test_spillLocalSlotsDuringRecursion_multiple_subroutines_no_recursion()
1284+
# to be recursive and fail due to by-ref args
1285+
def sub1Impl(a1):
1286+
return None
1287+
1288+
def sub2Impl(a1, a2: ScratchVar):
1289+
return None
1290+
1291+
def sub3Impl(a1, a2, a3):
1292+
return None
1293+
1294+
subroutine1 = SubroutineDefinition(sub1Impl, TealType.uint64)
1295+
subroutine2 = SubroutineDefinition(sub2Impl, TealType.uint64)
1296+
subroutine3 = SubroutineDefinition(sub3Impl, TealType.none)
1297+
1298+
subroutine1L1Label = LabelReference("l1")
1299+
subroutine1Ops = [
1300+
TealOp(None, Op.store, 0),
1301+
TealOp(None, Op.load, 0),
1302+
TealOp(None, Op.int, 0),
1303+
TealOp(None, Op.eq),
1304+
TealOp(None, Op.bnz, subroutine1L1Label),
1305+
TealOp(None, Op.err),
1306+
TealLabel(None, subroutine1L1Label),
1307+
TealOp(None, Op.int, 1),
1308+
TealOp(None, Op.callsub, subroutine3),
1309+
TealOp(None, Op.retsub),
1310+
]
1311+
1312+
subroutine2L1Label = LabelReference("l1")
1313+
subroutine2Ops = [
1314+
TealOp(None, Op.store, 1),
1315+
TealOp(None, Op.load, 1),
1316+
TealOp(None, Op.int, 0),
1317+
TealOp(None, Op.eq),
1318+
TealOp(None, Op.bnz, subroutine2L1Label),
1319+
TealOp(None, Op.err),
1320+
TealLabel(None, subroutine2L1Label),
1321+
TealOp(None, Op.int, 1),
1322+
TealOp(None, Op.retsub),
1323+
]
1324+
1325+
subroutine3Ops = [
1326+
TealOp(None, Op.retsub),
1327+
]
1328+
1329+
l1Label = LabelReference("l1")
1330+
mainOps = [
1331+
TealOp(None, Op.txn, "Fee"),
1332+
TealOp(None, Op.int, 0),
1333+
TealOp(None, Op.eq),
1334+
TealOp(None, Op.bz, l1Label),
1335+
TealOp(None, Op.int, 100),
1336+
TealOp(None, Op.callsub, subroutine1),
1337+
TealOp(None, Op.return_),
1338+
TealLabel(None, l1Label),
1339+
TealOp(None, Op.int, 101),
1340+
TealOp(None, Op.callsub, subroutine2),
1341+
TealOp(None, Op.return_),
1342+
]
1343+
1344+
subroutineMapping = {
1345+
None: mainOps,
1346+
subroutine1: subroutine1Ops,
1347+
subroutine2: subroutine2Ops,
1348+
subroutine3: subroutine3Ops,
1349+
}
1350+
1351+
subroutineGraph = {
1352+
subroutine1: {subroutine2},
1353+
subroutine2: {subroutine1},
1354+
subroutine3: set(),
1355+
}
1356+
1357+
localSlots = {None: set(), subroutine1: {0}, subroutine2: {1}, subroutine3: {}}
1358+
1359+
with pytest.raises(TealInputError) as tie:
1360+
spillLocalSlotsDuringRecursion(
1361+
5, subroutineMapping, subroutineGraph, localSlots
1362+
)
1363+
1364+
assert (
1365+
"ScratchVar arguments not allowed in recursive subroutines, but a recursive call-path was detected: sub2Impl()-->sub1Impl()-->sub2Impl()"
1366+
in str(tie)
1367+
)
1368+
1369+
12801370
def test_resolveSubroutines():
12811371
def sub1Impl(a1):
12821372
return None

tests/compile_asserts.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
from pyteal.ir import Mode
55

66

7-
def compile_and_save(approval):
7+
def compile_and_save(approval, version):
88
teal = Path.cwd() / "tests" / "teal"
9-
compiled = compileTeal(approval(), mode=Mode.Application, version=6)
9+
compiled = compileTeal(approval(), mode=Mode.Application, version=version)
1010
name = approval.__name__
1111
with open(teal / (name + ".teal"), "w") as f:
1212
f.write(compiled)
@@ -38,8 +38,7 @@ def assert_teal_as_expected(path2actual, path2expected):
3838

3939
assert len(elines) == len(
4040
alines
41-
), f"""EXPECTED {len(elines)} lines for {path2expected}
42-
but ACTUALLY got {len(alines)} lines in {path2actual}"""
41+
), f"""EXPECTED {len(elines)} lines for {path2expected} but ACTUALLY got {len(alines)} lines in {path2actual}"""
4342

4443
for i, actual in enumerate(alines):
4544
expected = elines[i]
@@ -56,18 +55,15 @@ def assert_teal_as_expected(path2actual, path2expected):
5655
"""
5756

5857

59-
def assert_new_v_old(approve_func):
60-
try:
61-
teal_dir, name, compiled = compile_and_save(approve_func)
58+
def assert_new_v_old(approve_func, version):
59+
teal_dir, name, compiled = compile_and_save(approve_func, version)
6260

63-
print(
64-
f"""Compilation resulted in TEAL program of length {len(compiled)}.
65-
To view output SEE <{name}.teal> in ({teal_dir})
66-
--------------"""
67-
)
61+
print(
62+
f"""Compilation resulted in TEAL program of length {len(compiled)}.
63+
To view output SEE <{name}.teal> in ({teal_dir})
64+
--------------"""
65+
)
6866

69-
path2actual = teal_dir / (name + ".teal")
70-
path2expected = teal_dir / (name + "_expected.teal")
71-
assert_teal_as_expected(path2actual, path2expected)
72-
except Exception as e:
73-
assert not e, f"failed to ASSERT NEW v OLD for {approve_func.__name__}: {e}"
67+
path2actual = teal_dir / (name + ".teal")
68+
path2expected = teal_dir / (name + "_expected.teal")
69+
assert_teal_as_expected(path2actual, path2expected)

0 commit comments

Comments
 (0)