diff --git a/rust/numaflow-core/src/shared/create_components.rs b/rust/numaflow-core/src/shared/create_components.rs index ecffe43fef..94a4357609 100644 --- a/rust/numaflow-core/src/shared/create_components.rs +++ b/rust/numaflow-core/src/shared/create_components.rs @@ -360,12 +360,14 @@ pub async fn create_source( *get_vertex_replica(), ) .await?; + // for serving we use batch size as one as we are not batching the messages + // and read ahead is enabled as it supports it. Ok(( Source::new( - batch_size, + 1, source::SourceType::Serving(serving), tracker_handle, - source_config.read_ahead, + true, transformer, watermark_handle, ), diff --git a/rust/numaflow-core/src/source.rs b/rust/numaflow-core/src/source.rs index fe1c421046..f5423333bc 100644 --- a/rust/numaflow-core/src/source.rs +++ b/rust/numaflow-core/src/source.rs @@ -4,6 +4,7 @@ //! [Source]: https://numaflow.numaproj.io/user-guide/sources/overview/ //! [Watermark]: https://numaflow.numaproj.io/core-concepts/watermarks/ +use std::cmp::max; use std::sync::Arc; use tracing::warn; @@ -55,6 +56,8 @@ use serving::ServingSource; use crate::transformer::Transformer; use crate::watermark::source::SourceWatermarkHandle; +const MAX_ACK_PENDING: usize = 10000; + /// Set of Read related items that has to be implemented to become a Source. pub(crate) trait SourceReader { #[allow(dead_code)] @@ -321,8 +324,15 @@ impl Source { info!(?self.read_batch_size, "Started streaming source with batch size"); let handle = tokio::spawn(async move { // this semaphore is used only if read-ahead is disabled. we hold this semaphore to - // make sure we can read only if the current inflight ones are ack'ed. - let semaphore = Arc::new(Semaphore::new(1)); + // make sure we can read only if the current inflight ones are ack'ed. If read ahead + // is disabled you can have upto (max_ack_pending / read_batch_size) ack tasks. We + // divide by read_batch_size because we do batch acking in source. + let max_ack_tasks = match &self.read_ahead { + true => MAX_ACK_PENDING / self.read_batch_size, + false => 1, + }; + let semaphore = Arc::new(Semaphore::new(max_ack_tasks)); + let mut result = Ok(()); loop { if cln_token.is_cancelled() { @@ -330,14 +340,12 @@ impl Source { break; } - if !self.read_ahead { - // Acquire the semaphore permit before reading the next batch to make - // sure we are not reading ahead and all the inflight messages are acked. - let _permit = Arc::clone(&semaphore) - .acquire_owned() - .await - .expect("acquiring permit should not fail"); - } + // Acquire the semaphore permit before reading the next batch to make + // sure we are not reading ahead and all the inflight messages are acked. + let _permit = Arc::clone(&semaphore) + .acquire_owned() + .await + .expect("acquiring permit should not fail"); let read_start_time = Instant::now(); let messages = match Self::read(self.sender.clone()).await { @@ -371,7 +379,6 @@ impl Source { for message in messages.iter() { let (resp_ack_tx, resp_ack_rx) = oneshot::channel(); let offset = message.offset.clone(); - println!("offset: {:?}", offset); // insert the offset and the ack one shot in the tracker. self.tracker_handle.insert(message, resp_ack_tx).await?; @@ -436,8 +443,11 @@ impl Source { } } info!(status=?result, "Source stopped, waiting for inflight messages to be acked"); + // wait for all the ack tasks to be completed before stopping the source, since we give + // a permit for each ack task all the permits should be released when the ack tasks are + // done, we can verify this by trying to acquire the permit for max_ack_tasks. let _permit = Arc::clone(&semaphore) - .acquire_owned() + .acquire_many_owned(max_ack_tasks as u32) .await .expect("acquiring permit should not fail"); info!("All inflight messages are acked. Source stopped."); diff --git a/rust/serving/src/config.rs b/rust/serving/src/config.rs index 828270540a..64484feef9 100644 --- a/rust/serving/src/config.rs +++ b/rust/serving/src/config.rs @@ -23,7 +23,7 @@ pub const DEFAULT_ID_HEADER: &str = "X-Numaflow-Id"; pub const DEFAULT_CALLBACK_URL_HEADER_KEY: &str = "X-Numaflow-Callback-Url"; pub const DEFAULT_REDIS_TTL_IN_SECS: u32 = 86400; -pub fn generate_certs() -> std::result::Result<(Certificate, KeyPair), String> { +pub fn generate_certs() -> Result<(Certificate, KeyPair), String> { let CertifiedKey { cert, key_pair } = generate_simple_self_signed(vec!["localhost".into()]) .map_err(|e| format!("Failed to generate cert {:?}", e))?; Ok((cert, key_pair)) @@ -74,7 +74,7 @@ impl Default for Settings { app_listen_port: 3000, metrics_server_listen_port: 3001, upstream_addr: "localhost:8888".to_owned(), - drain_timeout_secs: 10, + drain_timeout_secs: 600, redis: RedisConfig::default(), host_ip: "127.0.0.1".to_owned(), api_auth_token: None,