Skip to content

Commit ab9713f

Browse files
committed
Implement depth-limited rewriting
1 parent 70862c9 commit ab9713f

File tree

3 files changed

+92
-31
lines changed

3 files changed

+92
-31
lines changed

Sources/TestingMacros/Support/AttributeDiscovery.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ struct AttributeInfo {
126126
var expr = argument.expression
127127
if argument.expression.tokens(viewMode: .sourceAccurate).map(\.tokenKind).contains(.keyword(.Self)) {
128128
let selfRemover = _SelfRemover(in: context)
129-
expr = selfRemover.rewrite(Syntax(argument.expression)).cast(ExprSyntax.self)
129+
expr = selfRemover.rewrite(Syntax(argument.expression), detach: true).cast(ExprSyntax.self)
130130
}
131131
return Argument(label: argument.label, expression: expr)
132132
}

Sources/TestingMacros/Support/ConditionArgumentParsing.swift

Lines changed: 90 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,12 @@ func removeParentheses(from expr: ExprSyntax) -> ExprSyntax? {
9191

9292
// MARK: - Inserting expression context callouts
9393

94+
/// The maximum value of `_rewriteDepth` allowed by `_rewrite()` before it will
95+
/// start bailing early.
96+
private let _maximumRewriteDepth = {
97+
Int.max // disable rewrite-limiting (need to evaluate possible heuristics)
98+
}()
99+
94100
/// A type that inserts calls to an `__ExpectationContext` instance into an
95101
/// expression's syntax tree.
96102
private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansionContext, M: FreestandingMacroExpansionSyntax {
@@ -123,6 +129,12 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
123129
super.init()
124130
}
125131

132+
/// The number of calls to `_rewrite()` made along the current node hierarchy.
133+
///
134+
/// This value is incremented with each call to `_rewrite()` and managed by
135+
/// `_visitChild()`.
136+
private var _rewriteDepth = 0
137+
126138
/// Rewrite a given syntax node by inserting a call to the expression context
127139
/// (or rather, its `callAsFunction(_:_:)` member).
128140
///
@@ -137,14 +149,27 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
137149
///
138150
/// - Returns: A rewritten copy of `node` that calls into the expression
139151
/// context when it is evaluated at runtime.
140-
private func _rewrite(_ node: some ExprSyntaxProtocol, originalWas originalNode: some ExprSyntaxProtocol, calling functionName: TokenSyntax? = nil, passing additionalArguments: [Argument] = []) -> ExprSyntax {
141-
guard rewrittenNodes.insert(Syntax(originalNode)).inserted else {
142-
// If this node has already been rewritten, we don't need to rewrite it
143-
// again. (Currently, this can only happen when expanding binary operators
144-
// which need a bit of extra help.)
145-
return ExprSyntax(node)
152+
private func _rewrite(
153+
_ node: @autoclosure () -> some ExprSyntaxProtocol,
154+
originalWas originalNode: @autoclosure () -> some ExprSyntaxProtocol,
155+
calling functionName: @autoclosure () -> TokenSyntax? = nil,
156+
passing additionalArguments: @autoclosure () -> [Argument] = []
157+
) -> ExprSyntax {
158+
_rewriteDepth += 1
159+
if _rewriteDepth > _maximumRewriteDepth {
160+
// At least 2 ancestors of this node have already been rewritten, so do
161+
// not recursively rewrite further. This is necessary to limit the added
162+
// exponentional complexity we're throwing at the type checker.
163+
return ExprSyntax(originalNode())
146164
}
147165

166+
// We're going to rewrite the node, so we'll evaluate the arguments now.
167+
let node = node()
168+
let originalNode = originalNode()
169+
let functionName = functionName()
170+
let additionalArguments = additionalArguments()
171+
rewrittenNodes.insert(Syntax(originalNode))
172+
148173
let calledExpr: ExprSyntax = if let functionName {
149174
ExprSyntax(MemberAccessExprSyntax(base: expressionContextNameExpr, name: functionName))
150175
} else {
@@ -200,6 +225,43 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
200225
_rewrite(node, originalWas: node, calling: functionName, passing: additionalArguments)
201226
}
202227

228+
/// Visit `node` as a child of another previously-visited node.
229+
///
230+
/// - Parameters:
231+
/// - node: The node to visit.
232+
///
233+
/// - Returns: `node`, or a modified copy thereof if `node` or a child node
234+
/// was rewritten.
235+
///
236+
/// Use this function instead of calling `visit(_:)` or `rewrite(_:detach:)`
237+
/// recursively.
238+
///
239+
/// This overload simply visits `node` and is used for nodes that cannot be
240+
/// rewritten directly (because they are not expressions.)
241+
@_disfavoredOverload
242+
private func _visitChild<S>(_ node: S) -> S where S: SyntaxProtocol {
243+
rewrite(node, detach: true).cast(S.self)
244+
}
245+
246+
/// Visit `node` as a child of another previously-visited node.
247+
///
248+
/// - Parameters:
249+
/// - node: The node to visit.
250+
///
251+
/// - Returns: `node`, or a modified copy thereof if `node` or a child node
252+
/// was rewritten.
253+
///
254+
/// Use this function instead of calling `visit(_:)` or `rewrite(_:detach:)`
255+
/// recursively.
256+
private func _visitChild(_ node: some ExprSyntaxProtocol) -> ExprSyntax {
257+
let oldRewriteDepth = _rewriteDepth
258+
defer {
259+
_rewriteDepth = oldRewriteDepth
260+
}
261+
262+
return rewrite(node, detach: true).cast(ExprSyntax.self)
263+
}
264+
203265
/// Whether or not the parent node of the given node is capable of containing
204266
/// a rewritten `DeclReferenceExprSyntax` instance.
205267
///
@@ -281,7 +343,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
281343
return _rewrite(
282344
TupleExprSyntax {
283345
for element in node.elements {
284-
visit(element).trimmed
346+
_visitChild(element).trimmed
285347
}
286348
},
287349
originalWas: node
@@ -302,28 +364,28 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
302364
// expressions can be directly extracted out.
303365
if _isParentOfDeclReferenceExprValidForRewriting(node) {
304366
return _rewrite(
305-
node.with(\.base, node.base.map(visit)),
367+
node.with(\.base, node.base.map(_visitChild)),
306368
originalWas: node
307369
)
308370
}
309371

310-
return ExprSyntax(node.with(\.base, node.base.map(visit)))
372+
return ExprSyntax(node.with(\.base, node.base.map(_visitChild)))
311373
}
312374

313375
override func visit(_ node: FunctionCallExprSyntax) -> ExprSyntax {
314376
_rewrite(
315377
node
316-
.with(\.calledExpression, visit(node.calledExpression))
317-
.with(\.arguments, visit(node.arguments)),
378+
.with(\.calledExpression, _visitChild(node.calledExpression))
379+
.with(\.arguments, _visitChild(node.arguments)),
318380
originalWas: node
319381
)
320382
}
321383

322384
override func visit(_ node: SubscriptCallExprSyntax) -> ExprSyntax {
323385
_rewrite(
324386
node
325-
.with(\.calledExpression, visit(node.calledExpression))
326-
.with(\.arguments, visit(node.arguments)),
387+
.with(\.calledExpression, _visitChild(node.calledExpression))
388+
.with(\.arguments, _visitChild(node.arguments)),
327389
originalWas: node
328390
)
329391
}
@@ -355,7 +417,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
355417

356418
return _rewrite(
357419
node
358-
.with(\.expression, visit(node.expression)),
420+
.with(\.expression, _visitChild(node.expression)),
359421
originalWas: node
360422
)
361423
}
@@ -377,18 +439,18 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
377439
originalWas: node,
378440
calling: .identifier("__cmp"),
379441
passing: [
380-
Argument(expression: visit(node.leftOperand)),
442+
Argument(expression: _visitChild(node.leftOperand)),
381443
Argument(expression: node.leftOperand.expressionID(rootedAt: effectiveRootNode)),
382-
Argument(expression: visit(node.rightOperand)),
444+
Argument(expression: _visitChild(node.rightOperand)),
383445
Argument(expression: node.rightOperand.expressionID(rootedAt: effectiveRootNode))
384446
]
385447
)
386448
}
387449

388450
return _rewrite(
389451
node
390-
.with(\.leftOperand, visit(node.leftOperand))
391-
.with(\.rightOperand, visit(node.rightOperand)),
452+
.with(\.leftOperand, _visitChild(node.leftOperand))
453+
.with(\.rightOperand, _visitChild(node.rightOperand)),
392454
originalWas: node
393455
)
394456
}
@@ -399,12 +461,11 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
399461
// `inout`, so it should be sufficient to capture it in a `defer` statement
400462
// that runs after the expression is evaluated.
401463

