Skip to content

Commit 5a600fa

Browse files
lxbndrandriydruk
authored andcommitted
Replace DispatchSource in URLSession on Windows with custom event listener (swiftlang#4791)
1 parent c69a19e commit 5a600fa

File tree

2 files changed

+228
-1
lines changed

2 files changed

+228
-1
lines changed

Sources/FoundationNetworking/URLSession/libcurl/MultiHandle.swift

+174-1
Original file line numberDiff line numberDiff line change
@@ -516,7 +516,7 @@ fileprivate extension URLSession._MultiHandle._Timeout {
516516
}
517517
}
518518

519-
519+
#if !os(Windows)
520520
/// Read and write libdispatch sources for a specific socket.
521521
///
522522
/// A simple helper that combines two sources -- both being optional.
@@ -605,6 +605,179 @@ extension _SocketSources {
605605
}
606606
}
607607
}
608+
609+
#else
610+
611+
private let threadpoolWaitCallback: PTP_WAIT_CALLBACK = { (inst, context, pwa, res) in
612+
guard let sources = _SocketSources.from(socketSourcePtr: context) else {
613+
fatalError("Context is not set in socket callback")
614+
}
615+
616+
sources.socketCallback()
617+
}
618+
619+
private class _SocketSources {
620+
struct SocketEvents: OptionSet {
621+
let rawValue: CLong
622+
623+
static let read = SocketEvents(rawValue: CLong(FD_READ))
624+
static let write = SocketEvents(rawValue: CLong(FD_WRITE))
625+
}
626+
627+
private var socket: SOCKET = INVALID_SOCKET
628+
private var queue: DispatchQueue?
629+
private var handler: DispatchWorkItem?
630+
631+
// Only the handlerCallout and callback properties are
632+
// accessed concurrently (from queue thread and ThreadpoolWait thread).
633+
// While callback property should not be raced due to specific
634+
// disarm logic, it is still guarded with lock for safety.
635+
private var handlerCallout: DispatchWorkItem?
636+
private var callback: (event: HANDLE, threadpoolWait: PTP_WAIT)?
637+
private let lock = NSLock()
638+
639+
private var networkEvents: CLong = 0
640+
private var events: SocketEvents = [] {
641+
didSet {
642+
guard oldValue != events else {
643+
return
644+
}
645+
triggerIO()
646+
}
647+
}
648+
649+
func triggerIO() {
650+
// Decide which network events we're interested in,
651+
// initialize callback lazily.
652+
let (networkEvents, event) = { () -> (CLong, HANDLE?) in
653+
guard !events.isEmpty else {
654+
return (0, nil)
655+
}
656+
let event = {
657+
if let callback = callback {
658+
return callback.event
659+
}
660+
guard let event = CreateEventW(nil, /* bManualReset */ false, /* bInitialState */ false, nil) else {
661+
fatalError("CreateEventW \(GetLastError())")
662+
}
663+
guard let threadpoolWait = CreateThreadpoolWait(threadpoolWaitCallback, Unmanaged.passUnretained(self).toOpaque(), /* PTP_CALLBACK_ENVIRON */ nil) else {
664+
fatalError("CreateThreadpoolWait \(GetLastError())")
665+
}
666+
SetThreadpoolWait(threadpoolWait, event, /* pftTimeout */ nil)
667+
callback = (event, threadpoolWait)
668+
return event
669+
}()
670+
return (CLong(FD_CLOSE) | events.rawValue, event)
671+
}()
672+
673+
if self.networkEvents != networkEvents {
674+
guard WSAEventSelect(socket, event, networkEvents) == 0 else {
675+
fatalError("WSAEventSelect \(WSAGetLastError())")
676+
}
677+
self.networkEvents = networkEvents
678+
}
679+
680+
if events.contains(.write) {
681+
// FD_WRITE will only be signaled if the socket becomes writable after
682+
// a send() fails with WSAEWOULDBLOCK. If shis zero-byte send() doesn't fail,
683+
// we could immediately schedule the handler callout.
684+
if send(socket, "", 0, 0) == 0 {
685+
performHandler()
686+
}
687+
} else if events.isEmpty, let callback = callback {
688+
SetThreadpoolWait(callback.threadpoolWait, nil, nil)
689+
WaitForThreadpoolWaitCallbacks(callback.threadpoolWait, /* fCancelPendingCallbacks */ true)
690+
CloseThreadpoolWait(callback.threadpoolWait)
691+
CloseHandle(callback.event)
692+
693+
lock.lock()
694+
self.callback = nil
695+
handlerCallout?.cancel()
696+
handlerCallout = nil
697+
lock.unlock()
698+
699+
handler = nil
700+
}
701+
}
702+
703+
func createSources(with action: URLSession._MultiHandle._SocketRegisterAction, socket: CFURLSession_socket_t, queue: DispatchQueue, handler: DispatchWorkItem) {
704+
precondition(self.socket == INVALID_SOCKET || self.socket == socket, "Socket value changed")
705+
precondition(self.queue == nil || self.queue === queue, "Queue changed")
706+
707+
self.socket = socket
708+
self.queue = queue
709+
self.handler = handler
710+
711+
events = action.socketEvents
712+
}
713+
714+
func tearDown() {
715+
events = []
716+
}
717+
718+
func socketCallback() {
719+
// Note: this called on ThreadpoolWait thread.
720+
lock.lock()
721+
if let callback = callback {
722+
ResetEvent(callback.event)
723+
SetThreadpoolWait(callback.threadpoolWait, callback.event, /* pftTimeout */ nil)
724+
}
725+
lock.unlock()
726+
727+
performHandler()
728+
}
729+
730+
private func performHandler() {
731+
guard let queue = queue else {
732+
fatalError("Attempting callout without queue set")
733+
}
734+
735+
let handlerCallout = DispatchWorkItem {
736+
self.lock.lock()
737+
self.handlerCallout = nil
738+
self.lock.unlock()
739+
740+
if let handler = self.handler, !handler.isCancelled {
741+
handler.perform()
742+
}
743+
744+
// Check if new callout was scheduled while we were performing the handler.
745+
self.lock.lock()
746+
let hasCallout = self.handlerCallout != nil
747+
self.lock.unlock()
748+
guard !hasCallout, !self.events.isEmpty else {
749+
return
750+
}
751+
752+
self.triggerIO()
753+
}
754+
755+
// Simple callout merge implementation.
756+
// Just do not schedule additional work if there is pending item.
757+
lock.lock()
758+
if self.handlerCallout == nil {
759+
self.handlerCallout = handlerCallout
760+
queue.async(execute: handlerCallout)
761+
}
762+
lock.unlock()
763+
}
764+
765+
}
766+
767+
private extension URLSession._MultiHandle._SocketRegisterAction {
768+
var socketEvents: _SocketSources.SocketEvents {
769+
switch self {
770+
case .none: return []
771+
case .registerRead: return [.read]
772+
case .registerWrite: return [.write]
773+
case .registerReadAndWrite: return [.read, .write]
774+
case .unregister: return []
775+
}
776+
}
777+
}
778+
779+
#endif
780+
608781
extension _SocketSources {
609782
/// Unwraps the `SocketSources`
610783
///

Tests/Foundation/TestURLSession.swift

+54
Original file line numberDiff line numberDiff line change
@@ -718,6 +718,60 @@ final class TestURLSession: LoopbackServerTest, @unchecked Sendable {
718718
waitForExpectations(timeout: 30)
719719
}
720720

721+
func test_largePost() throws {
722+
let session = URLSession(configuration: URLSessionConfiguration.default)
723+
var dataTask: URLSessionDataTask? = nil
724+
725+
let data = Data((0 ..< 131076).map { _ in UInt8.random(in: UInt8.min ... UInt8.max) })
726+
var req = URLRequest(url: URL(string: "http://127.0.0.1:\(TestURLSession.serverPort)/POST")!)
727+
req.httpMethod = "POST"
728+
req.httpBody = data
729+
730+
let e = expectation(description: "POST completed")
731+
dataTask = session.uploadTask(with: req, from: data) { data, response, error in
732+
e.fulfill()
733+
}
734+
dataTask?.resume()
735+
736+
waitForExpectations(timeout: 5)
737+
}
738+
739+
func test_slowPost() throws {
740+
class DrippingInputStream: InputStream {
741+
private var data: Data
742+
override public var hasBytesAvailable: Bool {
743+
return !data.isEmpty
744+
}
745+
override public init(data: Data) {
746+
self.data = data
747+
super.init(data: data)
748+
}
749+
override public func read(_ buffer: UnsafeMutablePointer<UInt8>, maxLength len: Int) -> Int {
750+
let readCount = min(min(len, data.count), 42)
751+
data.copyBytes(to: buffer, count: readCount)
752+
data = data.advanced(by: readCount)
753+
return readCount
754+
}
755+
}
756+
757+
let session = URLSession(configuration: URLSessionConfiguration.default)
758+
var dataTask: URLSessionDataTask? = nil
759+
760+
let data = Data((0 ..< 2048).map { _ in UInt8.random(in: UInt8.min ... UInt8.max) })
761+
var req = URLRequest(url: URL(string: "http://127.0.0.1:\(TestURLSession.serverPort)/POST")!)
762+
req.httpMethod = "POST"
763+
req.httpBodyStream = DrippingInputStream(data: data)
764+
765+
let e = expectation(description: "POST completed")
766+
dataTask = session.uploadTask(with: req, from: data) { data, response, error in
767+
XCTAssertNil(error)
768+
e.fulfill()
769+
}
770+
dataTask?.resume()
771+
772+
waitForExpectations(timeout: 5)
773+
}
774+
721775
func test_httpRedirectionWithCode300() async throws {
722776
let statusCode = 300
723777
for method in httpMethods {

0 commit comments

Comments
 (0)