Skip to content

Commit 86e823d

Browse files
committed
Add inout arg support
1 parent d19a8b7 commit 86e823d

File tree

8 files changed

+136
-26
lines changed

8 files changed

+136
-26
lines changed

Sources/Testing/Expectations/ExpectationContext.swift

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -198,6 +198,20 @@ extension __ExpectationContext {
198198
return value
199199
}
200200
#endif
201+
202+
/// Capture information about a value passed `inout` to a function call after
203+
/// the function has returned.
204+
///
205+
/// - Parameters:
206+
/// - value: The value that was passed `inout` (i.e. with the `&` operator.)
207+
/// - id: A value that uniquely identifies the represented expression in the
208+
/// context of the expectation currently being evaluated.
209+
///
210+
/// - Warning: This function is used to implement the `#expect()` and
211+
/// `#require()` macros. Do not call it directly.
212+
public mutating func __inoutAfter<T>(_ value: T, _ id: __ExpressionID) {
213+
runtimeValues[id] = { Expression.Value(reflecting: value, timing: .after) }
214+
}
201215
}
202216

203217
// MARK: - Collection comparison and diffing

Sources/Testing/SourceAttribution/Expression.swift

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,21 @@ public struct __Expression: Sendable {
6969
/// Information about the type of this value.
7070
public var typeInfo: TypeInfo
7171

72+
/// The timing when a runtime value was captured.
73+
@_spi(Experimental)
74+
public enum Timing: String, Sendable {
75+
/// The value was captured after the containing expression was evaluated.
76+
case after
77+
}
78+
79+
/// When the value represented by this instance was captured.
80+
///
81+
/// The value of this property is typically `nil`. It may be set to a
82+
/// non-`nil` value if this instance represents some `inout` argument passed
83+
/// to a function with the `&` operator.
84+
@_spi(Experimental)
85+
public var timing: Timing?
86+
7287
/// The label associated with this value, if any.
7388
///
7489
/// For non-child instances, or for child instances of members who do not
@@ -92,9 +107,10 @@ public struct __Expression: Sendable {
92107
///
93108
/// - Parameters:
94109
/// - subject: The subject this instance should describe.
95-
init(reflecting subject: Any) {
110+
/// - timing: When the value represented by this instance was captured.
111+
init(reflecting subject: Any, timing: Timing? = nil) {
96112
var seenObjects: [ObjectIdentifier: AnyObject] = [:]
97-
self.init(_reflecting: subject, label: nil, seenObjects: &seenObjects)
113+
self.init(_reflecting: subject, label: nil, timing: timing, seenObjects: &seenObjects)
98114
}
99115

100116
/// Initialize an instance of this type describing the specified subject and
@@ -105,13 +121,15 @@ public struct __Expression: Sendable {
105121
/// - label: An optional label for this value. This should be a non-`nil`
106122
/// value when creating instances of this type which describe
107123
/// substructural values.
124+
/// - timing: When the value represented by this instance was captured.
108125
/// - seenObjects: The objects which have been seen so far while calling
109126
/// this initializer recursively, keyed by their object identifiers.
110127
/// This is used to halt further recursion if a previously-seen object
111128
/// is encountered again.
112129
private init(
113130
_reflecting subject: Any,
114131
label: String?,
132+
timing: Timing?,
115133
seenObjects: inout [ObjectIdentifier: AnyObject]
116134
) {
117135
let mirror = Mirror(reflecting: subject)
@@ -160,6 +178,7 @@ public struct __Expression: Sendable {
160178
debugDescription = String(reflecting: subject)
161179
typeInfo = TypeInfo(describingTypeOf: subject)
162180
self.label = label
181+
self.timing = timing
163182

164183
isCollection = switch mirror.displayStyle {
165184
case .some(.collection),
@@ -172,7 +191,7 @@ public struct __Expression: Sendable {
172191

173192
if shouldIncludeChildren && (!mirror.children.isEmpty || isCollection) {
174193
self.children = mirror.children.map { child in
175-
Self(_reflecting: child.value, label: child.label, seenObjects: &seenObjects)
194+
Self(_reflecting: child.value, label: child.label, timing: timing, seenObjects: &seenObjects)
176195
}
177196
}
178197
}
@@ -227,12 +246,19 @@ public struct __Expression: Sendable {
227246
let runtimeValueDescription = String(describingForTest: runtimeValue)
228247
// Hack: don't print string representations of function calls.
229248
if runtimeValueDescription != "(Function)" && runtimeValueDescription != result {
249+
switch runtimeValue.timing {
250+
case .after:
251+
result = "\(result) (after)"
252+
default:
253+
break
254+
}
230255
result = "\(result)\(runtimeValueDescription)"
231256
}
232257
} else {
233258
result = "\(result) → <not evaluated>"
234259
}
235260

261+
236262
return result
237263
}
238264

@@ -264,6 +290,7 @@ public struct __Expression: Sendable {
264290
extension __Expression: Codable {}
265291
extension __Expression.Kind: Codable {}
266292
extension __Expression.Value: Codable {}
293+
extension __Expression.Value.Timing: Codable {}
267294

268295
// MARK: - CustomStringConvertible, CustomDebugStringConvertible
269296

Sources/TestingMacros/ConditionMacro.swift

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -169,7 +169,7 @@ extension ConditionMacro {
169169
let uniqueName = context.makeUniqueName("")
170170
expressionContextName = .identifier("\(expressionContextName)\(uniqueName)")
171171
}
172-
let (rewrittenArgumentExpr, rewrittenNodes) = insertCalls(
172+
let (rewrittenArgumentExpr, rewrittenNodes, prefixCodeBlockItems) = insertCalls(
173173
toExpressionContextNamed: expressionContextName,
174174
into: originalArgumentExpr,
175175
for: macro,
@@ -195,11 +195,20 @@ extension ConditionMacro {
195195
argumentExpr = closureArguments.rewrittenNode.cast(ExprSyntax.self)
196196
}
197197

198+
// If we're inserting any additional code into the closure before the
199+
// rewritten argument, we can't elide the return keyword for brevity.
200+
var returnKeyword: TokenSyntax?
201+
if !prefixCodeBlockItems.isEmpty {
202+
returnKeyword = .keyword(.return)
203+
.with(\.leadingTrivia, argumentExpr.leadingTrivia)
204+
argumentExpr.leadingTrivia = .space
205+
}
206+
198207
// Enclose the expression in a closure into which we pass our local
199208
// context object.
200209
argumentExpr = """
201210
{ \(closureArguments?.captureList) (\(expressionContextName): inout Testing.__ExpectationContext) in
202-
\(argumentExpr)
211+
\(prefixCodeBlockItems)\(returnKeyword)\(argumentExpr)
203212
}
204213
"""
205214

Sources/TestingMacros/Support/ConditionArgumentParsing.swift

Lines changed: 49 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,10 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
111111
/// The nodes in this array are the _original_ nodes, not the rewritten nodes.
112112
var rewrittenNodes = Set<Syntax>()
113113

114+
/// Any postflight code the caller should insert into the closure containing
115+
/// the rewritten syntax tree.
116+
var teardownItems = [CodeBlockItemSyntax]()
117+
114118
init(in context: C, for macro: M, rootedAt effectiveRootNode: Syntax, expressionContextName: TokenSyntax) {
115119
self.context = context
116120
self.macro = macro
@@ -123,12 +127,13 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
123127
/// (or rather, its `callAsFunction(_:_:)` member).
124128
///
125129
/// - Parameters:
126-
/// - functionNameExpr: If not `nil`, the name of the function to call (as a
127-
/// member of the expression context.)
128130
/// - node: The node to rewrite.
129131
/// - originalNode: The original node in the original syntax tree, if `node`
130132
/// has already been partially rewritten or substituted. If `node` has not
131133
/// been rewritten, this argument should equal it.
134+
/// - functionName: If not `nil`, the name of the function to call (as a
135+
/// member function of the expression context.)
136+
/// - additionalArguments: Any additional arguments to pass to the function.
132137
///
133138
/// - Returns: A rewritten copy of `node` that calls into the expression
134139
/// context when it is evaluated at runtime.
@@ -379,8 +384,24 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
379384
}
380385

381386
override func visit(_ node: InOutExprSyntax) -> ExprSyntax {
382-
// inout arguments cannot be forwarded through functions. In the future, we
383-
// could experiment with unsafe mutable pointers?
387+
// Swift's Law of Exclusivity means that only one subexpression in the
388+
// expectation ought to be interacting with `value` when it is passed
389+
// `inout`, so it should be sufficient to capture it in a `defer` statement
390+
// that runs after the expression is evaluated.
391+
392+
let teardownItem = CodeBlockItemSyntax(
393+
item: .expr(
394+
_rewrite(
395+
node.expression,
396+
originalWas: node,
397+
calling: .identifier("__inoutAfter")
398+
)
399+
)
400+
)
401+
teardownItems.append(teardownItem)
402+
403+
// The argument should not be expanded in-place as we can't return an
404+
// argument passed `inout` and expect it to remain semantically correct.
384405
return ExprSyntax(node)
385406
}
386407

@@ -525,23 +546,42 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
525546
/// the purposes of generating expression ID values.
526547
/// - context: The macro context in which the expression is being parsed.
527548
///
528-
/// - Returns: A tuple containing the rewritten copy of `node` as well as a list
529-
/// of all the nodes within `node` (possibly including `node` itself) that
530-
/// were rewritten.
549+
/// - Returns: A tuple containing the rewritten copy of `node`, a list of all
550+
/// the nodes within `node` (possibly including `node` itself) that were
551+
/// rewritten, and a code block containing code that should be inserted into
552+
/// the lexical scope of `node` _before_ its rewritten equivalent.
531553
func insertCalls(
532554
toExpressionContextNamed expressionContextName: TokenSyntax,
533555
into node: some SyntaxProtocol,
534556
for macro: some FreestandingMacroExpansionSyntax,
535557
rootedAt effectiveRootNode: some SyntaxProtocol,
536558
in context: some MacroExpansionContext
537-
) -> (Syntax, rewrittenNodes: Set<Syntax>) {
559+
) -> (Syntax, rewrittenNodes: Set<Syntax>, prefixCodeBlockItems: CodeBlockItemListSyntax) {
538560
if let node = node.as(ExprSyntax.self) {
539561
_diagnoseTrivialBooleanValue(from: node, for: macro, in: context)
540562
}
541563

542564
let contextInserter = _ContextInserter(in: context, for: macro, rootedAt: Syntax(effectiveRootNode), expressionContextName: expressionContextName)
543565
let result = contextInserter.rewrite(node)
544-
return (result, contextInserter.rewrittenNodes)
566+
let rewrittenNodes = contextInserter.rewrittenNodes
567+
568+
let prefixCodeBlockItems = CodeBlockItemListSyntax {
569+
if !contextInserter.teardownItems.isEmpty {
570+
CodeBlockItemSyntax(
571+
item: .stmt(
572+
StmtSyntax(
573+
DeferStmtSyntax {
574+
for teardownItem in contextInserter.teardownItems {
575+
teardownItem
576+
}
577+
}
578+
)
579+
)
580+
)
581+
}
582+
}.formatted().with(\.trailingTrivia, .newline).cast(CodeBlockItemListSyntax.self)
583+
584+
return (result, rewrittenNodes, prefixCodeBlockItems)
545585
}
546586

547587
// MARK: - Finding optional chains

Tests/SubexpressionShowcase/SubexpressionShowcase.swift

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,11 @@ func g() throws -> Int {
2020
22
2121
}
2222

23+
func io(_ x: inout Int) -> Int {
24+
x += 1
25+
return x + 1
26+
}
27+
2328
struct T {
2429
func h(_ i: Int) -> Bool { false }
2530
static func j(_ d: Double) -> Bool { false }
@@ -39,6 +44,12 @@ func subexpressionShowcase() async throws {
3944
#expect((123, 456) == (789, 0x12))
4045
#expect((try g() > 500) && true)
4146

47+
do {
48+
let n = Int.random(in: 0 ..< 100)
49+
var m = n
50+
#expect(io(&m) == n)
51+
}
52+
4253
let closure: (Int) -> Void = {
4354
#expect($0 == 0x10)
4455
}

0 commit comments

Comments
 (0)