Skip to content

Commit 0f65564

Browse files
authored
Refactor parsing of function CodeInput using ast (#86)
With `ast` we have to use less regex tricks and can rely that the function is properly parsed. Docstring is now also considering linebreaks at the beginning and end. Annotations of input arguments are now supported. Default arguments are now supported. We even support expressions as default arguments e.g lambda functions. Arbitrary keyword arguments are now supported.
1 parent 934a46a commit 0f65564

File tree

2 files changed

+123
-17
lines changed

2 files changed

+123
-17
lines changed

src/scwidgets/code/_widget_code_input.py

Lines changed: 82 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
1+
import ast
12
import inspect
23
import re
34
import sys
5+
import textwrap
46
import traceback
57
import types
68
import warnings
79
from functools import wraps
8-
from typing import List, Optional
10+
from typing import List, Optional, Tuple
911

1012
from widget_code_input import WidgetCodeInput
1113
from widget_code_input.utils import (
@@ -20,6 +22,18 @@
2022
class CodeInput(WidgetCodeInput):
2123
"""
2224
Small wrapper around WidgetCodeInput that controls the output
25+
26+
:param function: We can automatically parse the function. Note that during
27+
parsing the source code might be differently formatted and certain
28+
python functionalities are not formatted. If you notice undesired
29+
changes by the parsing, please directly specify the function as string
30+
using the other parameters.
31+
:param function_name: The name of the function
32+
:param function_paramaters: The parameters as continuous string as specified in
33+
the signature of the function. e.g for `foo(x, y = 5)` it should be
34+
`"x, y = 5"`
35+
:param docstring: The docstring of the function
36+
:param function_body: The function definition without indentation
2337
"""
2438

2539
valid_code_themes = ["nord", "solarizedLight", "basicLight"]
@@ -38,13 +52,15 @@ def __init__(
3852
function.__name__ if function_name is None else function_name
3953
)
4054
function_parameters = (
41-
", ".join(inspect.getfullargspec(function).args)
55+
self.get_function_parameters(function)
4256
if function_parameters is None
4357
else function_parameters
4458
)
45-
docstring = inspect.getdoc(function) if docstring is None else docstring
59+
docstring = self.get_docstring(function) if docstring is None else docstring
4660
function_body = (
47-
self.get_code(function) if function_body is None else function_body
61+
self.get_function_body(function)
62+
if function_body is None
63+
else function_body
4864
)
4965

5066
# default parameters from WidgetCodeInput
@@ -105,8 +121,68 @@ def function_parameters_name(self) -> List[str]:
105121
return self.function_parameters.replace(",", "").split(" ")
106122

107123
@staticmethod
108-
def get_code(func: types.FunctionType) -> str:
109-
source_lines, _ = inspect.getsourcelines(func)
124+
def get_docstring(function: types.FunctionType) -> str:
125+
docstring = function.__doc__
126+
return "" if docstring is None else textwrap.dedent(docstring)
127+
128+
@staticmethod
129+
def _get_function_source_and_def(
130+
function: types.FunctionType,
131+
) -> Tuple[str, ast.FunctionDef]:
132+
function_source = inspect.getsource(function)
133+
function_source = textwrap.dedent(function_source)
134+
module = ast.parse(function_source)
135+
if len(module.body) != 1:
136+
raise ValueError(
137+
f"Expected code with one function definition but found {module.body}"
138+
)
139+
function_definition = module.body[0]
140+
if not isinstance(function_definition, ast.FunctionDef):
141+
raise ValueError(
142+
f"While parsing code found {module.body[0]}"
143+
" but only ast.FunctionDef is supported."
144+
)
145+
return function_source, function_definition
146+
147+
@staticmethod
148+
def get_function_parameters(function: types.FunctionType) -> str:
149+
function_parameters = []
150+
function_source, function_definition = CodeInput._get_function_source_and_def(
151+
function
152+
)
153+
idx_start_defaults = len(function_definition.args.args) - len(
154+
function_definition.args.defaults
155+
)
156+
for i, arg in enumerate(function_definition.args.args):
157+
function_parameter = ast.get_source_segment(function_source, arg)
158+
# Following PEP 8 in formatting
159+
if arg.annotation:
160+
annotation = function_parameter = ast.get_source_segment(
161+
function_source, arg.annotation
162+
)
163+
function_parameter = f"{arg.arg}: {annotation}"
164+
else:
165+
function_parameter = f"{arg.arg}"
166+
if i >= idx_start_defaults:
167+
default_val = ast.get_source_segment(
168+
function_source,
169+
function_definition.args.defaults[i - idx_start_defaults],
170+
)
171+
# Following PEP 8 in formatting
172+
if arg.annotation:
173+
function_parameter = f"{function_parameter} = {default_val}"
174+
else:
175+
function_parameter = f"{function_parameter}={default_val}"
176+
function_parameters.append(function_parameter)
177+
178+
if function_definition.args.kwarg is not None:
179+
function_parameters.append(f"**{function_definition.args.kwarg.arg}")
180+
181+
return ", ".join(function_parameters)
182+
183+
@staticmethod
184+
def get_function_body(function: types.FunctionType) -> str:
185+
source_lines, _ = inspect.getsourcelines(function)
110186

111187
found_def = False
112188
def_index = 0

tests/test_code.py

Lines changed: 41 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,15 @@ def mock_function_0():
2323
return 0
2424

2525
@staticmethod
26-
def mock_function_1(x, y):
26+
def mock_function_1(x: int, y: int = 5, z=lambda: 0):
2727
"""
2828
This is an example function.
2929
It adds two numbers.
3030
"""
3131
if x > 0:
3232
return x + y
3333
else:
34-
return y
34+
return y + z()
3535

3636
@staticmethod
3737
def mock_function_2(x):
@@ -53,26 +53,56 @@ def x():
5353
@staticmethod
5454
def mock_function_6(x: List[int]) -> List[int]:
5555
return x
56+
57+
@staticmethod
58+
def mock_function_7(x, **kwargs):
59+
return kwargs
5660
# fmt: on
5761

58-
def test_get_code(self):
62+
def test_get_function_paramaters(self):
63+
assert (
64+
CodeInput.get_function_parameters(self.mock_function_1)
65+
== "x: int, y: int = 5, z=lambda: 0"
66+
)
67+
assert CodeInput.get_function_parameters(self.mock_function_2) == "x"
68+
assert CodeInput.get_function_parameters(self.mock_function_6) == "x: List[int]"
69+
assert CodeInput.get_function_parameters(self.mock_function_7) == "x, **kwargs"
70+
71+
def test_get_docstring(self):
72+
assert (
73+
CodeInput.get_docstring(self.mock_function_1)
74+
== "\nThis is an example function.\nIt adds two numbers.\n"
75+
)
76+
assert (
77+
CodeInput.get_docstring(self.mock_function_2)
78+
== "This is an example function. It adds two numbers."
79+
)
80+
assert (
81+
CodeInput.get_docstring(self.mock_function_2)
82+
== "This is an example function. It adds two numbers."
83+
)
84+
85+
def test_get_function_body(self):
86+
assert (
87+
CodeInput.get_function_body(self.mock_function_1)
88+
== "if x > 0:\n return x + y\nelse:\n return y + z()\n"
89+
)
90+
assert CodeInput.get_function_body(self.mock_function_2) == "return x\n"
91+
assert CodeInput.get_function_body(self.mock_function_3) == "return x\n"
5992
assert (
60-
CodeInput.get_code(self.mock_function_1)
61-
== "if x > 0:\n return x + y\nelse:\n return y\n"
93+
CodeInput.get_function_body(self.mock_function_4)
94+
== "return x # noqa: E702\n"
6295
)
63-
assert CodeInput.get_code(self.mock_function_2) == "return x\n"
64-
assert CodeInput.get_code(self.mock_function_3) == "return x\n"
65-
assert CodeInput.get_code(self.mock_function_4) == "return x # noqa: E702\n"
6696
assert (
67-
CodeInput.get_code(self.mock_function_5)
97+
CodeInput.get_function_body(self.mock_function_5)
6898
== "def x():\n return 5\nreturn x()\n"
6999
)
70-
assert CodeInput.get_code(self.mock_function_6) == "return x\n"
100+
assert CodeInput.get_function_body(self.mock_function_6) == "return x\n"
71101
with pytest.raises(
72102
ValueError,
73103
match=r"Did not find any def definition. .*",
74104
):
75-
CodeInput.get_code(lambda x: x)
105+
CodeInput.get_function_body(lambda x: x)
76106

77107
def test_invalid_code_theme_raises_error(self):
78108
with pytest.raises(

0 commit comments

Comments
 (0)