Skip to content

Commit

Permalink
fix unit tests
Browse files Browse the repository at this point in the history
Signed-off-by: Yashash H L <[email protected]>
  • Loading branch information
yhl25 committed Feb 19, 2025
1 parent 95d85c6 commit 1606bec
Show file tree
Hide file tree
Showing 9 changed files with 117 additions and 110 deletions.
17 changes: 1 addition & 16 deletions rust/numaflow-core/src/config/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -733,22 +733,7 @@ mod tests {
transformer_config: None,
}),
metrics_config: Default::default(),
watermark_config: Some(WatermarkConfig::Source(SourceWatermarkConfig {
max_delay: Default::default(),
source_bucket_config: BucketConfig {
vertex: "in",
partitions: 1,
ot_bucket: "default-simple-pipeline-in_SOURCE_OT",
hb_bucket: "default-simple-pipeline-in_SOURCE_PROCESSORS",
},
to_vertex_bucket_config: vec![BucketConfig {
vertex: "out",
partitions: 1,
ot_bucket: "default-simple-pipeline-in-out_OT",
hb_bucket: "default-simple-pipeline-in-out_PROCESSORS",
}],
idle_config: None,
})),
watermark_config: None,
..Default::default()
};

Expand Down
99 changes: 53 additions & 46 deletions rust/numaflow-core/src/mapper/map.rs
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ impl MapHandle {
Some(error) = error_rx.recv() => {
// when we get an error we cancel the token to signal the upstream to stop
// sending new messages, and we empty the input stream and return the error.
if !self.final_result.is_err() {
if self.final_result.is_ok() {
error!(?error, "error received while performing unary map operation");
cln_token.cancel();
self.final_result = Err(error);
Expand Down Expand Up @@ -325,7 +325,7 @@ impl MapHandle {
Some(error) = error_rx.recv() => {
// when we get an error we cancel the token to signal the upstream to stop
// sending new messages, and we empty the input stream and return the error.
if !self.final_result.is_err() {
if self.final_result.is_ok() {
error!(?error, "error received while performing stream map operation");
cln_token.cancel();
self.final_result = Err(error);
Expand Down Expand Up @@ -577,7 +577,6 @@ impl MapHandle {
output_tx.send(mapped_message).await.expect("failed to send response");
}
Some(Err(e)) => {
error!(?e, "failed to map message");
tracker_handle
.discard(read_msg.offset)
.await
Expand Down Expand Up @@ -614,17 +613,17 @@ impl MapHandle {
mod tests {
use std::time::Duration;

use numaflow::{batchmap, map, mapstream};
use numaflow_pb::clients::map::map_client::MapClient;
use tempfile::TempDir;
use tokio::sync::{mpsc::Sender, oneshot};

use super::*;
use crate::{
message::{MessageID, Offset, StringOffset},
shared::grpc::create_rpc_channel,
Result,
};
use numaflow::{batchmap, map, mapstream};
use numaflow_pb::clients::map::map_client::MapClient;
use tempfile::TempDir;
use tokio::sync::{mpsc::Sender, oneshot};
use tokio::time::sleep;

struct SimpleMapper;

Expand Down Expand Up @@ -857,29 +856,33 @@ mod tests {
let (input_tx, input_rx) = mpsc::channel(10);
let input_stream = ReceiverStream::new(input_rx);

let message = Message {
typ: Default::default(),
keys: Arc::from(vec!["first".into()]),
tags: None,
value: "hello".into(),
offset: Offset::String(StringOffset::new("0".to_string(), 0)),
event_time: chrono::Utc::now(),
watermark: None,
id: MessageID {
vertex_name: "vertex_name".to_string().into(),
offset: "0".to_string().into(),
index: 0,
},
headers: Default::default(),
metadata: None,
};

input_tx.send(message).await.unwrap();

let (_output_stream, map_handle) = mapper
.streaming_map(input_stream, CancellationToken::new())
.await?;

// send 10 requests to the mapper
for i in 0..10 {
let message = Message {
typ: Default::default(),
keys: Arc::from(vec![format!("key_{}", i)]),
tags: None,
value: format!("value_{}", i).into(),
offset: Offset::String(StringOffset::new(i.to_string(), 0)),
event_time: chrono::Utc::now(),
watermark: None,
id: MessageID {
vertex_name: "vertex_name".to_string().into(),
offset: i.to_string().into(),
index: i,
},
headers: Default::default(),
metadata: None,
};
input_tx.send(message).await.unwrap();
sleep(Duration::from_millis(10)).await;
}

drop(input_tx);
// Await the join handle and expect an error due to the panic
let result = map_handle.await.unwrap();
assert!(result.is_err(), "Expected an error due to panic");
Expand Down Expand Up @@ -1277,32 +1280,36 @@ mod tests {
)
.await?;

let message = Message {
typ: Default::default(),
keys: Arc::from(vec!["first".into()]),
tags: None,
value: "panic".into(),
offset: Offset::String(StringOffset::new("0".to_string(), 0)),
event_time: chrono::Utc::now(),
watermark: None,
id: MessageID {
vertex_name: "vertex_name".to_string().into(),
offset: "0".to_string().into(),
index: 0,
},
headers: Default::default(),
metadata: None,
};

let (input_tx, input_rx) = mpsc::channel(10);
let input_stream = ReceiverStream::new(input_rx);

input_tx.send(message).await.unwrap();

let (_output_stream, map_handle) = mapper
.streaming_map(input_stream, CancellationToken::new())
.await?;

// send 10 requests to the mapper
for i in 0..10 {
let message = Message {
typ: Default::default(),
keys: Arc::from(vec![format!("key_{}", i)]),
tags: None,
value: format!("value_{}", i).into(),
offset: Offset::String(StringOffset::new(i.to_string(), 0)),
event_time: chrono::Utc::now(),
watermark: None,
id: MessageID {
vertex_name: "vertex_name".to_string().into(),
offset: i.to_string().into(),
index: i,
},
headers: Default::default(),
metadata: None,
};
input_tx.send(message).await.unwrap();
sleep(Duration::from_millis(10)).await;
}

drop(input_tx);
// Await the join handle and expect an error due to the panic
let result = map_handle.await.unwrap();
assert!(result.is_err(), "Expected an error due to panic");
Expand Down
17 changes: 9 additions & 8 deletions rust/numaflow-core/src/mapper/map/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -274,16 +274,17 @@ async fn create_response_stream(
let mut resp_stream = client
.map_fn(Request::new(ReceiverStream::new(read_rx)))
.await
.map_err(|e| Error::Grpc(e))?
.map_err(Error::Grpc)?
.into_inner();

let handshake_response = resp_stream
.message()
.await
.map_err(|e| Error::Grpc(e))?
.ok_or(Error::Mapper(
"failed to receive handshake response".to_string(),
))?;
let handshake_response =
resp_stream
.message()
.await
.map_err(Error::Grpc)?
.ok_or(Error::Mapper(
"failed to receive handshake response".to_string(),
))?;

if handshake_response.handshake.map_or(true, |h| !h.sot) {
return Err(Error::Mapper("invalid handshake response".to_string()));
Expand Down
11 changes: 9 additions & 2 deletions rust/numaflow-core/src/pipeline/isb/jetstream/reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -457,12 +457,13 @@ mod tests {
streams: vec![],
wip_ack_interval: Duration::from_millis(5),
};
let tracker = TrackerHandle::new(None, None);
let js_reader = JetStreamReader::new(
"Map".to_string(),
stream.clone(),
context.clone(),
buf_reader_config,
TrackerHandle::new(None, None),
tracker.clone(),
500,
None,
)
Expand All @@ -475,13 +476,16 @@ mod tests {
.await
.unwrap();

let mut offsets = vec![];
for i in 0..10 {
let offset = Offset::Int(IntOffset::new(i + 1, 0));
offsets.push(offset.clone());
let message = Message {
typ: Default::default(),
keys: Arc::from(vec![format!("key_{}", i)]),
tags: None,
value: format!("message {}", i).as_bytes().to_vec().into(),
offset: Offset::Int(IntOffset::new(i, 0)),
offset,
event_time: Utc::now(),
watermark: None,
id: MessageID {
Expand Down Expand Up @@ -513,6 +517,9 @@ mod tests {
"Expected 10 messages from the jetstream reader"
);

for offset in offsets {
tracker.discard(offset).await.unwrap();
}
reader_cancel_token.cancel();
js_reader_task.await.unwrap().unwrap();

Expand Down
19 changes: 10 additions & 9 deletions rust/numaflow-core/src/sink/user_defined.rs
Original file line number Diff line number Diff line change
Expand Up @@ -57,18 +57,19 @@ impl UserDefinedSink {
let mut resp_stream = client
.sink_fn(Request::new(sink_stream))
.await
.map_err(|e| Error::Grpc(e))?
.map_err(Error::Grpc)?
.into_inner();

// First response from the server will be the handshake response. We need to check if the
// server has accepted the handshake.
let handshake_response = resp_stream
.message()
.await
.map_err(|e| Error::Grpc(e))?
.ok_or(Error::Sink(
"failed to receive handshake response".to_string(),
))?;
let handshake_response =
resp_stream
.message()
.await
.map_err(Error::Grpc)?
.ok_or(Error::Sink(
"failed to receive handshake response".to_string(),
))?;

// Handshake cannot be None during the initial phase, and it has to set `sot` to true.
if handshake_response.handshake.map_or(true, |h| !h.sot) {
Expand Down Expand Up @@ -118,7 +119,7 @@ impl Sink for UserDefinedSink {
.resp_stream
.message()
.await
.map_err(|e| Error::Grpc(e))?
.map_err(Error::Grpc)?
.ok_or(Error::Sink("failed to receive response".to_string()))?;

if response.status.is_some_and(|s| s.eot) {
Expand Down
13 changes: 11 additions & 2 deletions rust/numaflow-core/src/source.rs
Original file line number Diff line number Diff line change
Expand Up @@ -371,6 +371,7 @@ impl Source {
for message in messages.iter() {
let (resp_ack_tx, resp_ack_rx) = oneshot::channel();
let offset = message.offset.clone();
println!("offset: {:?}", offset);

// insert the offset and the ack one shot in the tracker.
self.tracker_handle.insert(message, resp_ack_tx).await?;
Expand Down Expand Up @@ -660,10 +661,11 @@ mod tests {
.map_err(|e| panic!("failed to create source reader: {:?}", e))
.unwrap();

let tracker = TrackerHandle::new(None, None);
let source = Source::new(
5,
SourceType::UserDefinedSource(src_read, src_ack, lag_reader),
TrackerHandle::new(None, None),
tracker.clone(),
true,
None,
None,
Expand All @@ -683,7 +685,11 @@ mod tests {
}

// ack all the messages
Source::ack(sender.clone(), offsets).await.unwrap();
Source::ack(sender.clone(), offsets.clone()).await.unwrap();

for offset in offsets {
tracker.discard(offset).await.unwrap();
}

// since we acked all the messages, pending should be 0
let pending = source.pending().await.unwrap();
Expand All @@ -692,6 +698,9 @@ mod tests {
let partitions = Source::partitions(sender.clone()).await.unwrap();
assert_eq!(partitions, vec![1, 2]);

drop(source);
drop(sender);

cln_token.cancel();
let _ = handle.await.unwrap();
let _ = shutdown_tx.send(());
Expand Down
Loading

0 comments on commit 1606bec

Please sign in to comment.