402-
let teardownItem = CodeBlockItemSyntax(
403-
item: .expr(
404-
_rewrite(node.expression, calling: .identifier("__inoutAfter"))
405-
)
406-
)
407-
teardownItems.append(teardownItem)
464+
let rewrittenExpr = _rewrite(node.expression, calling: .identifier("__inoutAfter"))
465+
if rewrittenExpr != ExprSyntax(node.expression) {
466+
let teardownItem = CodeBlockItemSyntax(item: .expr(rewrittenExpr))
467+
teardownItems.append(teardownItem)
468+
}
408469

409470
// The argument should not be expanded in-place as we can't return an
410471
// argument passed `inout` and expect it to remain semantically correct.
@@ -427,7 +488,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
427488
rewrittenNodes.insert(Syntax(type))
428489

429490
return _rewrite(
430-
visit(valueExpr).trimmed,
491+
_visitChild(valueExpr).trimmed,
431492
originalWas: originalNode,
432493
calling: .identifier("__\(isAsKeyword)"),
433494
passing: [
@@ -503,7 +564,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
503564
node.with(
504565
\.elements, ArrayElementListSyntax {
505566
for element in node.elements {
506-
ArrayElementSyntax(expression: visit(element.expression).trimmed)
567+
ArrayElementSyntax(expression: _visitChild(element.expression).trimmed)
507568
}
508569
}
509570
),
@@ -520,7 +581,7 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
520581
\.content, .elements(
521582
DictionaryElementListSyntax {
522583
for element in elements {
523-
DictionaryElementSyntax(key: visit(element.key).trimmed, value: visit(element.value).trimmed)
584+
DictionaryElementSyntax(key: _visitChild(element.key).trimmed, value: _visitChild(element.value).trimmed)
524585
}
525586
}
526587
)
@@ -570,7 +631,7 @@ extension ConditionMacro {
570631
_diagnoseTrivialBooleanValue(from: ExprSyntax(node), for: macro, in: context)
571632

572633
let contextInserter = _ContextInserter(in: context, for: macro, rootedAt: Syntax(effectiveRootNode), expressionContextName: expressionContextName)
573-
var expandedExpr = contextInserter.rewrite(node).cast(ExprSyntax.self)
634+
var expandedExpr = contextInserter.rewrite(node, detach: true).cast(ExprSyntax.self)
574635
let rewrittenNodes = contextInserter.rewrittenNodes
575636

576637
// Insert additional effect keywords/thunks as needed.
@@ -606,7 +667,7 @@ extension ConditionMacro {
606667
var captureList: ClosureCaptureClauseSyntax?
607668
do {
608669
let dollarIDReplacer = _DollarIdentifierReplacer()
609-
codeBlockItems = dollarIDReplacer.rewrite(codeBlockItems).cast(CodeBlockItemListSyntax.self)
670+
codeBlockItems = dollarIDReplacer.rewrite(codeBlockItems, detach: true).cast(CodeBlockItemListSyntax.self)
610671
if !dollarIDReplacer.dollarIdentifierTokenKinds.isEmpty {
611672
let dollarIdentifierTokens = dollarIDReplacer.dollarIdentifierTokenKinds.map { tokenKind in
612673
TokenSyntax(tokenKind, presence: .present)

Tests/TestingTests/Support/CartesianProductTests.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ struct CartesianProductTests {
3333
#expect(product.underestimatedCount == c1.underestimatedCount * c2.underestimatedCount)
3434
let productCount = Array(product).count
3535
#expect(productCount == c1.count * c2.count)
36-
#expect(productCount == (26 * 100) as Int)
36+
#expect(productCount == 26 * 100)
3737
}
3838

3939
@Test("First element is correct")

0 commit comments

Comments
 (0)