Skip to content

Commit b902583

Browse files
committed
Replace DispatchSource in URLSession on Windows with custom event listener (swiftlang#4791)
1 parent 4ac0c38 commit b902583

File tree

2 files changed

+213
-1
lines changed

2 files changed

+213
-1
lines changed

Sources/FoundationNetworking/URLSession/libcurl/MultiHandle.swift

+175-1
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ fileprivate extension URLSession._MultiHandle {
127127
if let opaque = socketSourcePtr {
128128
Unmanaged<_SocketSources>.fromOpaque(opaque).release()
129129
}
130+
socketSources?.tearDown()
130131
socketSources = nil
131132
}
132133
if let ss = socketSources {
@@ -416,7 +417,7 @@ fileprivate extension URLSession._MultiHandle._Timeout {
416417
}
417418
}
418419

419-
420+
#if !os(Windows)
420421
/// Read and write libdispatch sources for a specific socket.
421422
///
422423
/// A simple helper that combines two sources -- both being optional.
@@ -474,6 +475,179 @@ extension _SocketSources {
474475
}
475476
}
476477
}
478+
479+
#else
480+
481+
private let threadpoolWaitCallback: PTP_WAIT_CALLBACK = { (inst, context, pwa, res) in
482+
guard let sources = _SocketSources.from(socketSourcePtr: context) else {
483+
fatalError("Context is not set in socket callback")
484+
}
485+
486+
sources.socketCallback()
487+
}
488+
489+
private class _SocketSources {
490+
struct SocketEvents: OptionSet {
491+
let rawValue: CLong
492+
493+
static let read = SocketEvents(rawValue: CLong(FD_READ))
494+
static let write = SocketEvents(rawValue: CLong(FD_WRITE))
495+
}
496+
497+
private var socket: SOCKET = INVALID_SOCKET
498+
private var queue: DispatchQueue?
499+
private var handler: DispatchWorkItem?
500+
501+
// Only the handlerCallout and callback properties are
502+
// accessed concurrently (from queue thread and ThreadpoolWait thread).
503+
// While callback property should not be raced due to specific
504+
// disarm logic, it is still guarded with lock for safety.
505+
private var handlerCallout: DispatchWorkItem?
506+
private var callback: (event: HANDLE, threadpoolWait: PTP_WAIT)?
507+
private let lock = NSLock()
508+
509+
private var networkEvents: CLong = 0
510+
private var events: SocketEvents = [] {
511+
didSet {
512+
guard oldValue != events else {
513+
return
514+
}
515+
triggerIO()
516+
}
517+
}
518+
519+
func triggerIO() {
520+
// Decide which network events we're interested in,
521+
// initialize callback lazily.
522+
let (networkEvents, event) = { () -> (CLong, HANDLE?) in
523+
guard !events.isEmpty else {
524+
return (0, nil)
525+
}
526+
let event = {
527+
if let callback = callback {
528+
return callback.event
529+
}
530+
guard let event = CreateEventW(nil, /* bManualReset */ false, /* bInitialState */ false, nil) else {
531+
fatalError("CreateEventW \(GetLastError())")
532+
}
533+
guard let threadpoolWait = CreateThreadpoolWait(threadpoolWaitCallback, Unmanaged.passUnretained(self).toOpaque(), /* PTP_CALLBACK_ENVIRON */ nil) else {
534+
fatalError("CreateThreadpoolWait \(GetLastError())")
535+
}
536+
SetThreadpoolWait(threadpoolWait, event, /* pftTimeout */ nil)
537+
callback = (event, threadpoolWait)
538+
return event
539+
}()
540+
return (CLong(FD_CLOSE) | events.rawValue, event)
541+
}()
542+
543+
if self.networkEvents != networkEvents {
544+
guard WSAEventSelect(socket, event, networkEvents) == 0 else {
545+
fatalError("WSAEventSelect \(WSAGetLastError())")
546+
}
547+
self.networkEvents = networkEvents
548+
}
549+
550+
if events.contains(.write) {
551+
// FD_WRITE will only be signaled if the socket becomes writable after
552+
// a send() fails with WSAEWOULDBLOCK. If shis zero-byte send() doesn't fail,
553+
// we could immediately schedule the handler callout.
554+
if send(socket, "", 0, 0) == 0 {
555+
performHandler()
556+
}
557+
} else if events.isEmpty, let callback = callback {
558+
SetThreadpoolWait(callback.threadpoolWait, nil, nil)
559+
WaitForThreadpoolWaitCallbacks(callback.threadpoolWait, /* fCancelPendingCallbacks */ true)
560+
CloseThreadpoolWait(callback.threadpoolWait)
561+
CloseHandle(callback.event)
562+
563+
lock.lock()
564+
self.callback = nil
565+
handlerCallout?.cancel()
566+
handlerCallout = nil
567+
lock.unlock()
568+
569+
handler = nil
570+
}
571+
}
572+
573+
func createSources(with action: URLSession._MultiHandle._SocketRegisterAction, socket: CFURLSession_socket_t, queue: DispatchQueue, handler: DispatchWorkItem) {
574+
precondition(self.socket == INVALID_SOCKET || self.socket == socket, "Socket value changed")
575+
precondition(self.queue == nil || self.queue === queue, "Queue changed")
576+
577+
self.socket = socket
578+
self.queue = queue
579+
self.handler = handler
580+
581+
events = action.socketEvents
582+
}
583+
584+
func tearDown() {
585+
events = []
586+
}
587+
588+
func socketCallback() {
589+
// Note: this called on ThreadpoolWait thread.
590+
lock.lock()
591+
if let callback = callback {
592+
ResetEvent(callback.event)
593+
SetThreadpoolWait(callback.threadpoolWait, callback.event, /* pftTimeout */ nil)
594+
}
595+
lock.unlock()
596+
597+
performHandler()
598+
}
599+
600+
private func performHandler() {
601+
guard let queue = queue else {
602+
fatalError("Attempting callout without queue set")
603+
}
604+
605+
let handlerCallout = DispatchWorkItem {
606+
self.lock.lock()
607+
self.handlerCallout = nil
608+
self.lock.unlock()
609+
610+
if let handler = self.handler, !handler.isCancelled {
611+
handler.perform()
612+
}
613+
614+
// Check if new callout was scheduled while we were performing the handler.
615+
self.lock.lock()
616+
let hasCallout = self.handlerCallout != nil
617+
self.lock.unlock()
618+
guard !hasCallout, !self.events.isEmpty else {
619+
return
620+
}
621+
622+
self.triggerIO()
623+
}
624+
625+
// Simple callout merge implementation.
626+
// Just do not schedule additional work if there is pending item.
627+
lock.lock()
628+
if self.handlerCallout == nil {
629+
self.handlerCallout = handlerCallout
630+
queue.async(execute: handlerCallout)
631+
}
632+
lock.unlock()
633+
}
634+
635+
}
636+
637+
private extension URLSession._MultiHandle._SocketRegisterAction {
638+
var socketEvents: _SocketSources.SocketEvents {
639+
switch self {
640+
case .none: return []
641+
case .registerRead: return [.read]
642+
case .registerWrite: return [.write]
643+
case .registerReadAndWrite: return [.read, .write]
644+
case .unregister: return []
645+
}
646+
}
647+
}
648+
649+
#endif
650+
477651
extension _SocketSources {
478652
/// Unwraps the `SocketSources`
479653
///

Tests/Foundation/Tests/TestURLSession.swift

+38
Original file line numberDiff line numberDiff line change
@@ -637,6 +637,42 @@ class TestURLSession: LoopbackServerTest {
637637
waitForExpectations(timeout: 5)
638638
}
639639

640+
func test_slowPost() throws {
641+
class DrippingInputStream: InputStream {
642+
private var data: Data
643+
override public var hasBytesAvailable: Bool {
644+
return !data.isEmpty
645+
}
646+
override public init(data: Data) {
647+
self.data = data
648+
super.init(data: data)
649+
}
650+
override public func read(_ buffer: UnsafeMutablePointer<UInt8>, maxLength len: Int) -> Int {
651+
let readCount = min(min(len, data.count), 42)
652+
data.copyBytes(to: buffer, count: readCount)
653+
data = data.advanced(by: readCount)
654+
return readCount
655+
}
656+
}
657+
658+
let session = URLSession(configuration: URLSessionConfiguration.default)
659+
var dataTask: URLSessionDataTask? = nil
660+
661+
let data = Data((0 ..< 2048).map { _ in UInt8.random(in: UInt8.min ... UInt8.max) })
662+
var req = URLRequest(url: URL(string: "http://127.0.0.1:\(TestURLSession.serverPort)/POST")!)
663+
req.httpMethod = "POST"
664+
req.httpBodyStream = DrippingInputStream(data: data)
665+
666+
let e = expectation(description: "POST completed")
667+
dataTask = session.uploadTask(with: req, from: data) { data, response, error in
668+
XCTAssertNil(error)
669+
e.fulfill()
670+
}
671+
dataTask?.resume()
672+
673+
waitForExpectations(timeout: 5)
674+
}
675+
640676
func test_httpRedirectionWithCode300() throws {
641677
let statusCode = 300
642678
for method in httpMethods {
@@ -2233,6 +2269,8 @@ class TestURLSession: LoopbackServerTest {
22332269
("test_verifyHttpAdditionalHeaders", test_verifyHttpAdditionalHeaders),
22342270
("test_httpTimeout", test_httpTimeout),
22352271
("test_connectTimeout", test_connectTimeout),
2272+
("test_largePost", test_largePost),
2273+
("test_slowPost", test_slowPost),
22362274
("test_httpRedirectionWithCode300", test_httpRedirectionWithCode300),
22372275
("test_httpRedirectionWithCode301_302", test_httpRedirectionWithCode301_302),
22382276
("test_httpRedirectionWithCode303", test_httpRedirectionWithCode303),

0 commit comments

Comments
 (0)