From d98dd6dad4a2a0ebcf8a307b176e20c7f9eccf52 Mon Sep 17 00:00:00 2001 From: Sreekanth Date: Wed, 25 Sep 2024 20:10:30 +0530 Subject: [PATCH] Bidirectional streaming for source transformer Signed-off-by: Sreekanth --- proto/sourcetransform.proto | 31 ++++- src/sourcetransform.rs | 236 ++++++++++++++++++++++++++---------- 2 files changed, 198 insertions(+), 69 deletions(-) diff --git a/proto/sourcetransform.proto b/proto/sourcetransform.proto index 18e045c..90a2a64 100644 --- a/proto/sourcetransform.proto +++ b/proto/sourcetransform.proto @@ -9,23 +9,38 @@ service SourceTransform { // SourceTransformFn applies a function to each request element. // In addition to map function, SourceTransformFn also supports assigning a new event time to response. // SourceTransformFn can be used only at source vertex by source data transformer. - rpc SourceTransformFn(SourceTransformRequest) returns (SourceTransformResponse); + rpc SourceTransformFn(stream SourceTransformRequest) returns (stream SourceTransformResponse); // IsReady is the heartbeat endpoint for gRPC. rpc IsReady(google.protobuf.Empty) returns (ReadyResponse); } +/* + * Handshake message between client and server to indicate the start of transmission. + */ + message Handshake { + // Required field indicating the start of transmission. + bool sot = 1; +} + /** * SourceTransformerRequest represents a request element. */ message SourceTransformRequest { - repeated string keys = 1; - bytes value = 2; - google.protobuf.Timestamp event_time = 3; - google.protobuf.Timestamp watermark = 4; - map headers = 5; + message Request { + repeated string keys = 1; + bytes value = 2; + google.protobuf.Timestamp event_time = 3; + google.protobuf.Timestamp watermark = 4; + map headers = 5; + // This ID is used to uniquely identify a transform request + string id = 6; + } + Request request = 1; + optional Handshake handshake = 2; } + /** * SourceTransformerResponse represents a response element. */ @@ -37,6 +52,10 @@ message SourceTransformResponse { repeated string tags = 4; } repeated Result results = 1; + // This ID is used to refer the responses to the request it corresponds to. + string id = 2; + // Handshake message between client and server to indicate the start of transmission. + optional Handshake handshake = 3; } /** diff --git a/src/sourcetransform.rs b/src/sourcetransform.rs index 25f06c7..a78cc60 100644 --- a/src/sourcetransform.rs +++ b/src/sourcetransform.rs @@ -1,18 +1,23 @@ -use crate::error::Error::SourceTransformerError; -use crate::error::ErrorKind::UserDefinedError; -use crate::shared::{self, prost_timestamp_from_utc}; +use crate::error::Error::{self, SourceTransformerError}; +use crate::error::ErrorKind; +use crate::shared::{self, prost_timestamp_from_utc, utc_from_timestamp}; use chrono::{DateTime, Utc}; +use proto::SourceTransformResponse; use std::collections::HashMap; use std::fs; use std::path::PathBuf; use std::sync::Arc; use tokio::sync::{mpsc, oneshot}; +use tokio::task::JoinHandle; +use tokio_stream::wrappers::ReceiverStream; use tokio_util::sync::CancellationToken; -use tonic::{async_trait, Request, Response, Status}; +use tonic::{async_trait, Request, Response, Status, Streaming}; +use tracing::info; const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024; const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sourcetransform.sock"; const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcetransformer-server-info"; +const DEFAULT_CHANNEL_SIZE: usize = 1000; const DROP: &str = "U+005C__DROP__"; @@ -218,54 +223,100 @@ impl From for proto::source_transform_response::Result { } } -impl From for SourceTransformRequest { - fn from(value: proto::SourceTransformRequest) -> Self { - Self { - keys: value.keys, - value: value.value, - watermark: shared::utc_from_timestamp(value.watermark), - eventtime: shared::utc_from_timestamp(value.event_time), - headers: value.headers, - } - } -} - #[async_trait] impl proto::source_transform_server::SourceTransform for SourceTransformerService where T: SourceTransformer + Send + Sync + 'static, { + type SourceTransformFnStream = ReceiverStream>; + async fn source_transform_fn( &self, - request: Request, - ) -> Result, Status> { - let request = request.into_inner(); + request: Request>, + ) -> Result, Status> { + let mut stream = request.into_inner(); let handler = Arc::clone(&self.handler); - let handle = tokio::spawn(async move { handler.transform(request.into()).await }); + + // tx (read from client), rx (write to client) pair for gRPC response + let (tx, rx) = + mpsc::channel::>(DEFAULT_CHANNEL_SIZE); + + // do the handshake first to let the client know that we are ready to receive read requests. + let handshake_request = stream + .message() + .await + .map_err(|e| Status::internal(format!("handshake failed {}", e)))? + .ok_or_else(|| Status::internal("stream closed before handshake"))?; + + if let Some(handshake) = handshake_request.handshake { + tx.send(Ok(SourceTransformResponse { + results: vec![], + id: "".to_string(), + handshake: Some(handshake), + })) + .await + .map_err(|e| Status::internal(format!("failed to send handshake response {}", e)))?; + } else { + return Err(Status::invalid_argument("Handshake not present")); + } + let shutdown_tx = self.shutdown_tx.clone(); let cancellation_token = self.cancellation_token.clone(); - // Wait for the handler to finish processing the request. If the server is shutting down(token will be cancelled), - // then return an error. - tokio::select! { - result = handle => { - match result { - Ok(messages) => Ok(Response::new(proto::SourceTransformResponse { - results: messages.into_iter().map(|msg| msg.into()).collect(), - })), - Err(e) => { - tracing::error!("Error in source transform handler: {:?}", e); - // Send a shutdown signal to the server to do a graceful shutdown because there was - // a panic in the handler. - shutdown_tx.send(()).await.expect("Sending shutdown signal to gRPC server"); - Err(Status::internal(SourceTransformerError(UserDefinedError(e.to_string())).to_string())) + let handle: JoinHandle> = tokio::spawn(async move { + loop { + tokio::select! { + transform_request = stream.message() => { + let transform_request = transform_request.map_err(|e| SourceTransformerError(ErrorKind::InternalError(e.to_string())))? + .ok_or_else(||SourceTransformerError(ErrorKind::InternalError("Stream closed".to_string())))?; + + let Some(request) = transform_request.request else { + return Err(SourceTransformerError(ErrorKind::InternalError("Transform request can not be none".to_string()))); + }; + + let message_id = request.id.clone(); + let handler_input = SourceTransformRequest{ + keys: request.keys, + value: request.value, + watermark: utc_from_timestamp(request.watermark), + eventtime: utc_from_timestamp(request.event_time), + headers: request.headers + }; + + let handler = handler.clone(); + // let messages = handler.transform(handler_input).await; + let udf_tranform_task = tokio::spawn(async move { handler.transform(handler_input).await }); + let messages = tokio::select! { + result = udf_tranform_task => { + match result { + Ok(messages) => messages, + Err(e) => { + tracing::error!("Failed to run transform function: {e:?}"); + // Send a shutdown signal to the server to do a graceful shutdown because there was + // a panic in the handler. + shutdown_tx.send(()).await.expect("Sending shutdown signal to gRPC server"); + return Err(SourceTransformerError(ErrorKind::UserDefinedError("panic in transform UDF".to_string()))); + } + } + } + }; + tx.send(Ok(SourceTransformResponse{ + results: messages.into_iter().map(|msg| msg.into()).collect(), + id: message_id, + handshake: None, + })).await.expect("sending messages to the client over gRPC channel"); + + } + _ = cancellation_token.cancelled() => { + info!("Cancellation token is cancelled, shutting down"); + break; } } - }, - _ = cancellation_token.cancelled() => { - Err(Status::internal(SourceTransformerError(UserDefinedError("Server is shutting down".to_string())).to_string())) - }, - } + } + Ok(()) + }); + + Ok(Response::new(ReceiverStream::new(rx))) } async fn is_ready(&self, _: Request<()>) -> Result, Status> { @@ -390,11 +441,13 @@ mod tests { use tempfile::TempDir; use tokio::net::UnixStream; - use tokio::sync::oneshot; + use tokio::sync::{mpsc, oneshot}; + use tokio_stream::wrappers::ReceiverStream; use tonic::transport::Uri; use tower::service_fn; use crate::sourcetransform; + use crate::sourcetransform::proto; use crate::sourcetransform::proto::source_transform_client::SourceTransformClient; #[tokio::test] @@ -447,21 +500,59 @@ mod tests { .await?; let mut client = SourceTransformClient::new(channel); - let request = tonic::Request::new(sourcetransform::proto::SourceTransformRequest { - keys: vec!["first".into(), "second".into()], - value: "hello".into(), - watermark: Some(prost_types::Timestamp::default()), - event_time: Some(prost_types::Timestamp::default()), - headers: Default::default(), - }); - let resp = client.source_transform_fn(request).await?; - let resp = resp.into_inner(); + let (tx, rx) = mpsc::channel(2); + + let handshake_request = proto::SourceTransformRequest { + request: None, + handshake: Some(proto::Handshake { sot: true }), + }; + tx.send(handshake_request).await.unwrap(); + + let mut stream = tokio::time::timeout( + Duration::from_secs(2), + client.source_transform_fn(ReceiverStream::new(rx)), + ) + .await + .map_err(|_| "timeout while getting stream for source_transform_fn")?? + .into_inner(); + + let handshake_resp = stream.message().await?.unwrap(); + assert!( + handshake_resp.results.is_empty(), + "The handshake response should not contain any messages" + ); + assert!( + handshake_resp.id.is_empty(), + "The message id of the handshake response should be empty" + ); + assert!( + handshake_resp.handshake.is_some(), + "Not a valid response for handshake request" + ); + + let request = sourcetransform::proto::SourceTransformRequest { + request: Some(proto::source_transform_request::Request { + id: "1".to_string(), + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + headers: Default::default(), + }), + handshake: None, + }; + + tx.send(request).await.unwrap(); + + let resp = stream.message().await?.unwrap(); assert_eq!(resp.results.len(), 1, "Expected single message from server"); let msg = &resp.results[0]; assert_eq!(msg.keys.first(), Some(&"first".to_owned())); assert_eq!(msg.value, "hello".as_bytes()); + drop(tx); + shutdown_tx .send(()) .expect("Sending shutdown signal to gRPC server"); @@ -515,21 +606,40 @@ mod tests { .await?; let mut client = SourceTransformClient::new(channel); - let request = tonic::Request::new(sourcetransform::proto::SourceTransformRequest { - keys: vec!["first".into(), "second".into()], - value: "hello".into(), - watermark: Some(prost_types::Timestamp::default()), - event_time: Some(prost_types::Timestamp::default()), - headers: Default::default(), - }); - - let resp = client.source_transform_fn(request).await; - assert!(resp.is_err(), "Expected error from server"); - if let Err(e) = resp { - assert_eq!(e.code(), tonic::Code::Internal); - assert!(e.message().contains("User Defined Error")); - } + let (tx, rx) = mpsc::channel(2); + let handshake_request = proto::SourceTransformRequest { + request: None, + handshake: Some(proto::Handshake { sot: true }), + }; + tx.send(handshake_request).await.unwrap(); + + let mut stream = tokio::time::timeout( + Duration::from_secs(2), + client.source_transform_fn(ReceiverStream::new(rx)), + ) + .await + .map_err(|_| "timeout while getting stream for source_transform_fn")?? + .into_inner(); + + let handshake_resp = stream.message().await?.unwrap(); + assert!( + handshake_resp.handshake.is_some(), + "Not a valid response for handshake request" + ); + + let request = proto::SourceTransformRequest { + request: Some(proto::source_transform_request::Request { + id: "1".to_string(), + keys: vec!["first".into(), "second".into()], + value: "hello".into(), + watermark: Some(prost_types::Timestamp::default()), + event_time: Some(prost_types::Timestamp::default()), + headers: Default::default(), + }), + handshake: None, + }; + tx.send(request).await.unwrap(); // server should shut down gracefully because there was a panic in the handler. for _ in 0..10 {