Skip to content
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 68 additions & 3 deletions Sources/Testing/Support/Locked.swift
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,14 @@ struct Locked<T> {
/// A type providing storage for the underlying lock and wrapped value.
#if SWT_TARGET_OS_APPLE && canImport(os)
private typealias _Storage = ManagedBuffer<T, os_unfair_lock_s>
#elseif !SWT_FIXED_85448 && (os(Linux) || os(Android))
private final class _Storage: ManagedBuffer<T, pthread_mutex_t> {
deinit {
withUnsafeMutablePointerToElements { lock in
_ = pthread_mutex_destroy(lock)
}
}
}
#else
private final class _Storage {
let mutex: Mutex<T>
Expand All @@ -49,6 +57,11 @@ extension Locked: RawRepresentable {
_storage.withUnsafeMutablePointerToElements { lock in
lock.initialize(to: .init())
}
#elseif !SWT_FIXED_85448 && (os(Linux) || os(Android))
_storage = _Storage.create(minimumCapacity: 1, makingHeaderWith: { _ in rawValue }) as! _Storage
_storage.withUnsafeMutablePointerToElements { lock in
_ = pthread_mutex_init(lock, nil)
}
#else
nonisolated(unsafe) let rawValue = rawValue
_storage = _Storage(rawValue)
Expand Down Expand Up @@ -77,20 +90,72 @@ extension Locked {
/// synchronous caller. Wherever possible, use actor isolation or other Swift
/// concurrency tools.
func withLock<R>(_ body: (inout T) throws -> sending R) rethrows -> sending R where R: ~Copyable {
nonisolated(unsafe) let result: R
#if SWT_TARGET_OS_APPLE && canImport(os)
nonisolated(unsafe) let result = try _storage.withUnsafeMutablePointers { rawValue, lock in
result = try _storage.withUnsafeMutablePointers { rawValue, lock in
os_unfair_lock_lock(lock)
defer {
os_unfair_lock_unlock(lock)
}
return try body(&rawValue.pointee)
}
#elseif !SWT_FIXED_85448 && (os(Linux) || os(Android))
result = try _storage.withUnsafeMutablePointers { rawValue, lock in
pthread_mutex_lock(lock)
defer {
pthread_mutex_unlock(lock)
}
return try body(&rawValue.pointee)
}
#else
result = try _storage.mutex.withLock { rawValue in
return try body(&rawValue)
}
#endif
return result
}

/// Try to acquire the lock and invoke a function while it is held.
///
/// - Parameters:
/// - body: A closure to invoke while the lock is held.
///
/// - Returns: Whatever is returned by `body`, or `nil` if the lock could not
/// be acquired.
///
/// - Throws: Whatever is thrown by `body`.
///
/// This function can be used to synchronize access to shared data from a
/// synchronous caller. Wherever possible, use actor isolation or other Swift
/// concurrency tools.
func withLockIfAvailable<R>(_ body: (inout T) throws -> sending R) rethrows -> sending R? where R: ~Copyable {
nonisolated(unsafe) let result: R?
#if SWT_TARGET_OS_APPLE && canImport(os)
result = try _storage.withUnsafeMutablePointers { rawValue, lock in
guard os_unfair_lock_trylock(lock) else {
return nil
}
defer {
os_unfair_lock_unlock(lock)
}
return try body(&rawValue.pointee)
}
#elseif !SWT_FIXED_85448 && (os(Linux) || os(Android))
result = try _storage.withUnsafeMutablePointers { rawValue, lock in
guard 0 == pthread_mutex_trylock(lock) else {
return nil
}
defer {
pthread_mutex_unlock(lock)
}
return try body(&rawValue.pointee)
}
#else
try _storage.mutex.withLock { rawValue in
try body(&rawValue)
result = try _storage.mutex.withLockIfAvailable { rawValue in
return try body(&rawValue)
}
#endif
return result
}
}

Expand Down
82 changes: 47 additions & 35 deletions Sources/Testing/Test+Cancellation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,8 @@ protocol TestCancellable: Sendable {

// MARK: - Tracking the current task

/// A structure describing a reference to a task that is associated with some
/// ``TestCancellable`` value.
private struct _TaskReference: Sendable {
/// A structure that is able to cancel a task.
private struct _TaskCanceller: Sendable {
/// The unsafe underlying reference to the associated task.
private nonisolated(unsafe) var _unsafeCurrentTask = Locked<UnsafeCurrentTask?>()

Expand All @@ -45,25 +44,46 @@ private struct _TaskReference: Sendable {
_unsafeCurrentTask = withUnsafeCurrentTask { Locked(rawValue: $0) }
}

/// Take this instance's reference to its associated task.
///
/// - Returns: An `UnsafeCurrentTask` instance, or `nil` if it was already
/// taken or if it was never available.
///
/// This function consumes the reference to the task. After the first call,
/// subsequent calls on the same instance return `nil`.
func takeUnsafeCurrentTask() -> UnsafeCurrentTask? {
/// Clear this instance's reference to its associated task without first
/// cancelling it.
func clear() {
_unsafeCurrentTask.withLock { unsafeCurrentTask in
let result = unsafeCurrentTask
unsafeCurrentTask = nil
return result
}
}

/// Cancel this instance's associated task and clear the reference to it.
///
/// - Returns: Whether or not this instance's task was cancelled.
///
/// After the first call to this function _starts_, subsequent calls on the
/// same instance return `false`. In other words, if another thread calls this
/// function before it has returned (or the same thread calls it recursively),
/// it returns `false` without cancelling the task a second time.
func cancel(with skipInfo: SkipInfo) -> Bool {
// trylock means a recursive call to this function won't ruin our day, nor
// should interleaving locks.
_unsafeCurrentTask.withLockIfAvailable { unsafeCurrentTask in
defer {
unsafeCurrentTask = nil
}
if let unsafeCurrentTask {
// The task is still valid, so we'll cancel it.
$_currentSkipInfo.withValue(skipInfo) {
unsafeCurrentTask.cancel()
}
return true
}

// The task has already been cancelled and/or cleared.
return false
} ?? false
}
}

/// A dictionary of tracked tasks, keyed by types that conform to
/// A dictionary of cancellable tasks keyed by types that conform to
/// ``TestCancellable``.
@TaskLocal private var _currentTaskReferences = [ObjectIdentifier: _TaskReference]()
@TaskLocal private var _currentTaskCancellers = [ObjectIdentifier: _TaskCanceller]()

/// The instance of ``SkipInfo`` to propagate to children of the current task.
///
Expand All @@ -87,16 +107,15 @@ extension TestCancellable {
/// the current task, test, or test case is cancelled, it records a
/// corresponding cancellation event.
func withCancellationHandling<R>(_ body: () async throws -> R) async rethrows -> R {
let taskReference = _TaskReference()
var currentTaskReferences = _currentTaskReferences
currentTaskReferences[ObjectIdentifier(Self.self)] = taskReference
return try await $_currentTaskReferences.withValue(currentTaskReferences) {
// Before returning, explicitly clear the stored task. This minimizes
// the potential race condition that can occur if test code creates an
// unstructured task and calls `Test.cancel()` in it after the test body
// has finished.
let taskCanceller = _TaskCanceller()
var currentTaskCancellers = _currentTaskCancellers
currentTaskCancellers[ObjectIdentifier(Self.self)] = taskCanceller
return try await $_currentTaskCancellers.withValue(currentTaskCancellers) {
// Before returning, explicitly clear the stored task so that an
// unstructured task that inherits the task local isn't able to
// accidentally cancel the task after it has been deallocated.
defer {
_ = taskReference.takeUnsafeCurrentTask()
taskCanceller.clear()
}

return try await withTaskCancellationHandler {
Expand All @@ -121,17 +140,10 @@ extension TestCancellable {
/// - testAndTestCase: The test and test case to use when posting an event.
/// - skipInfo: Information about the cancellation event.
private func _cancel<T>(_ cancellableValue: T?, for testAndTestCase: (Test?, Test.Case?), skipInfo: SkipInfo) where T: TestCancellable {
if cancellableValue != nil {
// If the current test case is still running, take its task property (which
// signals to subsequent callers that it has been cancelled.)
let task = _currentTaskReferences[ObjectIdentifier(T.self)]?.takeUnsafeCurrentTask()

// If we just cancelled the current test case's task, post a corresponding
// event with the relevant skip info.
if let task {
$_currentSkipInfo.withValue(skipInfo) {
task.cancel()
}
if cancellableValue != nil, let taskCanceller = _currentTaskCancellers[ObjectIdentifier(T.self)] {
// Try to cancel the task associated with `T`, if any. If we succeed, post a
// corresponding event with the relevant skip info.
if taskCanceller.cancel(with: skipInfo) {
Event.post(T.makeCancelledEventKind(with: skipInfo), for: testAndTestCase)
}
} else {
Expand Down