Skip to content

Commit 14aaf64

Browse files
committed
refactor: ensure consistency, migrate, sanitise
migrate from aggregating test blocks in pure .rs file testing, to a proper project structure to manage dependancies properly(tokio,...)
1 parent 00be2c4 commit 14aaf64

File tree

4 files changed

+29955
-29909
lines changed

4 files changed

+29955
-29909
lines changed

exts/rust-code-runner/__init__.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,26 +1,49 @@
11
from . import rust_examples_aggregate
22
from . import rustc
33
import os
4+
from pathlib import Path
45

56
def setup(app):
67

7-
app.output_rust_file = "build/rust-code-blocks/generated.rs"
88

9-
# create build dir
10-
if not os.path.exists("build/rust-code-blocks"):
11-
os.makedirs("build/rust-code-blocks")
12-
if os.path.isfile(app.output_rust_file):
13-
with open(app.output_rust_file, 'w'):
14-
pass
9+
# Define output directory
10+
app.output_rust = "build/rust-code-blocks/"
11+
12+
# Ensure the src directory exists
13+
base_dir = Path(app.output_rust)
14+
src_dir = base_dir / "src"
15+
src_dir.mkdir(parents=True, exist_ok=True)
16+
17+
18+
# Write Cargo.toml with required dependencies
19+
cargo_toml = base_dir / "Cargo.toml"
20+
cargo_toml.write_text(
21+
"""[package]
22+
name = "sc_generated_tests"
23+
version = "0.1.0"
24+
edition = "2024"
25+
26+
[dependencies]
27+
tokio = { version = "1", features = ["macros", "rt-multi-thread"] }
28+
""",
29+
encoding="utf-8",
30+
)
31+
32+
33+
print(f"Setup complete in '{base_dir.resolve()}'")
1534

1635
# we hook into 'source-read' because data is mutable at this point and easier to parse
1736
# and it also makes this extension indepandant from `needs`.
18-
#
19-
app.connect('source-read', rust_examples_aggregate.preprocess_rst_for_rust_code)
20-
21-
if app.config.test_rust_blocks:
37+
if not app.config.test_rust_blocks:
38+
# empty lib.rs on every run (incremental build is not supported)
39+
with open(app.output_rust + "src/lib.rs", "w", encoding="utf-8"):
40+
pass
41+
app.connect('source-read', rust_examples_aggregate.preprocess_rst_for_rust_code)
42+
else:
2243
app.connect('build-finished', rustc.check_rust_test_errors)
44+
2345
return {
2446
'version': '0.1',
2547
'parallel_read_safe': False,
48+
'parallel_write_safe': False,
2649
}

exts/rust-code-runner/rust_examples_aggregate.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -64,23 +64,51 @@ def replacer(match):
6464
modified_text = code_block_re.sub(replacer, source_text)
6565
return modified_text
6666

67+
import re
68+
69+
def sanitize_code_blocks(code_blocks):
70+
"""
71+
Removes unwanted attributes from each Rust code block:
72+
- `#[macro_export]` (to avoid exported-macro conflicts)
73+
- `#[tokio::main]` (to keep compilation as a library/test)
74+
"""
75+
patterns = [
76+
r'\s*#\s*\[macro_export\]',
77+
r'\s*#\s*\[tokio::main\]'
78+
]
79+
sanitized = []
80+
for block in code_blocks:
81+
lines = block.splitlines()
82+
cleaned = [
83+
line for line in lines
84+
if not any(re.match(pat, line) for pat in patterns)
85+
]
86+
sanitized.append("\n".join(cleaned))
87+
return sanitized
6788

6889
def preprocess_rst_for_rust_code(app, docname, source):
6990

7091
original_content = source[0]
7192
code_blocks = extract_code_blocks(original_content)
93+
code_blocks = sanitize_code_blocks(code_blocks)
7294
modified_content = remove_hidden_blocks_from_document(original_content)
7395
source[0] = modified_content
7496

7597
# print(f"Original content length: {len(original_content)}")
7698
# print(f"Extracted {len(code_blocks)} code blocks")
7799

78100
safe_docname = docname.replace("/", "_").replace("-", "_")
79-
with open(app.output_rust_file, "a", encoding="utf-8") as f:
80-
for i, block in enumerate(code_blocks, start=1):
81-
f.write(f"// ==== Code Block {i} ====\n")
82-
f.write("#[test]\n")
83-
f.write(f"fn test_block_{safe_docname}_{i}() {{\n")
84-
for line in block.splitlines():
85-
f.write(f" {line}\n")
86-
f.write("}\n\n")
101+
try:
102+
with open(app.output_rust + "src/lib.rs", "a", encoding="utf-8") as f:
103+
for i, block in enumerate(code_blocks, start=1):
104+
f.write(f"// ==== Code Block {i} ====\n")
105+
f.write(f"mod code_block_{i}_{safe_docname} {{\n")
106+
f.write(" #[test]\n")
107+
f.write(f" fn test_block_{safe_docname}_{i}() {{\n")
108+
for line in block.splitlines():
109+
f.write(f" {line}\n") # extra indent for the module
110+
f.write(" }\n") # close fn
111+
f.write("}\n\n") # close mod
112+
except Exception as e:
113+
print("Error writing file:", e)
114+

