Skip to content

Commit 3002d23

Browse files
committed
Add advanced source transformations to reduce type checking overhead
The new 'munge' module performs transformations on the source code. It uses the AST (abstract syntax tree) representation of Python code to recognize some idioms such as `if STATIC_TYPING:` and transforms them into alternatives that have zero overhead in mpy-compiled files (e.g., `if STATIC_TYPING:` is transformed into `if 0:`, which is eliminated at compile time due to mpy-cross constant-propagation and dead branch elimination) The code assumes the input file is black-formatted. In particular, it would malfunction if an if-statement and its body are on the same line: `if STATIC_TYPING: print("boo")` would be incorrectly munged.
1 parent 1f3c5bd commit 3002d23

File tree

8 files changed

+228
-30
lines changed

8 files changed

+228
-30
lines changed

.github/workflows/build.yml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@ jobs:
4848
git clone --recurse-submodules https://github.com/adafruit/CircuitPython_Community_Bundle.git
4949
cd CircuitPython_Community_Bundle
5050
circuitpython-build-bundles --filename_prefix test-bundle --library_location libraries --library_depth 2
51+
- name: Munge tests
52+
run: pytest
5153
- name: Build Python package
5254
run: |
5355
pip install --upgrade setuptools wheel twine readme_renderer testresources

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,4 @@ version.py
1010
.env/*
1111
.DS_Store
1212
.idea/*
13+
testcases/*.out

circuitpython_build_tools/build.py

Lines changed: 18 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,8 @@
3636
import subprocess
3737
import tempfile
3838

39+
from .munge import munge
40+
3941
# pyproject.toml `py_modules` values that are incorrect. These should all have PRs filed!
4042
# and should be removed when the fixed version is incorporated in its respective bundle.
4143

@@ -170,16 +172,6 @@ def mpy_cross(mpy_cross_filename, circuitpython_tag, quiet=False):
170172

171173
shutil.copy("build_deps/circuitpython/mpy-cross/mpy-cross", mpy_cross_filename)
172174

173-
def _munge_to_temp(original_path, temp_file, library_version):
174-
with open(original_path, "r", encoding="utf-8") as original_file:
175-
for line in original_file:
176-
line = line.strip("\n")
177-
if line.startswith("__version__"):
178-
line = line.replace("0.0.0-auto.0", library_version)
179-
line = line.replace("0.0.0+auto.0", library_version)
180-
print(line, file=temp_file)
181-
temp_file.flush()
182-
183175
def get_package_info(library_path, package_folder_prefix):
184176
lib_path = pathlib.Path(library_path)
185177
parent_idx = len(lib_path.parts)
@@ -289,25 +281,22 @@ def library(library_path, output_directory, package_folder_prefix,
289281
full_path = os.path.join(library_path, filename)
290282
output_file = output_directory / filename.relative_to(library_path)
291283
if filename.suffix == ".py":
292-
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as temp_file:
293-
temp_file_name = temp_file.name
294-
try:
295-
_munge_to_temp(full_path, temp_file, library_version)
296-
temp_file.close()
297-
if mpy_cross and os.stat(temp_file.name).st_size != 0:
298-
output_file = output_file.with_suffix(".mpy")
299-
mpy_success = subprocess.call([
300-
mpy_cross,
301-
"-o", output_file,
302-
"-s", str(filename.relative_to(library_path)),
303-
temp_file.name
304-
])
305-
if mpy_success != 0:
306-
raise RuntimeError("mpy-cross failed on", full_path)
307-
else:
308-
shutil.copyfile(temp_file_name, output_file)
309-
finally:
310-
os.remove(temp_file_name)
284+
content = munge(full_path, library_version)
285+
if mpy_cross and content:
286+
# TODO: Once 8.x bundles are no longer built, switch to
287+
# sending mpy-cross the code on stdin instead of via
288+
# temporary file (supports the "-" input argument)
289+
with tempfile.NamedTemporaryFile(delete=False, mode="w+") as temp_file:
290+
temp_file.write(content)
291+
temp_file.flush()
292+
subprocess.check_output([
293+
mpy_cross,
294+
"-o", output_file.with_suffix(".mpy"),
295+
"-s", str(filename.relative_to(library_path)),
296+
temp_file.name
297+
], input=content.encode('utf-8'))
298+
else:
299+
output_file.write_text(content, encoding="utf-8")
311300
else:
312301
shutil.copyfile(full_path, output_file)
313302

circuitpython_build_tools/munge.py

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
# The MIT License (MIT)
2+
#
3+
# Copyright (c) 2024 Jeff Epler for Adafruit Industries
4+
#
5+
# Permission is hereby granted, free of charge, to any person obtaining a copy
6+
# of this software and associated documentation files (the "Software"), to deal
7+
# in the Software without restriction, including without limitation the rights
8+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
# copies of the Software, and to permit persons to whom the Software is
10+
# furnished to do so, subject to the following conditions:
11+
#
12+
# The above copyright notice and this permission notice shall be included in
13+
# all copies or substantial portions of the Software.
14+
#
15+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
21+
# THE SOFTWARE.
22+
23+
# Filter program removes some code patterns introduced by type checking,
24+
# to move towards zero overhead static typing in circuitpython libraries
25+
#
26+
# Recognized:
27+
# from __future__ import ... -- eliminated
28+
# try: import typing -- eliminated, but first except: preserved
29+
# try: from typing import ... -- eliminated, but first except: preserved
30+
# if STATIC_TYPING: -- transformed to 'if 0:'
31+
# if sys.implementation_name... -- transformed to unconditional if
32+
# __version__ = ... -- set to library version string
33+
#
34+
# mpy-cross does constant propagation and dead branch elimination of
35+
# 'if 0:' and 'if 1:'
36+
#
37+
# Depends on the file being black-formatted!
38+
39+
import pathlib
40+
import sys
41+
import ast
42+
43+
VERBOSE = 0
44+
45+
# The canonical spelling of this test...
46+
sys_implementation_is_circuitpython = ast.unparse(ast.parse('sys.implementation.name == "circuitpython"'))
47+
sys_implementation_not_circuitpython = ast.unparse(ast.parse('not sys.implementation.name == "circuitpython"'))
48+
sys_implementation_not_circuitpython2 = ast.unparse(ast.parse('sys.implementation.name != "circuitpython"'))
49+
50+
def munge(src: pathlib.Path|str, version_str: str) -> str:
51+
path = pathlib.Path(src)
52+
replacements = {}
53+
54+
def replace(line, new):
55+
if VERBOSE:
56+
replacements[line] = f"{new:<40s} ### {lines[line]}"
57+
else:
58+
replacements[line] = new
59+
60+
def blank_range(node):
61+
for i in range(node.lineno, node.end_lineno+1):
62+
replace(i, "")
63+
64+
def unblank_range(node):
65+
for i in range(node.lineno, node.end_lineno+1):
66+
replacements.pop(i, None)
67+
68+
def imports_from_typing(node):
69+
if isinstance(node, ast.Import) and node.names[0].name == 'typing':
70+
return True
71+
if isinstance(node, ast.ImportFrom) and node.module == 'typing':
72+
return True
73+
return False
74+
75+
def process_statement(node):
76+
# filter out 'from future import...'
77+
if isinstance(node, ast.ImportFrom):
78+
if node.module == '__future__':
79+
blank_range(node)
80+
# filter out 'try: import typing...'
81+
# but preserve the first 'except:' or 'except ImportError'
82+
elif isinstance(node, ast.Try):
83+
b = node.body[0]
84+
if imports_from_typing(node.body[0]):
85+
blank_range(node)
86+
for h in node.handlers:
87+
if h.type is None or ast.unparse(h.type) == 'ImportError' or ast.unparse(h.type) == 'Exception':
88+
unblank_range(h)
89+
replace(h.lineno, 'if 1:')
90+
break
91+
return
92+
elif isinstance(node, ast.If):
93+
node_test = ast.unparse(node.test)
94+
# return the statements in the 'if' branch of 'if sys.implementation...: ...'
95+
if node_test == sys_implementation_is_circuitpython:
96+
replace(node.lineno, 'if 1:')
97+
# return the statements in the 'else' branch of 'if sys.implementation...: ...'
98+
elif node_test == sys_implementation_not_circuitpython or node_test == sys_implementation_not_circuitpython2:
99+
replace(node.lineno, 'if 0:')
100+
# return the statements in the else branch of 'if TYPE_CHECKING: ...'
101+
elif node_test == 'TYPE_CHECKING':
102+
replace(node.lineno, 'if 0:')
103+
elif isinstance(node, ast.Assign) and isinstance(node.targets[0], ast.Name) and node.targets[0].id == '__version__':
104+
replace(node.lineno, f"__version__ = \"{version_str}\"")
105+
106+
content = pathlib.Path(path).read_text(encoding="utf-8")
107+
# Insert a blank line 0 because ast line numbers are 1-based
108+
lines = [''] + content.rstrip().split('\n')
109+
a = ast.parse(content, path.name)
110+
111+
for node in a.body: process_statement(node)
112+
113+
result = []
114+
for i in range(1, len(lines)):
115+
result.append(replacements.get(i, lines[i]))
116+
117+
return "\n".join(result) + "\n"

requirements.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
Click
2+
pytest
23
requests
34
semver
4-
wheel
55
tomli; python_version < "3.11"
6+
wheel

testcases/test1.exp

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
2+
3+
4+
5+
if 1:
6+
pass
7+
8+
9+
10+
if 1:
11+
pass
12+
13+
14+
15+
16+
if 1:
17+
pass
18+
19+
20+
21+
if 1:
22+
pass
23+
24+
__version__ = "1.2.3"
25+
26+
if 1:
27+
print("is circuitpython")
28+
29+
if 0:
30+
print("not circuitpython (1)")
31+
32+
if 0:
33+
print("not circuitpython (2)")

testcases/test1.py

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
from __future__ import annotation
2+
3+
try:
4+
from typing import TYPE_CHECKING
5+
except ImportError:
6+
pass
7+
8+
try:
9+
from typing import TYPE_CHECKING as T
10+
except ImportError:
11+
pass
12+
13+
14+
try:
15+
import typing
16+
except:
17+
pass
18+
19+
try:
20+
import typing as T
21+
except:
22+
pass
23+
24+
__version__ = "0.0.0-auto"
25+
26+
if sys.implementation.name == "circuitpython":
27+
print("is circuitpython")
28+
29+
if sys.implementation.name != "circuitpython":
30+
print("not circuitpython (1)")
31+
32+
if not sys.implementation.name == "circuitpython":
33+
print("not circuitpython (2)")

tests/test_munge.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
import sys, pathlib
2+
import pytest
3+
4+
top = pathlib.Path(__file__).parent.parent
5+
sys.path.insert(0, str(top))
6+
7+
from circuitpython_build_tools.munge import munge
8+
9+
@pytest.mark.parametrize("test_path", top.glob("testcases/*.py"))
10+
def test_munge(test_path):
11+
result_path = test_path.with_suffix(".out")
12+
result_path.unlink(missing_ok = True)
13+
14+
result_content = munge(test_path, "1.2.3")
15+
result_path.write_text(result_content, encoding="utf-8")
16+
17+
expected_path = test_path.with_suffix(".exp")
18+
expected_content = expected_path.read_text(encoding="utf-8")
19+
20+
assert result_content == expected_content
21+
22+
result_path.unlink()

0 commit comments

Comments
 (0)