Skip to content

Commit 2c8529f

Browse files
committed
Add support for local server graceful shutdown
1 parent b3fbee6 commit 2c8529f

File tree

4 files changed

+124
-54
lines changed

4 files changed

+124
-54
lines changed

Sources/AWSLambdaRuntime/Lambda+LocalServer.swift

Lines changed: 83 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,21 @@ internal struct LambdaHTTPServer {
166166
// consumed by iterating the group or by exiting the group. Since, we are never consuming
167167
// the results of the group we need the group to automatically discard them; otherwise, this
168168
// would result in a memory leak over time.
169-
try await withThrowingDiscardingTaskGroup { taskGroup in
170-
try await channel.executeThenClose { inbound in
171-
for try await connectionChannel in inbound {
172-
173-
taskGroup.addTask {
174-
logger.trace("Handling a new connection")
175-
await server.handleConnection(channel: connectionChannel, logger: logger)
176-
logger.trace("Done handling the connection")
169+
try await withTaskCancellationHandler {
170+
try await withThrowingDiscardingTaskGroup { taskGroup in
171+
try await channel.executeThenClose { inbound in
172+
for try await connectionChannel in inbound {
173+
174+
taskGroup.addTask {
175+
logger.trace("Handling a new connection")
176+
await server.handleConnection(channel: connectionChannel, logger: logger)
177+
logger.trace("Done handling the connection")
178+
}
177179
}
178180
}
179181
}
182+
} onCancel: {
183+
channel.channel.close(promise: nil)
180184
}
181185
return .serverReturned(.success(()))
182186
} catch {
@@ -230,38 +234,42 @@ internal struct LambdaHTTPServer {
230234
// Note that this method is non-throwing and we are catching any error.
231235
// We do this since we don't want to tear down the whole server when a single connection
232236
// encounters an error.
233-
do {
234-
try await channel.executeThenClose { inbound, outbound in
235-
for try await inboundData in inbound {
236-
switch inboundData {
237-
case .head(let head):
238-
requestHead = head
239-
240-
case .body(let body):
241-
requestBody.setOrWriteImmutableBuffer(body)
242-
243-
case .end:
244-
precondition(requestHead != nil, "Received .end without .head")
245-
// process the request
246-
let response = try await self.processRequest(
247-
head: requestHead,
248-
body: requestBody,
249-
logger: logger
250-
)
251-
// send the responses
252-
try await self.sendResponse(
253-
response: response,
254-
outbound: outbound,
255-
logger: logger
256-
)
257-
258-
requestHead = nil
259-
requestBody = nil
237+
await withTaskCancellationHandler {
238+
do {
239+
try await channel.executeThenClose { inbound, outbound in
240+
for try await inboundData in inbound {
241+
switch inboundData {
242+
case .head(let head):
243+
requestHead = head
244+
245+
case .body(let body):
246+
requestBody.setOrWriteImmutableBuffer(body)
247+
248+
case .end:
249+
precondition(requestHead != nil, "Received .end without .head")
250+
// process the request
251+
let response = try await self.processRequest(
252+
head: requestHead,
253+
body: requestBody,
254+
logger: logger
255+
)
256+
// send the responses
257+
try await self.sendResponse(
258+
response: response,
259+
outbound: outbound,
260+
logger: logger
261+
)
262+
263+
requestHead = nil
264+
requestBody = nil
265+
}
260266
}
261267
}
268+
} catch {
269+
logger.error("Hit error: \(error)")
262270
}
263-
} catch {
264-
logger.error("Hit error: \(error)")
271+
} onCancel: {
272+
channel.channel.close(promise: nil)
265273
}
266274
}
267275

@@ -432,6 +440,7 @@ internal struct LambdaHTTPServer {
432440
enum State: ~Copyable {
433441
case buffer(Deque<T>)
434442
case continuation(CheckedContinuation<T, any Error>?)
443+
case cancelled
435444
}
436445

437446
private let lock = Mutex<State>(.buffer([]))
@@ -450,6 +459,10 @@ internal struct LambdaHTTPServer {
450459
buffer.append(invocation)
451460
state = .buffer(buffer)
452461
return nil
462+
463+
case .cancelled:
464+
state = .cancelled
465+
return nil
453466
}
454467
}
455468

@@ -462,26 +475,44 @@ internal struct LambdaHTTPServer {
462475
return nil
463476
}
464477

465-
return try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
466-
let nextAction = self.lock.withLock { state -> T? in
467-
switch consume state {
468-
case .buffer(var buffer):
469-
if let first = buffer.popFirst() {
470-
state = .buffer(buffer)
471-
return first
472-
} else {
473-
state = .continuation(continuation)
478+
return try await withTaskCancellationHandler {
479+
try await withCheckedThrowingContinuation { (continuation: CheckedContinuation<T, any Error>) in
480+
let nextAction = self.lock.withLock { state -> T? in
481+
switch consume state {
482+
case .buffer(var buffer):
483+
if let first = buffer.popFirst() {
484+
state = .buffer(buffer)
485+
return first
486+
} else {
487+
state = .continuation(continuation)
488+
return nil
489+
}
490+
491+
case .continuation:
492+
fatalError("Concurrent invocations to next(). This is illegal.")
493+
494+
case .cancelled:
495+
state = .cancelled
474496
return nil
475497
}
476-
477-
case .continuation:
478-
fatalError("Concurrent invocations to next(). This is illegal.")
479498
}
480-
}
481499

482-
guard let nextAction else { return }
500+
guard let nextAction else { return }
483501

484-
continuation.resume(returning: nextAction)
502+
continuation.resume(returning: nextAction)
503+
}
504+
} onCancel: {
505+
self.lock.withLock { state in
506+
switch consume state {
507+
case .buffer(let buffer):
508+
state = .buffer(buffer)
509+
case .continuation(let continuation):
510+
continuation?.resume(throwing: CancellationError())
511+
state = .continuation(continuation)
512+
case .cancelled:
513+
state = .cancelled
514+
}
515+
}
485516
}
486517
}
487518

Sources/AWSLambdaRuntime/LambdaRuntime+ServiceLifecycle.swift

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,5 +15,11 @@
1515
#if ServiceLifecycleSupport
1616
import ServiceLifecycle
1717

18-
extension LambdaRuntime: Service {}
18+
extension LambdaRuntime: Service {
19+
public func run() async throws {
20+
try await cancelWhenGracefulShutdown {
21+
try await self._run()
22+
}
23+
}
24+
}
1925
#endif

Sources/AWSLambdaRuntime/LambdaRuntime.swift

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,8 +51,15 @@ public final class LambdaRuntime<Handler>: @unchecked Sendable where Handler: St
5151
self.logger.debug("LambdaRuntime initialized")
5252
}
5353

54+
#if !ServiceLifecycleSupport
5455
@inlinable
55-
public func run() async throws {
56+
internal func run() async throws {
57+
try await _run()
58+
}
59+
#endif
60+
61+
@inlinable
62+
internal func _run() async throws {
5663
let handler = self.handlerMutex.withLockedValue { handler in
5764
let result = handler
5865
handler = nil

Tests/AWSLambdaRuntimeTests/LambdaRuntimeClientTests.swift

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
import Logging
1616
import NIOCore
1717
import NIOPosix
18+
import ServiceLifecycle
1819
import Testing
1920

2021
import struct Foundation.UUID
@@ -139,4 +140,29 @@ struct LambdaRuntimeClientTests {
139140
}
140141
}
141142
}
143+
#if ServiceLifecycleSupport
144+
@Test
145+
func testLambdaRuntimeGracefulShutdown() async throws {
146+
let runtime = LambdaRuntime {
147+
(event: String, context: LambdaContext) in
148+
"Hello \(event)"
149+
}
150+
var logger = Logger(label: "LambdaRuntime")
151+
logger.logLevel = .debug
152+
let serviceGroup = ServiceGroup(
153+
services: [runtime],
154+
gracefulShutdownSignals: [.sigterm, .sigint],
155+
logger: logger
156+
)
157+
try await withThrowingTaskGroup(of: Void.self) { group in
158+
group.addTask {
159+
try await serviceGroup.run()
160+
}
161+
// wait a small amount to ensure we are waiting for continuation
162+
try await Task.sleep(for: .milliseconds(100))
163+
164+
await serviceGroup.triggerGracefulShutdown()
165+
}
166+
}
167+
#endif
142168
}

0 commit comments

Comments
 (0)