Skip to content

Commit

Permalink
Bidirectional streaming for source transformer
Browse files Browse the repository at this point in the history
Signed-off-by: Sreekanth <[email protected]>
  • Loading branch information
BulkBeing committed Sep 25, 2024
1 parent 362f2b0 commit d98dd6d
Show file tree
Hide file tree
Showing 2 changed files with 198 additions and 69 deletions.
31 changes: 25 additions & 6 deletions proto/sourcetransform.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,38 @@ service SourceTransform {
// SourceTransformFn applies a function to each request element.
// In addition to map function, SourceTransformFn also supports assigning a new event time to response.
// SourceTransformFn can be used only at source vertex by source data transformer.
rpc SourceTransformFn(SourceTransformRequest) returns (SourceTransformResponse);
rpc SourceTransformFn(stream SourceTransformRequest) returns (stream SourceTransformResponse);

// IsReady is the heartbeat endpoint for gRPC.
rpc IsReady(google.protobuf.Empty) returns (ReadyResponse);
}

/*
* Handshake message between client and server to indicate the start of transmission.
*/
message Handshake {
// Required field indicating the start of transmission.
bool sot = 1;
}

/**
* SourceTransformerRequest represents a request element.
*/
message SourceTransformRequest {
repeated string keys = 1;
bytes value = 2;
google.protobuf.Timestamp event_time = 3;
google.protobuf.Timestamp watermark = 4;
map<string, string> headers = 5;
message Request {
repeated string keys = 1;
bytes value = 2;
google.protobuf.Timestamp event_time = 3;
google.protobuf.Timestamp watermark = 4;
map<string, string> headers = 5;
// This ID is used to uniquely identify a transform request
string id = 6;
}
Request request = 1;
optional Handshake handshake = 2;
}


/**
* SourceTransformerResponse represents a response element.
*/
Expand All @@ -37,6 +52,10 @@ message SourceTransformResponse {
repeated string tags = 4;
}
repeated Result results = 1;
// This ID is used to refer the responses to the request it corresponds to.
string id = 2;
// Handshake message between client and server to indicate the start of transmission.
optional Handshake handshake = 3;
}

