Skip to content

Commit b280455

Browse files
committed
implement streaming on the server side
1 parent 5a4415c commit b280455

File tree

1 file changed

+101
-62
lines changed

1 file changed

+101
-62
lines changed

Sources/AWSLambdaRuntime/Lambda+LocalServer.swift

Lines changed: 101 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,7 @@ internal struct LambdaHTTPServer {
235235

236236
var requestHead: HTTPRequestHead!
237237
var requestBody: ByteBuffer?
238+
var requestId: String?
238239

239240
// Note that this method is non-throwing and we are catching any error.
240241
// We do this since we don't want to tear down the whole server when a single connection
@@ -246,16 +247,27 @@ internal struct LambdaHTTPServer {
246247
switch inboundData {
247248
case .head(let head):
248249
requestHead = head
250+
requestId = getRequestId(from: requestHead)
251+
252+
// for streaming requests, push a partial head response
253+
if self.isStreamingResponse(requestHead) {
254+
await self.responsePool.push(
255+
LocalServerResponse(
256+
id: requestId,
257+
status: .ok
258+
)
259+
)
260+
}
249261

250262
case .body(let body):
251263
precondition(requestHead != nil, "Received .body without .head")
252264

253265
// if this is a request from a Streaming Lambda Handler,
254266
// stream the response instead of buffering it
255267
if self.isStreamingResponse(requestHead) {
256-
// we are receiving a chunked body,
257-
// we can stream the response and not accumulate the chunks
258-
print(String(buffer: body))
268+
await self.responsePool.push(
269+
LocalServerResponse(id: requestId, body: body)
270+
)
259271
} else {
260272
requestBody.setOrWriteImmutableBuffer(body)
261273
}
@@ -265,22 +277,23 @@ internal struct LambdaHTTPServer {
265277

266278
// process the buffered response for non streaming requests
267279
if !self.isStreamingResponse(requestHead) {
268-
// process the complete request
269-
let response = try await self.processCompleteRequest(
280+
// process the request and send the response
281+
try await self.processRequestAndSendResponse(
270282
head: requestHead,
271283
body: requestBody,
272-
logger: logger
273-
)
274-
// send the responses
275-
try await self.sendCompleteResponse(
276-
response: response,
277284
outbound: outbound,
278285
logger: logger
279286
)
287+
} else {
288+
await self.responsePool.push(
289+
LocalServerResponse(id: requestId, final: true)
290+
)
291+
280292
}
281293

282294
requestHead = nil
283295
requestBody = nil
296+
requestId = nil
284297
}
285298
}
286299
}
@@ -304,6 +317,11 @@ internal struct LambdaHTTPServer {
304317
requestHead.headers["Transfer-Encoding"].contains("chunked")
305318
}
306319

320+
/// This function pareses and returns the requestId or nil if the request is malformed
321+
private func getRequestId(from head: HTTPRequestHead) -> String? {
322+
let parts = head.uri.split(separator: "/")
323+
return parts.count > 2 ? String(parts[parts.count - 2]) : nil
324+
}
307325
/// This function process the URI request sent by the client and by the Lambda function
308326
///
309327
/// It enqueues the client invocation and iterate over the invocation queue when the Lambda function sends /next request
@@ -314,19 +332,22 @@ internal struct LambdaHTTPServer {
314332
/// - body: the HTTP request body
315333
/// - Throws:
316334
/// - Returns: the response to send back to the client or the Lambda function
317-
private func processCompleteRequest(
335+
private func processRequestAndSendResponse(
318336
head: HTTPRequestHead,
319337
body: ByteBuffer?,
338+
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>,
320339
logger: Logger
321-
) async throws -> LocalServerResponse {
340+
) async throws {
322341

342+
var logger = logger
343+
logger[metadataKey: "URI"] = "\(head.method) \(head.uri)"
323344
if let body {
324345
logger.trace(
325346
"Processing request",
326-
metadata: ["URI": "\(head.method) \(head.uri)", "Body": "\(String(buffer: body))"]
347+
metadata: ["Body": "\(String(buffer: body))"]
327348
)
328349
} else {
329-
logger.trace("Processing request", metadata: ["URI": "\(head.method) \(head.uri)"])
350+
logger.trace("Processing request")
330351
}
331352

332353
switch (head.method, head.uri) {
@@ -337,27 +358,32 @@ internal struct LambdaHTTPServer {
337358
// client POST /invoke
338359
case (.POST, let url) where url.hasSuffix(self.invocationEndpoint):
339360
guard let body else {
340-
return .init(status: .badRequest, headers: [], body: nil)
361+
return try await sendResponse(.init(status: .badRequest), outbound: outbound, logger: logger)
341362
}
342363
// we always accept the /invoke request and push them to the pool
343364
let requestId = "\(DispatchTime.now().uptimeNanoseconds)"
344-
var logger = logger
345365
logger[metadataKey: "requestID"] = "\(requestId)"
346-
logger.trace("/invoke received invocation")
366+
logger.trace("/invoke received invocation, pushing it to the stack")
347367
await self.invocationPool.push(LocalServerInvocation(requestId: requestId, request: body))
348368

349369
// wait for the lambda function to process the request
350370
for try await response in self.responsePool {
351-
logger.trace(
352-
"Received response to return to client",
353-
metadata: ["requestId": "\(response.requestId ?? "")"]
354-
)
371+
logger[metadataKey: "requestID"] = "\(requestId)"
372+
logger.trace("Received response to return to client")
355373
if response.requestId == requestId {
356-
return response
374+
logger.trace("/invoke requestId is valid, sending the response")
375+
// send the response to the client
376+
// if the response is final, we can send it and return
377+
// if the response is not final, we can send it and wait for the next response
378+
try await self.sendResponse(response, outbound: outbound, logger: logger)
379+
if response.final == true {
380+
logger.trace("/invoke returning")
381+
return // if the response is final, we can return and close the connection
382+
}
357383
} else {
358384
logger.error(
359385
"Received response for a different request id",
360-
metadata: ["response requestId": "\(response.requestId ?? "")", "requestId": "\(requestId)"]
386+
metadata: ["response requestId": "\(response.requestId ?? "")"]
361387
)
362388
// should we return an error here ? Or crash as this is probably a programming error?
363389
}
@@ -368,7 +394,7 @@ internal struct LambdaHTTPServer {
368394

369395
// client uses incorrect HTTP method
370396
case (_, let url) where url.hasSuffix(self.invocationEndpoint):
371-
return .init(status: .methodNotAllowed)
397+
return try await sendResponse(.init(status: .methodNotAllowed), outbound: outbound, logger: logger)
372398

373399
//
374400
// lambda invocations
@@ -381,85 +407,97 @@ internal struct LambdaHTTPServer {
381407
// pop the tasks from the queue
382408
logger.trace("/next waiting for /invoke")
383409
for try await invocation in self.invocationPool {
384-
logger.trace("/next retrieved invocation", metadata: ["requestId": "\(invocation.requestId)"])
385-
// this call also stores the invocation requestId into the response
386-
return invocation.makeResponse(status: .accepted)
410+
logger[metadataKey: "requestId"] = "\(invocation.requestId)"
411+
logger.trace("/next retrieved invocation")
412+
// tell the lambda function we accepted the invocation
413+
return try await sendResponse(invocation.acceptedResponse(), outbound: outbound, logger: logger)
387414
}
388415
// What todo when there is no more tasks to process?
389416
// This should not happen as the async iterator blocks until there is a task to process
390417
fatalError("No more invocations to process - the async for loop should not return")
391418

392419
// :requestID/response endpoint is called by the lambda posting the response
393420
case (.POST, let url) where url.hasSuffix(Consts.postResponseURLSuffix):
394-
let parts = head.uri.split(separator: "/")
395-
guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else {
421+
guard let requestID = getRequestId(from: head) else {
396422
// the request is malformed, since we were expecting a requestId in the path
397-
return .init(status: .badRequest)
423+
return try await sendResponse(.init(status: .badRequest), outbound: outbound, logger: logger)
398424
}
399425
// enqueue the lambda function response to be served as response to the client /invoke
400426
logger.trace("/:requestID/response received response", metadata: ["requestId": "\(requestID)"])
401427
await self.responsePool.push(
402428
LocalServerResponse(
403429
id: requestID,
404430
status: .ok,
405-
headers: [("Content-Type", "application/json")],
431+
headers: HTTPHeaders([("Content-Type", "application/json")]),
406432
body: body
407433
)
408434
)
409435

410436
// tell the Lambda function we accepted the response
411-
return .init(id: requestID, status: .accepted)
437+
return try await sendResponse(.init(id: requestID, status: .accepted), outbound: outbound, logger: logger)
412438

413439
// :requestID/error endpoint is called by the lambda posting an error response
414440
// we accept all requestID and we do not handle the body, we just acknowledge the request
415441
case (.POST, let url) where url.hasSuffix(Consts.postErrorURLSuffix):
416-
let parts = head.uri.split(separator: "/")
417-
guard let requestID = parts.count > 2 ? String(parts[parts.count - 2]) : nil else {
442+
guard let requestID = getRequestId(from: head) else {
418443
// the request is malformed, since we were expecting a requestId in the path
419-
return .init(status: .badRequest)
444+
return try await sendResponse(.init(status: .badRequest), outbound: outbound, logger: logger)
420445
}
421446
// enqueue the lambda function response to be served as response to the client /invoke
422447
logger.trace("/:requestID/response received response", metadata: ["requestId": "\(requestID)"])
423448
await self.responsePool.push(
424449
LocalServerResponse(
425450
id: requestID,
426451
status: .internalServerError,
427-
headers: [("Content-Type", "application/json")],
452+
headers: HTTPHeaders([("Content-Type", "application/json")]),
428453
body: body
429454
)
430455
)
431456

432-
return .init(status: .accepted)
457+
return try await sendResponse(.init(status: .accepted), outbound: outbound, logger: logger)
433458

434459
// unknown call
435460
default:
436-
return .init(status: .notFound)
461+
return try await sendResponse(.init(status: .notFound), outbound: outbound, logger: logger)
437462
}
438463
}
439464

440-
private func sendCompleteResponse(
441-
response: LocalServerResponse,
465+
private func sendResponse(
466+
_ response: LocalServerResponse,
442467
outbound: NIOAsyncChannelOutboundWriter<HTTPServerResponsePart>,
443468
logger: Logger
444469
) async throws {
445-
var headers = HTTPHeaders(response.headers ?? [])
446-
headers.add(name: "Content-Length", value: "\(response.body?.readableBytes ?? 0)")
447-
448-
logger.trace("Writing response", metadata: ["requestId": "\(response.requestId ?? "")"])
449-
try await outbound.write(
450-
HTTPServerResponsePart.head(
451-
HTTPResponseHead(
452-
version: .init(major: 1, minor: 1),
453-
status: response.status,
454-
headers: headers
470+
var logger = logger
471+
logger[metadataKey: "requestId"] = "\(response.requestId ?? "nil")"
472+
logger.trace("Writing response")
473+
474+
var headers = response.headers ?? HTTPHeaders()
475+
if let body = response.body {
476+
headers.add(name: "Content-Length", value: "\(body.readableBytes)")
477+
}
478+
479+
if let status = response.status {
480+
logger.trace("Sending status and headers")
481+
try await outbound.write(
482+
HTTPServerResponsePart.head(
483+
HTTPResponseHead(
484+
version: .init(major: 1, minor: 1),
485+
status: status,
486+
headers: headers
487+
)
455488
)
456489
)
457-
)
490+
}
491+
458492
if let body = response.body {
493+
logger.trace("Sending body")
459494
try await outbound.write(HTTPServerResponsePart.body(.byteBuffer(body)))
460495
}
461496

462-
try await outbound.write(HTTPServerResponsePart.end(nil))
497+
if response.final {
498+
logger.trace("Sending end")
499+
try await outbound.write(HTTPServerResponsePart.end(nil))
500+
}
463501
}
464502

465503
/// A shared data structure to store the current invocation or response requests and the continuation objects.
@@ -543,36 +581,37 @@ internal struct LambdaHTTPServer {
543581

544582
private struct LocalServerResponse: Sendable {
545583
let requestId: String?
546-
let status: HTTPResponseStatus
547-
let headers: [(String, String)]?
584+
let status: HTTPResponseStatus?
585+
let headers: HTTPHeaders?
548586
let body: ByteBuffer?
549-
init(id: String? = nil, status: HTTPResponseStatus, headers: [(String, String)]? = nil, body: ByteBuffer? = nil)
550-
{
587+
let final: Bool
588+
init(id: String? = nil, status: HTTPResponseStatus? = nil, headers: HTTPHeaders? = nil, body: ByteBuffer? = nil, final: Bool = false) {
551589
self.requestId = id
552590
self.status = status
553591
self.headers = headers
554592
self.body = body
593+
self.final = final
555594
}
556595
}
557596

558597
private struct LocalServerInvocation: Sendable {
559598
let requestId: String
560599
let request: ByteBuffer
561600

562-
func makeResponse(status: HTTPResponseStatus) -> LocalServerResponse {
601+
func acceptedResponse() -> LocalServerResponse {
563602

564603
// required headers
565-
let headers = [
604+
let headers = HTTPHeaders([
566605
(AmazonHeaders.requestID, self.requestId),
567606
(
568-
AmazonHeaders.invokedFunctionARN,
569-
"arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime"
607+
AmazonHeaders.invokedFunctionARN,
608+
"arn:aws:lambda:us-east-1:\(Int16.random(in: Int16.min ... Int16.max)):function:custom-runtime"
570609
),
571610
(AmazonHeaders.traceID, "Root=\(AmazonHeaders.generateXRayTraceID());Sampled=1"),
572611
(AmazonHeaders.deadline, "\(DispatchWallTime.distantFuture.millisSinceEpoch)"),
573-
]
612+
])
574613

575-
return LocalServerResponse(id: self.requestId, status: status, headers: headers, body: self.request)
614+
return LocalServerResponse(id: self.requestId, status: .accepted, headers: headers, body: self.request, final: true)
576615
}
577616
}
578617
}

0 commit comments

Comments
 (0)