exts/rust-code-runner/rustc.py

Lines changed: 47 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import json
2+
import sys
3+
import os
24
import subprocess
35

46
def print_code_snippet(file_path, line_num, context=3):
@@ -31,52 +33,51 @@ def print_code_snippet(file_path, line_num, context=3):
3133
print(f"Could not read file {file_path}: {e}")
3234

3335

34-
def parse_rustc_json(stderr: str, file_path):
35-
"""
36-
Parses rustc's JSON output and prints only the first error with a single snippet.
37-
38-
Args:
39-
stderr (str): JSON-formatted stderr output from rustc.
40-
file_path: Path to the Rust file.
36+
import json
4137

42-
Returns:
43-
None
38+
def parse_cargo_errors(output: str, output_rust):
39+
"""
40+
Parses Cargo’s JSON output and prints only the first compiler error it finds.
41+
Ignores warnings and notes entirely.
4442
"""
45-
for line in stderr.splitlines():
43+
for line in output.splitlines():
4644
line = line.strip()
4745
if not line:
4846
continue
4947

5048
try:
51-
diagnostic = json.loads(line)
49+
rec = json.loads(line)
5250
except json.JSONDecodeError:
5351
continue
5452

55-
if diagnostic.get("$message_type") != "diagnostic":
53+
# Only look at compiler messages
54+
if rec.get("reason") != "compiler-message":
5655
continue
5756

58-
if diagnostic.get("level") != "error":
57+
msg = rec["message"]
58+
# Skip anything that isn't an error
59+
if msg.get("level") != "error":
5960
continue
6061

61-
message = diagnostic.get("message", "")
62-
spans = diagnostic.get("spans", [])
62+
text = msg.get("message", "")
63+
spans = msg.get("spans", [])
6364

64-
# Prefer the primary span in the current file
65+
# Print the high-level error first
66+
print(f"\nerror: {text}")
67+
68+
# Then try to show its primary location
6569
for span in spans:
66-
if span.get("is_primary") and span["file_name"] == file_path:
67-
line_num = span["line_start"]
70+
if span.get("is_primary"):
71+
file = span.get("file_name")
72+
line_start = span.get("line_start")
6873
label = span.get("label", "")
69-
print(f"error: line {line_num}: {message}")
70-
if label:
71-
print(f"--> {label}")
72-
print("=" * 25)
73-
snippet = print_code_snippet(file_path, line_num, context=8)
74-
print(snippet)
75-
print("=" * 25)
76-
return # we return because we only print the first error--in json format there can be multiple error messages(primary and non primary) for 1 error-- if you want to see them comment this line.
77-
78-
# fallback: print the error message if no span in the current file
79-
print(f"error: {message}")
74+
print(f" --> {file}:{line_start} {label}".rstrip(), file= sys.stderr)
75+
# and a snippet
76+
snippet = print_code_snippet(output_rust + file, line_start, context=5)
77+
print("\n" + snippet, file = sys.stderr)
78+
break
79+
80+
# Stop after the first error
8081
return
8182

8283
def check_rust_test_errors(app, exception):
@@ -86,31 +87,31 @@ def check_rust_test_errors(app, exception):
8687
This function is connected to the Sphinx build lifecycle and is executed after the build finishes.
8788
It invokes `rustc` in test mode on the generated Rust file and reports any compilation or test-related
8889
errors.
89-
90-
Args:
91-
app: The Sphinx application object. Must have an `output_rust_file` attribute containing
92-
the path to the generated Rust source file.
93-
exception: Exception raised during the build process, or None if the build completed successfully.
9490
"""
95-
rs_path = app.output_rust_file
91+
rs_path = app.output_rust
92+
cargo_toml_path = os.path.join(rs_path, "Cargo.toml")
9693
# Run the Rust compiler in test mode with JSON error output format.
9794
# capturing stdout and stderr as text.
9895
result = subprocess.run(
99-
["rustc", "--test", "--edition=2024", "--error-format=json", "--emit=metadata", rs_path],
100-
# --emit=metadata or else rustc will produce a binary ./generated
96+
[
97+
"cargo",
98+
"test",
99+
"--message-format=json",
100+
"--manifest-path",
101+
cargo_toml_path
102+
],
101103
capture_output=True,
102-
text=True
104+
text=True,
103105
)
104106

105107
if result.returncode != 0:
106-
print("--- rustc Errors/Warnings ---")
107-
parse_rustc_json(result.stderr, app.output_rust_file)
108+
print("\033[31m--- Cargo test errors ---\033[0m")
109+
parse_cargo_errors(result.stdout, app.output_rust) # parse stdout JSON lines
108110
# print("--- rustc Output ---")
109111
# print(result.stdout)
110-
111112
else:
112-
print("--- rustc Output ---")
113-
print(result.stdout)
114-
if result.stderr:
115-
print("\n\n--- rustc Warnings---")
116-
print(result.stderr)
113+
print("\033[1;32mAll tests succeeded\033[0m") # ANSI magic
114+
# print(result.stdout)
115+
# if result.stderr:
116+
# print("\n\n--- rustc Warnings ---")
117+
# print(result.stderr)

0 commit comments

Comments
 (0)