Skip to content

Commit 9f51145

Browse files
authored
Allow anonymous functions to recursively call themselves (#139)
1 parent 466c861 commit 9f51145

File tree

2 files changed

+49
-10
lines changed

2 files changed

+49
-10
lines changed

basilisp/compiler.py

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -721,22 +721,34 @@ def _compose_ifs(if_stmts: List[Dict[str, ast.AST]], orelse: List[ast.AST] = Non
721721
orelse=Maybe(orelse).or_else_get([]))
722722

723723

724-
def _single_arity_fn_ast(ctx: CompilerContext, name: str, fndef: llist.List) -> ASTStream:
724+
def _fn_name(s: Optional[sym.Symbol]) -> str:
725+
"""Generate a safe Python function name from a function name symbol.
726+
If no symbol is provided, generate a name with a default prefix."""
727+
return genname("__" + munge(Maybe(s).map(lambda s: s.name).or_else_get(_FN_PREFIX)))
728+
729+
730+
def _single_arity_fn_ast(ctx: CompilerContext, name: Optional[sym.Symbol], fndef: llist.List) -> ASTStream:
725731
"""Generate Python AST nodes for a single-arity function."""
726-
with ctx.new_symbol_table(name), ctx.new_recur_point(name, fndef.first):
732+
py_fn_name = _fn_name(name)
733+
with ctx.new_symbol_table(py_fn_name), ctx.new_recur_point(py_fn_name, fndef.first):
734+
# Allow named anonymous functions to recursively call themselves
735+
if name is not None:
736+
ctx.symbol_table.new_symbol(name, py_fn_name, _SYM_CTX_LOCAL)
737+
727738
args, body, vargs = _fn_args_body(ctx, fndef.first, fndef.rest)
728739

729-
yield _dependency(_expressionize(body, name, args=args, vargs=vargs))
740+
yield _dependency(_expressionize(body, py_fn_name, args=args, vargs=vargs))
730741
if ctx.recur_point.has_recur:
731742
yield _node(ast.Call(func=_TRAMPOLINE_FN_NAME,
732743
args=[ast.Name(id=ctx.recur_point.name, ctx=ast.Load())],
733744
keywords=[]))
734745
else:
735-
yield _node(ast.Name(id=name, ctx=ast.Load()))
746+
yield _node(ast.Name(id=py_fn_name, ctx=ast.Load()))
736747
return
737748

738749

739-
def _multi_arity_fn_ast(ctx: CompilerContext, name: str, arities: List[FunctionArityDetails]) -> ASTStream:
750+
def _multi_arity_fn_ast(ctx: CompilerContext, name: Optional[sym.Symbol],
751+
arities: List[FunctionArityDetails]) -> ASTStream:
740752
"""Generate Python AST nodes for multi-arity Basilisp function definitions.
741753
742754
For example, a multi-arity function like this:
@@ -774,14 +786,19 @@ def __f_68(*multi_arity_args):
774786
775787
776788
f = __f_68"""
789+
py_fn_name = _fn_name(name)
777790
if_stmts: List[Dict[str, ast.AST]] = []
778791
multi_arity_args_arg = _load_attr(_MULTI_ARITY_ARG_NAME)
779792
has_rest = False
780793

781794
for arg_count, is_rest, arity in arities:
782-
with ctx.new_recur_point(name, arity.first):
795+
with ctx.new_recur_point(py_fn_name, arity.first):
796+
# Allow named anonymous functions to recursively call themselves
797+
if name is not None:
798+
ctx.symbol_table.new_symbol(name, py_fn_name, _SYM_CTX_LOCAL)
799+
783800
has_rest = any([has_rest, is_rest])
784-
arity_name = f"{name}__arity{'_rest' if is_rest else arg_count}"
801+
arity_name = f"{py_fn_name}__arity{'_rest' if is_rest else arg_count}"
785802

786803
with ctx.new_symbol_table(arity_name):
787804
# Generate the arity function
@@ -813,7 +830,7 @@ def __f_68(*multi_arity_args):
813830
assert len(if_stmts) == len(arities)
814831

815832
yield _dependency(ast.FunctionDef(
816-
name=name,
833+
name=py_fn_name,
817834
args=ast.arguments(
818835
args=[],
819836
kwarg=None,
@@ -832,14 +849,14 @@ def __f_68(*multi_arity_args):
832849
decorator_list=[],
833850
returns=None))
834851

835-
yield _node(ast.Name(id=name, ctx=ast.Load()))
852+
yield _node(ast.Name(id=py_fn_name, ctx=ast.Load()))
836853

837854

838855
def _fn_ast(ctx: CompilerContext, form: llist.List) -> ASTStream:
839856
"""Generate a Python AST Nodes for function definitions."""
840857
assert form.first == _FN
841858
has_name = isinstance(form[1], sym.Symbol)
842-
name = genname("__" + (munge(form[1].name) if has_name else _FN_PREFIX))
859+
name = form[1] if has_name else None
843860

844861
rest_idx = 1 + int(has_name)
845862
arities = list(_fn_arities(ctx, form[rest_idx:]))

tests/compiler_test.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -521,6 +521,28 @@ def test_disallow_recur_outside_tail(ns_var: Var):
521521
lcompile("(fn [a] (try :b (finally (do (recur :a) :c))))")
522522

523523

524+
def test_named_anonymous_fn_recursion(ns_var: Var):
525+
code = """
526+
(let [compute-sum (fn sum [n]
527+
(if (operator/eq 0 n)
528+
0
529+
(operator/add n (sum (operator/sub n 1)))))]
530+
(compute-sum 5))
531+
"""
532+
assert 15 == lcompile(code)
533+
534+
code = """
535+
(let [compute-sum (fn sum
536+
([] 0)
537+
([n]
538+
(if (operator/eq 0 n)
539+
0
540+
(operator/add n (sum (operator/sub n 1))))))]
541+
(compute-sum 5))
542+
"""
543+
assert 15 == lcompile(code)
544+
545+
524546
def test_syntax_quoting(test_ns: str, ns_var: Var, resolver: reader.Resolver):
525547
code = """
526548
(def some-val \"some value!\")

0 commit comments

Comments
 (0)