/**
Expand Down
236 changes: 173 additions & 63 deletions src/sourcetransform.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,23 @@
use crate::error::Error::SourceTransformerError;
use crate::error::ErrorKind::UserDefinedError;
use crate::shared::{self, prost_timestamp_from_utc};
use crate::error::Error::{self, SourceTransformerError};
use crate::error::ErrorKind;
use crate::shared::{self, prost_timestamp_from_utc, utc_from_timestamp};
use chrono::{DateTime, Utc};
use proto::SourceTransformResponse;
use std::collections::HashMap;
use std::fs;
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tokio_stream::wrappers::ReceiverStream;
use tokio_util::sync::CancellationToken;
use tonic::{async_trait, Request, Response, Status};
use tonic::{async_trait, Request, Response, Status, Streaming};
use tracing::info;

const DEFAULT_MAX_MESSAGE_SIZE: usize = 64 * 1024 * 1024;
const DEFAULT_SOCK_ADDR: &str = "/var/run/numaflow/sourcetransform.sock";
const DEFAULT_SERVER_INFO_FILE: &str = "/var/run/numaflow/sourcetransformer-server-info";
const DEFAULT_CHANNEL_SIZE: usize = 1000;

const DROP: &str = "U+005C__DROP__";

Expand Down Expand Up @@ -218,54 +223,100 @@ impl From<Message> for proto::source_transform_response::Result {
}
}

impl From<proto::SourceTransformRequest> for SourceTransformRequest {
fn from(value: proto::SourceTransformRequest) -> Self {
Self {
keys: value.keys,
value: value.value,
watermark: shared::utc_from_timestamp(value.watermark),
eventtime: shared::utc_from_timestamp(value.event_time),
headers: value.headers,
}
}
}

#[async_trait]
impl<T> proto::source_transform_server::SourceTransform for SourceTransformerService<T>
where
T: SourceTransformer + Send + Sync + 'static,
{
type SourceTransformFnStream = ReceiverStream<Result<SourceTransformResponse, Status>>;

async fn source_transform_fn(
&self,
request: Request<proto::SourceTransformRequest>,
) -> Result<Response<proto::SourceTransformResponse>, Status> {
let request = request.into_inner();
request: Request<Streaming<proto::SourceTransformRequest>>,
) -> Result<Response<Self::SourceTransformFnStream>, Status> {
let mut stream = request.into_inner();
let handler = Arc::clone(&self.handler);
let handle = tokio::spawn(async move { handler.transform(request.into()).await });

// tx (read from client), rx (write to client) pair for gRPC response
let (tx, rx) =
mpsc::channel::<Result<SourceTransformResponse, Status>>(DEFAULT_CHANNEL_SIZE);

// do the handshake first to let the client know that we are ready to receive read requests.
let handshake_request = stream
.message()
.await
.map_err(|e| Status::internal(format!("handshake failed {}", e)))?
.ok_or_else(|| Status::internal("stream closed before handshake"))?;

if let Some(handshake) = handshake_request.handshake {
tx.send(Ok(SourceTransformResponse {
results: vec![],
id: "".to_string(),
handshake: Some(handshake),
}))
.await
.map_err(|e| Status::internal(format!("failed to send handshake response {}", e)))?;
} else {
return Err(Status::invalid_argument("Handshake not present"));
}

let shutdown_tx = self.shutdown_tx.clone();
let cancellation_token = self.cancellation_token.clone();

// Wait for the handler to finish processing the request. If the server is shutting down(token will be cancelled),
// then return an error.
tokio::select! {
result = handle => {
match result {
Ok(messages) => Ok(Response::new(proto::SourceTransformResponse {
results: messages.into_iter().map(|msg| msg.into()).collect(),
})),
Err(e) => {
tracing::error!("Error in source transform handler: {:?}", e);
// Send a shutdown signal to the server to do a graceful shutdown because there was
// a panic in the handler.
shutdown_tx.send(()).await.expect("Sending shutdown signal to gRPC server");
Err(Status::internal(SourceTransformerError(UserDefinedError(e.to_string())).to_string()))
let handle: JoinHandle<Result<(), Error>> = tokio::spawn(async move {
loop {
tokio::select! {
transform_request = stream.message() => {
let transform_request = transform_request.map_err(|e| SourceTransformerError(ErrorKind::InternalError(e.to_string())))?
.ok_or_else(||SourceTransformerError(ErrorKind::InternalError("Stream closed".to_string())))?;

let Some(request) = transform_request.request else {
return Err(SourceTransformerError(ErrorKind::InternalError("Transform request can not be none".to_string())));
};

let message_id = request.id.clone();
let handler_input = SourceTransformRequest{
keys: request.keys,
value: request.value,
watermark: utc_from_timestamp(request.watermark),
eventtime: utc_from_timestamp(request.event_time),
headers: request.headers
};

let handler = handler.clone();
// let messages = handler.transform(handler_input).await;
let udf_tranform_task = tokio::spawn(async move { handler.transform(handler_input).await });
let messages = tokio::select! {
result = udf_tranform_task => {
match result {
Ok(messages) => messages,
Err(e) => {
tracing::error!("Failed to run transform function: {e:?}");
// Send a shutdown signal to the server to do a graceful shutdown because there was
// a panic in the handler.
shutdown_tx.send(()).await.expect("Sending shutdown signal to gRPC server");
return Err(SourceTransformerError(ErrorKind::UserDefinedError("panic in transform UDF".to_string())));
}
}
}
};
tx.send(Ok(SourceTransformResponse{
results: messages.into_iter().map(|msg| msg.into()).collect(),
id: message_id,
handshake: None,
})).await.expect("sending messages to the client over gRPC channel");

}
_ = cancellation_token.cancelled() => {
info!("Cancellation token is cancelled, shutting down");
break;
}
}
},
_ = cancellation_token.cancelled() => {
Err(Status::internal(SourceTransformerError(UserDefinedError("Server is shutting down".to_string())).to_string()))
},
}
}
Ok(())
});

Ok(Response::new(ReceiverStream::new(rx)))
}

async fn is_ready(&self, _: Request<()>) -> Result<Response<proto::ReadyResponse>, Status> {
Expand Down Expand Up @@ -390,11 +441,13 @@ mod tests {

use tempfile::TempDir;
use tokio::net::UnixStream;
use tokio::sync::oneshot;
use tokio::sync::{mpsc, oneshot};
use tokio_stream::wrappers::ReceiverStream;
use tonic::transport::Uri;
use tower::service_fn;

use crate::sourcetransform;
use crate::sourcetransform::proto;
use crate::sourcetransform::proto::source_transform_client::SourceTransformClient;

#[tokio::test]
Expand Down Expand Up @@ -447,21 +500,59 @@ mod tests {
.await?;

let mut client = SourceTransformClient::new(channel);
let request = tonic::Request::new(sourcetransform::proto::SourceTransformRequest {
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
headers: Default::default(),
});

let resp = client.source_transform_fn(request).await?;
let resp = resp.into_inner();
let (tx, rx) = mpsc::channel(2);

let handshake_request = proto::SourceTransformRequest {
request: None,
handshake: Some(proto::Handshake { sot: true }),
};
tx.send(handshake_request).await.unwrap();

let mut stream = tokio::time::timeout(
Duration::from_secs(2),
client.source_transform_fn(ReceiverStream::new(rx)),
)
.await
.map_err(|_| "timeout while getting stream for source_transform_fn")??
.into_inner();

let handshake_resp = stream.message().await?.unwrap();
assert!(
handshake_resp.results.is_empty(),
"The handshake response should not contain any messages"
);
assert!(
handshake_resp.id.is_empty(),
"The message id of the handshake response should be empty"
);
assert!(
handshake_resp.handshake.is_some(),
"Not a valid response for handshake request"
);

let request = sourcetransform::proto::SourceTransformRequest {
request: Some(proto::source_transform_request::Request {
id: "1".to_string(),
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
headers: Default::default(),
}),
handshake: None,
};

tx.send(request).await.unwrap();

let resp = stream.message().await?.unwrap();
assert_eq!(resp.results.len(), 1, "Expected single message from server");
let msg = &resp.results[0];
assert_eq!(msg.keys.first(), Some(&"first".to_owned()));
assert_eq!(msg.value, "hello".as_bytes());

drop(tx);

shutdown_tx
.send(())
.expect("Sending shutdown signal to gRPC server");
Expand Down Expand Up @@ -515,21 +606,40 @@ mod tests {
.await?;

let mut client = SourceTransformClient::new(channel);
let request = tonic::Request::new(sourcetransform::proto::SourceTransformRequest {
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
headers: Default::default(),
});

let resp = client.source_transform_fn(request).await;
assert!(resp.is_err(), "Expected error from server");

if let Err(e) = resp {
assert_eq!(e.code(), tonic::Code::Internal);
assert!(e.message().contains("User Defined Error"));
}
let (tx, rx) = mpsc::channel(2);
let handshake_request = proto::SourceTransformRequest {
request: None,
handshake: Some(proto::Handshake { sot: true }),
};
tx.send(handshake_request).await.unwrap();

let mut stream = tokio::time::timeout(
Duration::from_secs(2),
client.source_transform_fn(ReceiverStream::new(rx)),
)
.await
.map_err(|_| "timeout while getting stream for source_transform_fn")??
.into_inner();

let handshake_resp = stream.message().await?.unwrap();
assert!(
handshake_resp.handshake.is_some(),
"Not a valid response for handshake request"
);

let request = proto::SourceTransformRequest {
request: Some(proto::source_transform_request::Request {
id: "1".to_string(),
keys: vec!["first".into(), "second".into()],
value: "hello".into(),
watermark: Some(prost_types::Timestamp::default()),
event_time: Some(prost_types::Timestamp::default()),
headers: Default::default(),
}),
handshake: None,
};
tx.send(request).await.unwrap();

// server should shut down gracefully because there was a panic in the handler.
for _ in 0..10 {
Expand Down

0 comments on commit d98dd6d

Please sign in to comment.