From 981a0a808258dda417902deb2c0f3283d38ce1e0 Mon Sep 17 00:00:00 2001 From: Yashash H L Date: Thu, 26 Sep 2024 22:22:32 +0530 Subject: [PATCH 1/2] feat: bidirectional streaming for udsink (#90) Signed-off-by: Yashash H L Signed-off-by: Vigith Maurice Co-authored-by: Vigith Maurice --- examples/simple-source/Makefile | 5 +- examples/sink-log/Dockerfile | 2 +- proto/sink.proto | 55 +++-- src/sink.rs | 389 ++++++++++++++++++++++++-------- src/source.rs | 112 +++++---- 5 files changed, 405 insertions(+), 158 deletions(-) diff --git a/examples/simple-source/Makefile b/examples/simple-source/Makefile index 80ba511..b7769e8 100644 --- a/examples/simple-source/Makefile +++ b/examples/simple-source/Makefile @@ -10,9 +10,10 @@ update: .PHONY: image image: update - cd ../../ && docker buildx build \ + cd ../../ && docker build \ -f ${DOCKER_FILE_PATH} \ - -t ${IMAGE_REGISTRY} . --platform linux/amd64,linux/arm64 --push + -t ${IMAGE_REGISTRY} . + @if [ "$(PUSH)" = "true" ]; then docker push ${IMAGE_REGISTRY}; fi .PHONY: clean clean: diff --git a/examples/sink-log/Dockerfile b/examples/sink-log/Dockerfile index a7aa942..334b445 100644 --- a/examples/sink-log/Dockerfile +++ b/examples/sink-log/Dockerfile @@ -11,7 +11,7 @@ WORKDIR /numaflow-rs/examples/sink-log RUN cargo build --release # our final base -FROM debian:bookworm AS simple-source +FROM debian:bookworm AS sink-log # copy the build artifact from the build stage COPY --from=build /numaflow-rs/examples/sink-log/target/release/server . diff --git a/proto/sink.proto b/proto/sink.proto index c413ea8..300e570 100644 --- a/proto/sink.proto +++ b/proto/sink.proto @@ -7,7 +7,7 @@ package sink.v1; service Sink { // SinkFn writes the request to a user defined sink. - rpc SinkFn(stream SinkRequest) returns (SinkResponse); + rpc SinkFn(stream SinkRequest) returns (stream SinkResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); @@ -17,12 +17,32 @@ service Sink { * SinkRequest represents a request element. */ message SinkRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - string id = 5; - map headers = 6; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + string id = 5; + map headers = 6; + } + message Status { + bool eot = 1; + } + // Required field indicating the request. + Request request = 1; + // Required field indicating the status of the request. + // If eot is set to true, it indicates the end of transmission. + Status status = 2; + // optional field indicating the handshake message. + optional Handshake handshake = 3; +} + +/* + * Handshake message between client and server to indicate the start of transmission. + */ +message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; } /** @@ -32,6 +52,15 @@ message ReadyResponse { bool ready = 1; } +/* + * Status is the status of the response. + */ +enum Status { + SUCCESS = 0; + FAILURE = 1; + FALLBACK = 2; +} + /** * SinkResponse is the individual response of each message written to the sink. */ @@ -44,14 +73,6 @@ message SinkResponse { // err_msg is the error message, set it if success is set to false. string err_msg = 3; } - repeated Result results = 1; -} - -/* - * Status is the status of the response. - */ -enum Status { - SUCCESS = 0; - FAILURE = 1; - FALLBACK = 2; + Result result = 1; + optional Handshake handshake = 2; } \ No newline at end of file diff --git a/src/sink.rs b/src/sink.rs index 2e2f3ba..43ff530 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -1,14 +1,19 @@ +use crate::error::Error; use crate::error::Error::SinkError; use crate::error::ErrorKind::{InternalError, UserDefinedError}; use crate::shared; +use crate::sink::sink_pb::SinkResponse; use chrono::{DateTime, Utc}; use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; use std::{env, fs}; use tokio::sync::{mpsc, oneshot}; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; use tonic::{Request, Status, Streaming}; +use tracing::{debug, info}; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sink.sock"; @@ -21,7 +26,7 @@ const UD_CONTAINER_FB_SINK: &str = "fb-udsink"; // TODO: use batch-size, blocked by https://github.com/numaproj/numaflow/issues/2026 const DEFAULT_CHANNEL_SIZE: usize = 1000; /// Numaflow Sink Proto definitions. -pub mod proto { +pub mod sink_pb { tonic::include_proto!("sink.v1"); } @@ -98,8 +103,8 @@ pub struct SinkRequest { pub headers: HashMap, } -impl From for SinkRequest { - fn from(sr: proto::SinkRequest) -> Self { +impl From for SinkRequest { + fn from(sr: sink_pb::sink_request::Request) -> Self { Self { keys: sr.keys, value: sr.value, @@ -157,16 +162,16 @@ impl Response { } } -impl From for proto::sink_response::Result { +impl From for sink_pb::sink_response::Result { fn from(r: Response) -> Self { Self { id: r.id, status: if r.fallback { - proto::Status::Fallback as i32 + sink_pb::Status::Fallback as i32 } else if r.success { - proto::Status::Success as i32 + sink_pb::Status::Success as i32 } else { - proto::Status::Failure as i32 + sink_pb::Status::Failure as i32 }, err_msg: r.err.unwrap_or_default(), } @@ -174,82 +179,207 @@ impl From for proto::sink_response::Result { } #[tonic::async_trait] -impl proto::sink_server::Sink for SinkService +impl sink_pb::sink_server::Sink for SinkService where T: Sinker + Send + Sync + 'static, { + type SinkFnStream = ReceiverStream>; + async fn sink_fn( &self, - request: Request>, - ) -> Result, Status> { - let mut stream = request.into_inner(); + request: Request>, + ) -> Result, Status> { + let mut sink_stream = request.into_inner(); let sink_handle = self.handler.clone(); - let cancellation_token = self.cancellation_token.clone(); let shutdown_tx = self.shutdown_tx.clone(); - // FIXME: we should be using the batch size as the channel size + let cln_token = self.cancellation_token.clone(); + let (resp_tx, resp_rx) = + mpsc::channel::>(DEFAULT_CHANNEL_SIZE); + + self.perform_handshake(&mut sink_stream, &resp_tx).await?; + + let grpc_resp_tx = resp_tx.clone(); + let handle: JoinHandle> = tokio::spawn(async move { + Self::process_sink_stream(sink_handle, sink_stream, grpc_resp_tx).await + }); + + tokio::spawn(Self::handle_sink_errors( + handle, + resp_tx, + shutdown_tx, + cln_token, + )); + + Ok(tonic::Response::new(ReceiverStream::new(resp_rx))) + } + + async fn is_ready( + &self, + _: Request<()>, + ) -> Result, Status> { + Ok(tonic::Response::new(sink_pb::ReadyResponse { ready: true })) + } +} + +impl SinkService +where + T: Sinker + Send + Sync + 'static, +{ + /// processes the stream of requests from the client + async fn process_sink_stream( + sink_handle: Arc, + mut sink_stream: Streaming, + grpc_resp_tx: mpsc::Sender>, + ) -> Result<(), Error> { + loop { + let done = Self::process_sink_batch( + sink_handle.clone(), + &mut sink_stream, + grpc_resp_tx.clone(), + ) + .await?; + if done { + break; + } + } + Ok(()) + } + + /// processes a batch of messages from the client, sends them to the sink handler and sends the + /// responses back to the client batches are separated by an EOT message + async fn process_sink_batch( + sink_handle: Arc, + sink_stream: &mut Streaming, + grpc_resp_tx: mpsc::Sender>, + ) -> Result { let (tx, rx) = mpsc::channel::(DEFAULT_CHANNEL_SIZE); + let resp_tx = grpc_resp_tx.clone(); + let sink_handle = sink_handle.clone(); + + // spawn the UDF + let sinker_handle = tokio::spawn(async move { + let responses = sink_handle.sink(rx).await; + for response in responses { + resp_tx + .send(Ok(SinkResponse { + result: Some(response.into()), + handshake: None, + })) + .await + .expect("Sending response to channel"); + } + }); - let reader_shutdown_tx = shutdown_tx.clone(); - // spawn a task to read messages from the stream and send them to the user's sink handle - let reader_handle = tokio::spawn(async move { - loop { - match stream.message().await { - Ok(Some(message)) => { - // If sending fails, it means the receiver is dropped, and we should stop the task. - if let Err(e) = tx.send(message.into()).await { - tracing::error!("Failed to send message: {}", e); - break; - } - } - // If there's an error or the stream ends, break the loop to stop the task. - Ok(None) => break, - Err(e) => { - tracing::error!("Error reading message from stream: {}", e); - reader_shutdown_tx - .send(()) - .await - .expect("Sending shutdown signal to gRPC server"); - break; - } + loop { + let message = match sink_stream.message().await { + Ok(Some(m)) => m, + Ok(None) => { + info!("global bidi stream ended"); + return Ok(true); // bidi stream ended + } + Err(e) => { + return Err(SinkError(InternalError(format!( + "Error reading message from stream: {}", + e + )))) } + }; + + // we are done with this batch because eot=true + if message.status.map_or(false, |status| status.eot) { + debug!("Batch Ended, received an EOT message"); + break; } - }); - // call the user's sink handle - let handle = tokio::spawn(async move { sink_handle.sink(rx).await }); + // message.request cannot be none + let request = message.request.ok_or_else(|| { + SinkError(InternalError( + "Invalid argument, request can't be None".to_string(), + )) + })?; + + // write to the UDF's tx + tx.send(request.into()).await.map_err(|e| { + SinkError(InternalError(format!( + "Error sending message to sink handler: {}", + e + ))) + })?; + } + + // drop the sender to signal the sink handler that the batch has ended + drop(tx); - // Wait for the handler to finish processing the request. If the server is shutting down(token will be cancelled), - // then return an error. + // wait for UDF task to return + sinker_handle + .await + .map_err(|e| SinkError(UserDefinedError(e.to_string())))?; + + Ok(false) + } + + /// handles errors from the sink handler and sends them to the client via the response channel + async fn handle_sink_errors( + handle: JoinHandle>, + resp_tx: mpsc::Sender>, + shutdown_tx: mpsc::Sender<()>, + cln_token: CancellationToken, + ) { tokio::select! { - result = handle => { - match result { - Ok(responses) => { - Ok(tonic::Response::new(proto::SinkResponse { - results: responses.into_iter().map(|r| r.into()).collect(), - })) + resp = handle => { + match resp { + Ok(Ok(_)) => {}, + Ok(Err(e)) => { + resp_tx + .send(Err(Status::internal(e.to_string()))) + .await + .expect("Sending error to response channel"); + shutdown_tx.send(()).await.expect("Sending shutdown signal"); } Err(e) => { - // Send a shutdown signal to the server to do a graceful shutdown because there was - // a panic in the handler. - shutdown_tx.send(()).await.expect("Sending shutdown signal to gRPC server"); - Err(Status::internal(SinkError(UserDefinedError(e.to_string())).to_string())) + resp_tx + .send(Err(Status::internal(format!("Sink handler aborted: {}", e)))) + .await + .expect("Sending error to response channel"); + shutdown_tx.send(()).await.expect("Sending shutdown signal"); } } }, - - _ = cancellation_token.cancelled() => { - // abort the reader task to stop reading messages from the stream - reader_handle.abort(); - Err(Status::cancelled(SinkError(InternalError("Server is shutting down".to_string())).to_string())) + _ = cln_token.cancelled() => { + resp_tx + .send(Err(Status::cancelled("Sink handler cancelled"))) + .await + .expect("Sending error to response channel"); } } } - async fn is_ready( + // performs handshake with the client + async fn perform_handshake( &self, - _: Request<()>, - ) -> Result, Status> { - Ok(tonic::Response::new(proto::ReadyResponse { ready: true })) + sink_stream: &mut Streaming, + resp_tx: &mpsc::Sender>, + ) -> Result<(), Status> { + let handshake_request = sink_stream + .message() + .await + .map_err(|e| Status::internal(format!("handshake failed {}", e)))? + .ok_or_else(|| Status::internal("stream closed before handshake"))?; + + if let Some(handshake) = handshake_request.handshake { + resp_tx + .send(Ok(SinkResponse { + result: None, + handshake: Some(handshake), + })) + .await + .map_err(|e| { + Status::internal(format!("failed to send handshake response {}", e)) + })?; + Ok(()) + } else { + Err(Status::invalid_argument("Handshake not present")) + } } } @@ -339,7 +469,7 @@ impl Server { cancellation_token: cln_token.clone(), }; - let svc = proto::sink_server::SinkServer::new(svc) + let svc = sink_pb::sink_server::SinkServer::new(svc) .max_encoding_message_size(self.max_message_size) .max_decoding_message_size(self.max_message_size); @@ -385,7 +515,9 @@ mod tests { use tower::service_fn; use crate::sink; - use crate::sink::proto::sink_client::SinkClient; + use crate::sink::sink_pb::sink_client::SinkClient; + use crate::sink::sink_pb::sink_request::{Request, Status}; + use crate::sink::sink_pb::Handshake; #[tokio::test] async fn sink_server() -> Result<(), Box> { @@ -397,27 +529,16 @@ mod tests { mut input: tokio::sync::mpsc::Receiver, ) -> Vec { let mut responses: Vec = Vec::new(); - while let Some(datum) = input.recv().await { - // do something better, but for now let's just log it. - // please note that `from_utf8` is working because the input in this - // example uses utf-8 data. let response = match std::str::from_utf8(&datum.value) { - Ok(v) => { - println!("{}", v); - // record the response - sink::Response::ok(datum.id) - } + Ok(_) => sink::Response::ok(datum.id), Err(e) => sink::Response::failure( datum.id, format!("Invalid UTF-8 sequence: {}", e), ), }; - - // return the responses responses.push(response); } - responses } } @@ -454,22 +575,72 @@ mod tests { .await?; let mut client = SinkClient::new(channel); - let request = sink::proto::SinkRequest { - keys: vec!["first".into(), "second".into()], - value: "hello".into(), - watermark: Some(prost_types::Timestamp::default()), - event_time: Some(prost_types::Timestamp::default()), - id: "1".to_string(), - headers: Default::default(), + // Send handshake request + let handshake_request = sink::sink_pb::SinkRequest { + request: None, + status: None, + handshake: Some(Handshake { sot: true }), + }; + let request = sink::sink_pb::SinkRequest { + request: Some(Request { + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + id: "1".to_string(), + headers: Default::default(), + }), + status: None, + handshake: None, + }; + + let eot_request = sink::sink_pb::SinkRequest { + request: None, + status: Some(Status { eot: true }), + handshake: None, }; - let resp = client.sink_fn(tokio_stream::iter(vec![request])).await?; - let resp = resp.into_inner(); - assert_eq!(resp.results.len(), 1, "Expected single message from server"); - let msg = &resp.results[0]; + let request_two = sink::sink_pb::SinkRequest { + request: Some(Request { + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + id: "2".to_string(), + headers: Default::default(), + }), + status: None, + handshake: None, + }; + + let resp = client + .sink_fn(tokio_stream::iter(vec![ + handshake_request, + request, + eot_request, + request_two, + ])) + .await?; + + let mut resp_stream = resp.into_inner(); + // handshake response + let resp = resp_stream.message().await.unwrap().unwrap(); + assert!(resp.result.is_none()); + assert!(resp.handshake.is_some()); + + let resp = resp_stream.message().await.unwrap().unwrap(); + assert!(resp.result.is_some()); + let msg = &resp.result.unwrap(); assert_eq!(msg.err_msg, ""); assert_eq!(msg.id, "1"); + let resp = resp_stream.message().await.unwrap().unwrap(); + assert!(resp.result.is_some()); + assert!(resp.handshake.is_none()); + let msg = &resp.result.unwrap(); + assert_eq!(msg.err_msg, ""); + assert_eq!(msg.id, "2"); + shutdown_tx .send(()) .expect("Sending shutdown signal to gRPC server"); @@ -533,24 +704,52 @@ mod tests { .await?; let mut client = SinkClient::new(channel); - let mut requests = Vec::new(); + // Send handshake request + let handshake_request = sink::sink_pb::SinkRequest { + request: None, + status: None, + handshake: Some(Handshake { sot: true }), + }; + + let mut requests = vec![handshake_request]; for i in 0..10 { - let request = sink::proto::SinkRequest { - keys: vec!["first".into(), "second".into()], - value: format!("hello {}", i).into(), - watermark: Some(prost_types::Timestamp::default()), - event_time: Some(prost_types::Timestamp::default()), - id: i.to_string(), - headers: Default::default(), + let request = sink::sink_pb::SinkRequest { + request: Some(Request { + keys: vec!["first".into(), "second".into()], + value: format!("hello {}", i).into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + id: i.to_string(), + headers: Default::default(), + }), + status: None, + handshake: None, }; requests.push(request); } - let resp = client.sink_fn(tokio_stream::iter(requests)).await; - assert!(resp.is_err(), "Expected error from server"); + requests.push(sink::sink_pb::SinkRequest { + request: None, + status: Some(Status { eot: true }), + handshake: None, + }); + + let mut resp_stream = client + .sink_fn(tokio_stream::iter(requests)) + .await + .unwrap() + .into_inner(); + + // handshake response + let resp = resp_stream.message().await.unwrap().unwrap(); + assert!(resp.result.is_none()); + assert!(resp.handshake.is_some()); + + let err_resp = resp_stream.message().await; + assert!(err_resp.is_err()); - if let Err(e) = resp { + if let Err(e) = err_resp { assert_eq!(e.code(), tonic::Code::Internal); assert!(e.message().contains("User Defined Error")); } diff --git a/src/source.rs b/src/source.rs index a36ca96..59c0b80 100644 --- a/src/source.rs +++ b/src/source.rs @@ -178,7 +178,7 @@ where &self, request: Request>, ) -> Result, Status> { - let mut sr = request.into_inner(); + let mut req_stream = request.into_inner(); // we have to call the handler over and over for each ReadRequest let handler_fn = Arc::clone(&self.handler); @@ -191,26 +191,8 @@ where let cln_token = self.cancellation_token.clone(); // do the handshake first to let the client know that we are ready to receive read requests. - let handshake_request = sr - .message() - .await - .map_err(|e| Status::internal(format!("handshake failed {}", e)))? - .ok_or_else(|| Status::internal("stream closed before handshake"))?; - - if let Some(handshake) = handshake_request.handshake { - grpc_tx - .send(Ok(ReadResponse { - result: None, - status: None, - handshake: Some(handshake), - })) - .await - .map_err(|e| { - Status::internal(format!("failed to send handshake response {}", e)) - })?; - } else { - return Err(Status::invalid_argument("Handshake not present")); - } + self.perform_read_handshake(&mut req_stream, &grpc_tx) + .await?; // this is the top-level stream consumer and this task will only exit when stream is closed (which // will happen when server and client are shutting down). @@ -219,7 +201,7 @@ where tokio::select! { // for each ReadRequest message, the handler will be called and a batch of messages // will be sent over to the client. - read_request = sr.message() => { + read_request = req_stream.message() => { let read_request = read_request .map_err(|e| SourceError(ErrorKind::InternalError(e.to_string())))? .ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; @@ -277,27 +259,9 @@ where let handler_fn = Arc::clone(&self.handler); // do the handshake first to let the client know that we are ready to receive ack requests. - let handshake_request = ack_stream - .message() - .await - .map_err(|e| Status::internal(format!("handshake failed {}", e)))? - .ok_or_else(|| Status::internal("stream closed before handshake"))?; + self.perform_ack_handshake(&mut ack_stream, &ack_tx).await?; let ack_resp_tx = ack_tx.clone(); - if let Some(handshake) = handshake_request.handshake { - ack_resp_tx - .send(Ok(AckResponse { - result: None, - handshake: Some(handshake), - })) - .await - .map_err(|e| { - Status::internal(format!("failed to send handshake response {}", e)) - })?; - } else { - return Err(Status::invalid_argument("Handshake not present")); - } - let cln_token = self.cancellation_token.clone(); let grpc_read_handle: JoinHandle> = tokio::spawn(async move { loop { @@ -312,10 +276,10 @@ where .ok_or_else(|| SourceError(ErrorKind::InternalError("Stream closed".to_string())))?; let request = ack_request.request - .ok_or_else(|| SourceError(ErrorKind::InternalError("Invalid request, request is empty".to_string())))?; + .ok_or_else(|| SourceError(ErrorKind::InternalError("Invalid request, request can't be empty".to_string())))?; let offset = request.offset - .ok_or_else(|| SourceError(ErrorKind::InternalError("Invalid request, offset is empty".to_string())))?; + .ok_or_else(|| SourceError(ErrorKind::InternalError("Invalid request, offset can't be empty".to_string())))?; handler_fn .ack(Offset { @@ -390,6 +354,68 @@ where } } +impl SourceService +where + T: Sourcer + Send + Sync + 'static, +{ + // performs the read handshake with the client + async fn perform_read_handshake( + &self, + read_stream: &mut Streaming, + resp_tx: &Sender>, + ) -> Result<(), Status> { + let handshake_request = read_stream + .message() + .await + .map_err(|e| Status::internal(format!("read handshake failed {}", e)))? + .ok_or_else(|| Status::internal("read stream closed before handshake"))?; + + if let Some(handshake) = handshake_request.handshake { + resp_tx + .send(Ok(ReadResponse { + result: None, + status: None, + handshake: Some(handshake), + })) + .await + .map_err(|e| { + Status::internal(format!("failed to send read handshake response {}", e)) + })?; + Ok(()) + } else { + Err(Status::invalid_argument("Read handshake not present")) + } + } + + // performs the ack handshake with the client + async fn perform_ack_handshake( + &self, + ack_stream: &mut Streaming, + resp_tx: &Sender>, + ) -> Result<(), Status> { + let handshake_request = ack_stream + .message() + .await + .map_err(|e| Status::internal(format!("ack handshake failed {}", e)))? + .ok_or_else(|| Status::internal("ack stream closed before handshake"))?; + + if let Some(handshake) = handshake_request.handshake { + resp_tx + .send(Ok(AckResponse { + result: None, + handshake: Some(handshake), + })) + .await + .map_err(|e| { + Status::internal(format!("failed to send ack handshake response {}", e)) + })?; + Ok(()) + } else { + Err(Status::invalid_argument("Ack handshake not present")) + } + } +} + /// Message is the response from the user's [`Sourcer::read`] pub struct Message { /// The value passed to the next vertex. From d998ab2f82a074a7f7fab7123fb9c695f4b8843b Mon Sep 17 00:00:00 2001 From: Vigith Maurice Date: Thu, 26 Sep 2024 19:31:29 -0700 Subject: [PATCH 2/2] fix: invariant on global stream ending is upheld (#92) Signed-off-by: Vigith Maurice --- src/sink.rs | 24 ++++++++++++++++++------ 1 file changed, 18 insertions(+), 6 deletions(-) diff --git a/src/sink.rs b/src/sink.rs index 43ff530..160a374 100644 --- a/src/sink.rs +++ b/src/sink.rs @@ -231,14 +231,19 @@ where mut sink_stream: Streaming, grpc_resp_tx: mpsc::Sender>, ) -> Result<(), Error> { + // loop until the global stream has been shutdown. loop { - let done = Self::process_sink_batch( + // for every batch, we need to read from the stream. The end-of-batch is + // encoded in the request. + let stream_ended = Self::process_sink_batch( sink_handle.clone(), &mut sink_stream, grpc_resp_tx.clone(), ) .await?; - if done { + + if stream_ended { + // shutting down, hence exiting the loop break; } } @@ -246,7 +251,8 @@ where } /// processes a batch of messages from the client, sends them to the sink handler and sends the - /// responses back to the client batches are separated by an EOT message + /// responses back to the client batches are separated by an EOT message. + /// Returns true if the global bidi-stream has ended, otherwise false. async fn process_sink_batch( sink_handle: Arc, sink_stream: &mut Streaming, @@ -270,12 +276,18 @@ where } }); + let mut global_stream_ended = false; + + // loop until eot happens on stream is closed. loop { let message = match sink_stream.message().await { Ok(Some(m)) => m, Ok(None) => { info!("global bidi stream ended"); - return Ok(true); // bidi stream ended + // NOTE: this will only happen during shutdown. We can be certain that there + // are no messages left hanging in the UDF. + global_stream_ended = true; + break; // bidi stream ended } Err(e) => { return Err(SinkError(InternalError(format!( @@ -315,7 +327,7 @@ where .await .map_err(|e| SinkError(UserDefinedError(e.to_string())))?; - Ok(false) + Ok(global_stream_ended) } /// handles errors from the sink handler and sends them to the client via the response channel @@ -354,7 +366,7 @@ where } } - // performs handshake with the client + /// performs handshake with the client async fn perform_handshake( &self, sink_stream: &mut Streaming,