Skip to content

Commit 73b833f

Browse files
committed
Tweaks to how we generate the capturing function calls
1 parent b7d6de4 commit 73b833f

File tree

6 files changed

+109
-147
lines changed

6 files changed

+109
-147
lines changed

Sources/Testing/Expectations/ExpectationContext.swift

Lines changed: 35 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -269,12 +269,12 @@ extension __ExpectationContext {
269269
/// - Warning: This function is used to implement the `#expect()` and
270270
/// `#require()` macros. Do not call it directly.
271271
public mutating func __cmp<T, U, R>(
272+
_ op: (T, U) throws -> R,
273+
_ opID: __ExpressionID,
272274
_ lhs: T,
273275
_ lhsID: __ExpressionID,
274276
_ rhs: U,
275-
_ rhsID: __ExpressionID,
276-
_ op: (T, U) throws -> R,
277-
_ opID: __ExpressionID
277+
_ rhsID: __ExpressionID
278278
) rethrows -> R {
279279
try self(op(self(lhs, lhsID), self(rhs, rhsID)), opID)
280280
}
@@ -287,12 +287,12 @@ extension __ExpectationContext {
287287
/// - Warning: This function is used to implement the `#expect()` and
288288
/// `#require()` macros. Do not call it directly.
289289
public mutating func __cmp<C>(
290+
_ op: (C, C) -> Bool,
291+
_ opID: __ExpressionID,
290292
_ lhs: C,
291293
_ lhsID: __ExpressionID,
292294
_ rhs: C,
293-
_ rhsID: __ExpressionID,
294-
_ op: (C, C) -> Bool,
295-
_ opID: __ExpressionID
295+
_ rhsID: __ExpressionID
296296
) -> Bool where C: BidirectionalCollection, C.Element: Equatable {
297297
let result = self(op(self(lhs, lhsID), self(rhs, rhsID)), opID)
298298

@@ -309,17 +309,17 @@ extension __ExpectationContext {
309309
///
310310
/// This overload of `__cmp()` does _not_ perform a diffing operation on `lhs`
311311
/// and `rhs`. Range expressions are not usefully diffable the way other kinds
312-
/// of collections are. ([139222774](rdar://139222774))
312+
/// of collections are. ([#639](https://github.com/swiftlang/swift-testing/issues/639))
313313
///
314314
/// - Warning: This function is used to implement the `#expect()` and
315315
/// `#require()` macros. Do not call it directly.
316316
public mutating func __cmp<R>(
317+
_ op: (R, R) -> Bool,
318+
_ opID: __ExpressionID,
317319
_ lhs: R,
318320
_ lhsID: __ExpressionID,
319321
_ rhs: R,
320-
_ rhsID: __ExpressionID,
321-
_ op: (R, R) -> Bool,
322-
_ opID: __ExpressionID
322+
_ rhsID: __ExpressionID
323323
) -> Bool where R: RangeExpression & BidirectionalCollection, R.Element: Equatable {
324324
self(op(self(lhs, lhsID), self(rhs, rhsID)), opID)
325325
}
@@ -333,12 +333,12 @@ extension __ExpectationContext {
333333
/// - Warning: This function is used to implement the `#expect()` and
334334
/// `#require()` macros. Do not call it directly.
335335
public mutating func __cmp<S>(
336+
_ op: (S, S) -> Bool,
337+
_ opID: __ExpressionID,
336338
_ lhs: S,
337339
_ lhsID: __ExpressionID,
338340
_ rhs: S,
339-
_ rhsID: __ExpressionID,
340-
_ op: (S, S) -> Bool,
341-
_ opID: __ExpressionID
341+
_ rhsID: __ExpressionID
342342
) -> Bool where S: StringProtocol {
343343
let result = self(op(self(lhs, lhsID), self(rhs, rhsID)), opID)
344344

@@ -376,9 +376,11 @@ extension __ExpectationContext {
376376
///
377377
/// - Parameters:
378378
/// - value: The value to cast.
379+
/// - valueID: A value that uniquely identifies the expression represented
380+
/// by `value` in the context of the expectation being evaluated.
379381
/// - type: The type to cast `value` to.
380-
/// - typeID: The ID chain of the `type` expression as emitted during
381-
/// expansion of the `#expect()` or `#require()` macro.
382+
/// - valueID: A value that uniquely identifies the expression represented
383+
/// by `type` in the context of the expectation being evaluated.
382384
///
383385
/// - Returns: The result of the expression `value as? type`.
384386
///
@@ -388,12 +390,12 @@ extension __ExpectationContext {
388390
///
389391
/// - Warning: This function is used to implement the `#expect()` and
390392
/// `#require()` macros. Do not call it directly.
391-
public mutating func __as<T, U>(_ value: T, _ type: U.Type, _ typeID: __ExpressionID) -> U? {
392-
let result = value as? U
393+
public mutating func __as<T, U>(_ value: T, _ valueID: __ExpressionID, _ type: U.Type, _ typeID: __ExpressionID) -> U? {
394+
let result = self(value, valueID) as? U
393395

394396
if result == nil {
395397
let correctType = Swift.type(of: value as Any)
396-
runtimeValues[typeID] = { Expression.Value(reflecting: correctType) }
398+
_ = self(correctType, typeID)
397399
}
398400

399401
return result
@@ -403,9 +405,11 @@ extension __ExpectationContext {
403405
///
404406
/// - Parameters:
405407
/// - value: The value to cast.
408+
/// - valueID: A value that uniquely identifies the expression represented
409+
/// by `value` in the context of the expectation being evaluated.
406410
/// - type: The type `value` is expected to be.
407-
/// - typeID: The ID chain of the `type` expression as emitted during
408-
/// expansion of the `#expect()` or `#require()` macro.
411+
/// - valueID: A value that uniquely identifies the expression represented
412+
/// by `type` in the context of the expectation being evaluated.
409413
///
410414
/// - Returns: The result of the expression `value as? type`.
411415
///
@@ -415,12 +419,12 @@ extension __ExpectationContext {
415419
///
416420
/// - Warning: This function is used to implement the `#expect()` and
417421
/// `#require()` macros. Do not call it directly.
418-
public mutating func __is<T, U>(_ value: T, _ type: U.Type, _ typeID: __ExpressionID) -> Bool {
419-
let result = value is U
422+
public mutating func __is<T, U>(_ value: T, _ valueID: __ExpressionID, _ type: U.Type, _ typeID: __ExpressionID) -> Bool {
423+
let result = self(value, valueID) is U
420424

421425
if !result {
422426
let correctType = Swift.type(of: value as Any)
423-
runtimeValues[typeID] = { Expression.Value(reflecting: correctType) }
427+
_ = self(correctType, typeID)
424428
}
425429

426430
return result
@@ -445,46 +449,9 @@ extension __ExpectationContext {
445449
///
446450
/// - Warning: This function is used to implement the `#expect()` and
447451
/// `#require()` macros. Do not call it directly.
448-
public mutating func callAsFunction<P, T>(_ value: P, _ id: __ExpressionID) -> UnsafePointer<T> where P: _Pointer {
449-
self(value as P?, id)!
450-
}
451-
452-
/// Convert some pointer to an immutable one and capture information about it
453-
/// for use if the expectation currently being evaluated fails.
454-
///
455-
/// - Parameters:
456-
/// - value: The pointer to make immutable.
457-
/// - id: A value that uniquely identifies the represented expression in the
458-
/// context of the expectation currently being evaluated.
459-
///
460-
/// - Returns: `value`, cast to an immutable pointer.
461-
///
462-
/// This overload of `callAsFunction(_:_:)` handles the implicit conversions
463-
/// between various pointer types that are normally provided by the compiler.
464-
///
465-
/// - Warning: This function is used to implement the `#expect()` and
466-
/// `#require()` macros. Do not call it directly.
467-
public mutating func callAsFunction<P, T>(_ value: P?, _ id: __ExpressionID) -> UnsafePointer<T>? where P: _Pointer {
468-
UnsafePointer(bitPattern: Int(bitPattern: self(value, id) as P?))
469-
}
470-
471-
/// Convert some pointer to an immutable one and capture information about it
472-
/// for use if the expectation currently being evaluated fails.
473-
///
474-
/// - Parameters:
475-
/// - value: The pointer to make immutable.
476-
/// - id: A value that uniquely identifies the represented expression in the
477-
/// context of the expectation currently being evaluated.
478-
///
479-
/// - Returns: `value`, cast to an immutable pointer.
480-
///
481-
/// This overload of `callAsFunction(_:_:)` handles the implicit conversions
482-
/// between various pointer types that are normally provided by the compiler.
483-
///
484-
/// - Warning: This function is used to implement the `#expect()` and
485-
/// `#require()` macros. Do not call it directly.
486-
public mutating func callAsFunction<P>(_ value: P, _ id: __ExpressionID) -> UnsafeRawPointer where P: _Pointer {
487-
self(value as P?, id)!
452+
@_disfavoredOverload
453+
public mutating func callAsFunction<PFrom, PTo>(_ value: PFrom, _ id: __ExpressionID) -> PTo where PFrom: _Pointer, PTo: _Pointer {
454+
self(value as PFrom?, id) as! PTo
488455
}
489456

490457
/// Convert some pointer to an immutable one and capture information about it
@@ -502,8 +469,11 @@ extension __ExpectationContext {
502469
///
503470
/// - Warning: This function is used to implement the `#expect()` and
504471
/// `#require()` macros. Do not call it directly.
505-
public mutating func callAsFunction<P>(_ value: P?, _ id: __ExpressionID) -> UnsafeRawPointer? where P: _Pointer {
506-
UnsafeRawPointer(bitPattern: Int(bitPattern: self(value, id) as P?))
472+
@_disfavoredOverload
473+
public mutating func callAsFunction<PFrom, PTo>(_ value: PFrom?, _ id: __ExpressionID) -> PTo? where PFrom: _Pointer, PTo: _Pointer {
474+
value.flatMap { value in
475+
PTo(bitPattern: Int(bitPattern: self(value, id) as PFrom))
476+
}
507477
}
508478
}
509479

Sources/TestingMacros/Support/ConditionArgumentParsing.swift

Lines changed: 59 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -123,26 +123,35 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
123123
/// (or rather, its `callAsFunction(_:_:)` member).
124124
///
125125
/// - Parameters:
126+
/// - functionNameExpr: If not `nil`, the name of the function to call (as a
127+
/// member of the expression context.)
126128
/// - node: The node to rewrite.
127129
/// - originalNode: The original node in the original syntax tree, if `node`
128130
/// has already been partially rewritten or substituted. If `node` has not
129131
/// been rewritten, this argument should equal it.
130132
///
131133
/// - Returns: A rewritten copy of `node` that calls into the expression
132134
/// context when it is evaluated at runtime.
133-
private func _rewrite<E>(_ node: E, originalWas originalNode: some SyntaxProtocol) -> ExprSyntax where E: ExprSyntaxProtocol {
134-
if rewrittenNodes.contains(Syntax(originalNode)) {
135+
private func _rewrite<E>(_ node: E, originalWas originalNode: some SyntaxProtocol, calling functionName: TokenSyntax? = nil, passing additionalArguments: [Argument] = []) -> ExprSyntax where E: ExprSyntaxProtocol {
136+
guard rewrittenNodes.insert(Syntax(originalNode)).inserted else {
135137
// If this node has already been rewritten, we don't need to rewrite it
136138
// again. (Currently, this can only happen when expanding binary operators
137139
// which need a bit of extra help.)
138140
return ExprSyntax(node)
139141
}
140142

141-
rewrittenNodes.insert(Syntax(originalNode))
143+
let calledExpr: ExprSyntax = if let functionName {
144+
ExprSyntax(MemberAccessExprSyntax(base: expressionContextNameExpr, name: functionName))
145+
} else {
146+
ExprSyntax(expressionContextNameExpr)
147+
}
142148

143-
var result = FunctionCallExprSyntax(calledExpression: expressionContextNameExpr) {
149+
var result = FunctionCallExprSyntax(calledExpression: calledExpr) {
144150
LabeledExprSyntax(expression: node.trimmed)
145151
LabeledExprSyntax(expression: originalNode.expressionID(rootedAt: effectiveRootNode))
152+
for argument in additionalArguments {
153+
LabeledExprSyntax(argument)
154+
}
146155
}
147156

148157
result.leftParen = .leftParenToken()
@@ -331,41 +340,27 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
331340
if let op = node.operator.as(BinaryOperatorExprSyntax.self)?.operator.textWithoutBackticks,
332341
op == "==" || op == "!=" || op == "===" || op == "!==" {
333342

334-
rewrittenNodes.insert(Syntax(node))
335-
rewrittenNodes.insert(Syntax(node.leftOperand))
336-
rewrittenNodes.insert(Syntax(node.rightOperand))
337-
338-
var result = FunctionCallExprSyntax(
339-
calledExpression: MemberAccessExprSyntax(
340-
base: expressionContextNameExpr,
341-
name: .identifier("__cmp")
342-
)
343-
) {
344-
LabeledExprSyntax(expression: visit(node.leftOperand))
345-
LabeledExprSyntax(expression: node.leftOperand.expressionID(rootedAt: effectiveRootNode))
346-
LabeledExprSyntax(expression: visit(node.rightOperand))
347-
LabeledExprSyntax(expression: node.rightOperand.expressionID(rootedAt: effectiveRootNode))
348-
LabeledExprSyntax(
349-
expression: ClosureExprSyntax {
350-
InfixOperatorExprSyntax(
351-
leftOperand: DeclReferenceExprSyntax(
352-
baseName: .dollarIdentifier("$0")
353-
).with(\.trailingTrivia, .space),
354-
operator: BinaryOperatorExprSyntax(text: op),
355-
rightOperand: DeclReferenceExprSyntax(
356-
baseName: .dollarIdentifier("$1")
357-
).with(\.leadingTrivia, .space)
358-
)
359-
}
360-
)
361-
LabeledExprSyntax(expression: node.expressionID(rootedAt: effectiveRootNode))
362-
}
363-
result.leftParen = .leftParenToken()
364-
result.rightParen = .rightParenToken()
365-
result.leadingTrivia = node.leadingTrivia
366-
result.trailingTrivia = node.trailingTrivia
367-
368-
return ExprSyntax(result)
343+
return _rewrite(
344+
ClosureExprSyntax {
345+
InfixOperatorExprSyntax(
346+
leftOperand: DeclReferenceExprSyntax(
347+
baseName: .dollarIdentifier("$0")
348+
).with(\.trailingTrivia, .space),
349+
operator: BinaryOperatorExprSyntax(text: op),
350+
rightOperand: DeclReferenceExprSyntax(
351+
baseName: .dollarIdentifier("$1")
352+
).with(\.leadingTrivia, .space)
353+
)
354+
},
355+
originalWas: node,
356+
calling: .identifier("__cmp"),
357+
passing: [
358+
Argument(expression: visit(node.leftOperand)),
359+
Argument(expression: node.leftOperand.expressionID(rootedAt: effectiveRootNode)),
360+
Argument(expression: visit(node.rightOperand)),
361+
Argument(expression: node.rightOperand.expressionID(rootedAt: effectiveRootNode))
362+
]
363+
)
369364
}
370365

371366
return _rewrite(
@@ -384,48 +379,41 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
384379

385380
// MARK: - Casts
386381

387-
/// Create a function call that represents an `is` or `as?` cast.
382+
/// Rewrite an `is` or `as?` cast.
388383
///
389384
/// - Parameters:
390385
/// - valueExpr: The expression to cast.
391386
/// - isAsKeyword: The casting keyword (either `.is` or `.as`).
392387
/// - type: The type to cast `valueExpr` to.
388+
/// - originalNode: The original `IsExprSyntax` or `AsExprSyntax` node in
389+
/// the original syntax tree.
393390
///
394391
/// - Returns: A function call expression equivalent to the described cast.
395-
private func _makeCastCall(_ valueExpr: ExprSyntax, _ isAsKeyword: Keyword, _ type: TypeSyntax) -> FunctionCallExprSyntax {
396-
var result = FunctionCallExprSyntax(
397-
calledExpression: MemberAccessExprSyntax(
398-
base: expressionContextNameExpr,
399-
name: .identifier("__\(isAsKeyword)")
400-
)
401-
) {
402-
LabeledExprSyntax(expression: visit(valueExpr).trimmed)
403-
LabeledExprSyntax(
404-
expression: _rewrite(
405-
MemberAccessExprSyntax(
392+
private func _rewriteAsCast(_ valueExpr: ExprSyntax, _ isAsKeyword: Keyword, _ type: TypeSyntax, originalWas originalNode: some SyntaxProtocol) -> ExprSyntax {
393+
rewrittenNodes.insert(Syntax(type))
394+
395+
return _rewrite(
396+
visit(valueExpr).trimmed,
397+
originalWas: originalNode,
398+
calling: .identifier("__\(isAsKeyword)"),
399+
passing: [
400+
Argument(
401+
expression: MemberAccessExprSyntax(
406402
base: TupleExprSyntax {
407403
LabeledExprSyntax(expression: TypeExprSyntax(type: type.trimmed))
408404
},
409405
declName: DeclReferenceExprSyntax(baseName: .keyword(.self))
410-
),
411-
originalWas: type
412-
)
413-
)
414-
LabeledExprSyntax(expression: type.expressionID(rootedAt: effectiveRootNode))
415-
}
416-
result.leftParen = .leftParenToken()
417-
result.rightParen = .rightParenToken()
418-
419-
return result
406+
)
407+
),
408+
Argument(expression: type.expressionID(rootedAt: effectiveRootNode))
409+
]
410+
)
420411
}
421412

422413
override func visit(_ node: AsExprSyntax) -> ExprSyntax {
423414
switch node.questionOrExclamationMark?.tokenKind {
424415
case .postfixQuestionMark:
425-
return _rewrite(
426-
_makeCastCall(node.expression, .as, node.type),
427-
originalWas: node
428-
)
416+
return _rewriteAsCast(node.expression, .as, node.type, originalWas: node)
429417

430418
case .exclamationMark where !node.type.isNamed("Bool", inModuleNamed: "Swift") && !node.type.isOptional:
431419
// Warn that as! will be evaluated before #expect() or #require(), which is
@@ -436,21 +424,19 @@ private final class _ContextInserter<C, M>: SyntaxRewriter where C: MacroExpansi
436424
context.diagnose(.asExclamationMarkIsEvaluatedEarly(node, in: macro))
437425
return _rewrite(node)
438426

439-
default:
427+
case .exclamationMark:
440428
// Only diagnose for `x as! T`. `x as T` is perfectly fine if it otherwise
441429
// compiles. For example, `#require(x as Int?)` should compile.
442-
//
443-
// If the token after "as" is something else entirely and got through the
444-
// type checker, just leave it alone as we don't recognize it.
445430
return _rewrite(node)
431+
432+
default:
433+
// This is an "escape hatch" cast. Do not attempt to process the cast.
434+
return ExprSyntax(node)
446435
}
447436
}
448437

449438
override func visit(_ node: IsExprSyntax) -> ExprSyntax {
450-
_rewrite(
451-
_makeCastCall(node.expression, .is, node.type),
452-
originalWas: node
453-
)
439+
_rewriteAsCast(node.expression, .is, node.type, originalWas: node)
454440
}
455441

456442
// MARK: - Literals

0 commit comments

Comments
 (0)