Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
190 changes: 127 additions & 63 deletions Tests/AsyncHTTPClientTests/AsyncTestHelpers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,11 @@ final class AsyncSequenceWriter<Element: Sendable>: AsyncSequence, @unchecked Se
case failed(Error, CheckedContinuation<Void, Never>?)
}

private var _state = State.buffering(.init(), nil)
private let lock = NIOLock()
private let state = NIOLockedValueBox<State>(.buffering([], nil))

public var hasDemand: Bool {
self.lock.withLock {
switch self._state {
self.state.withLockedValue { state in
switch state {
case .failed, .finished, .buffering:
return false
case .waiting:
Expand All @@ -59,67 +58,132 @@ final class AsyncSequenceWriter<Element: Sendable>: AsyncSequence, @unchecked Se

/// Wait until a downstream consumer has issued more demand by calling `next`.
public func demand() async {
self.lock.lock()
let shouldBuffer = self.state.withLockedValue { state in
switch state {
case .buffering(_, .none):
return true
case .waiting:
return false
case .buffering(_, .some), .failed(_, .some):
preconditionFailure("Already waiting for demand. Invalid state: \(state)")
case .finished, .failed:
preconditionFailure("Invalid state: \(state)")
}
}

switch self._state {
case .buffering(let buffer, .none):
if shouldBuffer {
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
self._state = .buffering(buffer, continuation)
self.lock.unlock()
let shouldResumeContinuation = self.state.withLockedValue { state in
switch state {
case .buffering(let buffer, .none):
state = .buffering(buffer, continuation)
return false
case .waiting:
return true
case .buffering(_, .some), .failed(_, .some):
preconditionFailure("Already waiting for demand. Invalid state: \(state)")
case .finished, .failed:
preconditionFailure("Invalid state: \(state)")
}
}

if shouldResumeContinuation {
continuation.resume()
}
}

case .waiting:
self.lock.unlock()
return

case .buffering(_, .some), .failed(_, .some):
let state = self._state
self.lock.unlock()
preconditionFailure("Already waiting for demand. Invalid state: \(state)")

case .finished, .failed:
let state = self._state
self.lock.unlock()
preconditionFailure("Invalid state: \(state)")
}
}

private enum NextAction {
/// Resume the continuation if present, and return the result if present.
case resumeAndReturn(CheckedContinuation<Void, Never>?, Result<Element?, Error>?)
/// Suspend the current task and wait for the next value.
case suspend
}

private func next() async throws -> Element? {
self.lock.lock()
switch self._state {
case .buffering(let buffer, let demandContinuation) where buffer.isEmpty:
return try await withCheckedThrowingContinuation { continuation in
self._state = .waiting(continuation)
self.lock.unlock()
demandContinuation?.resume(returning: ())
}
let action: NextAction = self.state.withLockedValue { state in
switch state {
case .buffering(var buffer, let demandContinuation):
if buffer.isEmpty {
return .suspend
} else {
let first = buffer.removeFirst()
if first != nil {
state = .buffering(buffer, demandContinuation)
} else {
state = .finished
}
return .resumeAndReturn(nil, .success(first))
}

case .failed(let error, let demandContinuation):
state = .finished
return .resumeAndReturn(demandContinuation, .failure(error))

case .finished:
return .resumeAndReturn(nil, .success(nil))

case .buffering(var buffer, let demandContinuation):
let first = buffer.removeFirst()
if first != nil {
self._state = .buffering(buffer, demandContinuation)
} else {
self._state = .finished
case .waiting:
preconditionFailure(
"Expected that there is always only one concurrent call to next. Invalid state: \(state)"
)
}
self.lock.unlock()
return first
}

case .failed(let error, let demandContinuation):
self._state = .finished
self.lock.unlock()
switch action {
case .resumeAndReturn(let demandContinuation, let result):
demandContinuation?.resume()
throw error

case .finished:
self.lock.unlock()
return nil

case .waiting:
let state = self._state
self.lock.unlock()
preconditionFailure(
"Expected that there is always only one concurrent call to next. Invalid state: \(state)"
)
return try result?.get()

case .suspend:
// Holding the lock here *should* be safe but because of a bug in the runtime
// it isn't, so drop the lock, create the continuation and then try again.
//
// See https://github.com/swiftlang/swift/issues/85668
return try await withCheckedThrowingContinuation {
(continuation: CheckedContinuation<Element?, any Error>) in
let action: NextAction = self.state.withLockedValue { state in
switch state {
case .buffering(var buffer, let demandContinuation):
if buffer.isEmpty {
state = .waiting(continuation)
return .resumeAndReturn(demandContinuation, nil)
} else {
let first = buffer.removeFirst()
if first != nil {
state = .buffering(buffer, demandContinuation)
} else {
state = .finished
}
return .resumeAndReturn(nil, .success(first))
}

case .failed(let error, let demandContinuation):
state = .finished
return .resumeAndReturn(demandContinuation, .failure(error))

case .finished:
return .resumeAndReturn(nil, .success(nil))

case .waiting:
preconditionFailure(
"Expected that there is always only one concurrent call to next. Invalid state: \(state)"
)
}
}

switch action {
case .resumeAndReturn(let demandContinuation, let result):
demandContinuation?.resume()
// Resume the continuation rather than returning th result.
if let result {
continuation.resume(with: result)
}
case .suspend:
preconditionFailure() // Not returned from the code above.
}
}
}
}

Expand All @@ -137,19 +201,19 @@ final class AsyncSequenceWriter<Element: Sendable>: AsyncSequence, @unchecked Se
}

private func writeBufferOrEnd(_ element: Element?) {
let writeAction = self.lock.withLock { () -> WriteAction in
switch self._state {
let writeAction = self.state.withLockedValue { state -> WriteAction in
switch state {
case .buffering(var buffer, let continuation):
buffer.append(element)
self._state = .buffering(buffer, continuation)
state = .buffering(buffer, continuation)
return .none

case .waiting(let continuation):
self._state = .buffering(.init(), nil)
state = .buffering(.init(), nil)
return .succeedContinuation(continuation, element)

case .finished, .failed:
preconditionFailure("Invalid state: \(self._state)")
preconditionFailure("Invalid state: \(state)")
}
}

Expand All @@ -170,17 +234,17 @@ final class AsyncSequenceWriter<Element: Sendable>: AsyncSequence, @unchecked Se
/// Drops all buffered writes and emits an error on the waiting `next`. If there is no call to `next`
/// waiting, will emit the error on the next call to `next`.
public func fail(_ error: Error) {
let errorAction = self.lock.withLock { () -> ErrorAction in
switch self._state {
let errorAction = self.state.withLockedValue { state -> ErrorAction in
switch state {
case .buffering(_, let demandContinuation):
self._state = .failed(error, demandContinuation)
state = .failed(error, demandContinuation)
return .none

case .failed, .finished:
return .none

case .waiting(let continuation):
self._state = .finished
state = .finished
return .failContinuation(continuation, error)
}
}
Expand Down