Skip to content

Commit 542dc78

Browse files
committed
Refactor: Optimize client request handling to use two async tasks per client
Reduced the number of async tasks per client from task-per-request to two dedicated tasks: 1. `client_reader_loop`: Reads and parses requests from the socket, forwarding them for processing. 2. `request_processor_loop`: Manages a `FuturesUnordered` queue, ensuring efficient request execution. Signed-off-by: barshaul <[email protected]>
1 parent 7d79d46 commit 542dc78

File tree

1 file changed

+196
-110
lines changed

1 file changed

+196
-110
lines changed

glide-core/src/socket_listener.rs

+196-110
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ use crate::response;
1212
use crate::response::Response;
1313
use bytes::Bytes;
1414
use directories::BaseDirs;
15+
use futures::{future::poll_fn, stream::FuturesUnordered, StreamExt};
1516
use logger_core::{log_debug, log_error, log_info, log_trace, log_warn};
1617
use once_cell::sync::Lazy;
1718
use protobuf::{Chars, Message};
@@ -25,13 +26,13 @@ use std::collections::HashSet;
2526
use std::ptr::from_mut;
2627
use std::rc::Rc;
2728
use std::sync::RwLock;
28-
use std::{env, str};
29+
use std::{env, fmt, str};
2930
use std::{io, thread};
3031
use thiserror::Error;
3132
use tokio::net::{UnixListener, UnixStream};
3233
use tokio::runtime::Builder;
3334
use tokio::sync::mpsc;
34-
use tokio::sync::mpsc::{channel, Sender};
35+
use tokio::sync::mpsc::{channel, Receiver, Sender};
3536
use tokio::sync::Mutex;
3637
use tokio::task;
3738
use tokio_util::task::LocalPoolHandle;
@@ -459,112 +460,96 @@ fn get_route(
459460
}
460461
}
461462

462-
fn handle_request(request: CommandRequest, mut client: Client, writer: Rc<Writer>) {
463-
task::spawn_local(async move {
464-
let mut updated_inflight_counter = true;
465-
let client_clone = client.clone();
463+
async fn handle_request(request: CommandRequest, mut client: Client, writer: Rc<Writer>) {
464+
let mut updated_inflight_counter = true;
465+
let client_clone = client.clone();
466466

467-
let result = match client.reserve_inflight_request() {
468-
false => {
469-
updated_inflight_counter = false;
470-
Err(ClientUsageError::User(
471-
"Reached maximum inflight requests".to_string(),
472-
))
473-
}
474-
true => match request.command {
475-
Some(action) => match action {
476-
command_request::Command::ClusterScan(cluster_scan_command) => {
477-
cluster_scan(cluster_scan_command, client).await
478-
}
479-
command_request::Command::SingleCommand(command) => {
480-
match get_redis_command(&command) {
481-
Ok(cmd) => match get_route(request.route.0, Some(&cmd)) {
482-
Ok(routes) => send_command(cmd, client, routes).await,
483-
Err(e) => Err(e),
484-
},
467+
let result = match client.reserve_inflight_request() {
468+
false => {
469+
updated_inflight_counter = false;
470+
Err(ClientUsageError::User(
471+
"Reached maximum inflight requests".to_string(),
472+
))
473+
}
474+
true => match request.command {
475+
Some(action) => match action {
476+
command_request::Command::ClusterScan(cluster_scan_command) => {
477+
cluster_scan(cluster_scan_command, client).await
478+
}
479+
command_request::Command::SingleCommand(command) => {
480+
match get_redis_command(&command) {
481+
Ok(cmd) => match get_route(request.route.0, Some(&cmd)) {
482+
Ok(routes) => send_command(cmd, client, routes).await,
485483
Err(e) => Err(e),
486-
}
484+
},
485+
Err(e) => Err(e),
487486
}
488-
command_request::Command::Transaction(transaction) => {
489-
match get_route(request.route.0, None) {
490-
Ok(routes) => send_transaction(transaction, &mut client, routes).await,
491-
Err(e) => Err(e),
492-
}
487+
}
488+
command_request::Command::Transaction(transaction) => {
489+
match get_route(request.route.0, None) {
490+
Ok(routes) => send_transaction(transaction, &mut client, routes).await,
491+
Err(e) => Err(e),
493492
}
494-
command_request::Command::ScriptInvocation(script) => {
495-
match get_route(request.route.0, None) {
496-
Ok(routes) => {
497-
invoke_script(
498-
script.hash,
499-
Some(script.keys),
500-
Some(script.args),
501-
client,
502-
routes,
503-
)
504-
.await
505-
}
506-
Err(e) => Err(e),
493+
}
494+
command_request::Command::ScriptInvocation(script) => {
495+
match get_route(request.route.0, None) {
496+
Ok(routes) => {
497+
invoke_script(
498+
script.hash,
499+
Some(script.keys),
500+
Some(script.args),
501+
client,
502+
routes,
503+
)
504+
.await
507505
}
506+
Err(e) => Err(e),
508507
}
509-
command_request::Command::ScriptInvocationPointers(script) => {
510-
let keys = script
511-
.keys_pointer
512-
.map(|pointer| *unsafe { Box::from_raw(pointer as *mut Vec<Bytes>) });
513-
let args = script
514-
.args_pointer
515-
.map(|pointer| *unsafe { Box::from_raw(pointer as *mut Vec<Bytes>) });
516-
match get_route(request.route.0, None) {
517-
Ok(routes) => {
518-
invoke_script(script.hash, keys, args, client, routes).await
519-
}
520-
Err(e) => Err(e),
521-
}
508+
}
509+
command_request::Command::ScriptInvocationPointers(script) => {
510+
let keys = script
511+
.keys_pointer
512+
.map(|pointer| *unsafe { Box::from_raw(pointer as *mut Vec<Bytes>) });
513+
let args = script
514+
.args_pointer
515+
.map(|pointer| *unsafe { Box::from_raw(pointer as *mut Vec<Bytes>) });
516+
match get_route(request.route.0, None) {
517+
Ok(routes) => invoke_script(script.hash, keys, args, client, routes).await,
518+
Err(e) => Err(e),
522519
}
523-
command_request::Command::UpdateConnectionPassword(
524-
update_connection_password_command,
525-
) => client
526-
.update_connection_password(
527-
update_connection_password_command
528-
.password
529-
.map(|chars| chars.to_string()),
530-
update_connection_password_command.immediate_auth,
531-
)
532-
.await
533-
.map_err(|err| err.into()),
534-
},
535-
None => {
536-
log_debug(
537-
"received error",
538-
format!(
539-
"Received empty request for callback {}",
540-
request.callback_idx
541-
),
542-
);
543-
Err(ClientUsageError::Internal(
544-
"Received empty request".to_string(),
545-
))
546520
}
521+
command_request::Command::UpdateConnectionPassword(
522+
update_connection_password_command,
523+
) => client
524+
.update_connection_password(
525+
update_connection_password_command
526+
.password
527+
.map(|chars| chars.to_string()),
528+
update_connection_password_command.immediate_auth,
529+
)
530+
.await
531+
.map_err(|err| err.into()),
547532
},
548-
};
549-
550-
if updated_inflight_counter {
551-
client_clone.release_inflight_request();
552-
}
553-
554-
let _res = write_result(result, request.callback_idx, &writer).await;
555-
});
556-
}
533+
None => {
534+
log_debug(
535+
"received error",
536+
format!(
537+
"Received empty request for callback {}",
538+
request.callback_idx
539+
),
540+
);
541+
Err(ClientUsageError::Internal(
542+
"Received empty request".to_string(),
543+
))
544+
}
545+
},
546+
};
557547

558-
async fn handle_requests(
559-
received_requests: Vec<CommandRequest>,
560-
client: &Client,
561-
writer: &Rc<Writer>,
562-
) {
563-
for request in received_requests {
564-
handle_request(request, client.clone(), writer.clone());
548+
if updated_inflight_counter {
549+
client_clone.release_inflight_request();
565550
}
566-
// Yield to ensure that the subtasks aren't starved.
567-
task::yield_now().await;
551+
552+
let _res = write_result(result, request.callback_idx, &writer).await;
568553
}
569554

570555
pub fn close_socket(socket_path: &String) {
@@ -605,18 +590,25 @@ async fn wait_for_connection_configuration_and_create_client(
605590
}
606591
}
607592

608-
async fn read_values_loop(
593+
/// Listens for new requests on the socket, parses them, and forwards them to the request processor.
594+
///
595+
/// # Arguments:
596+
/// - `client_listener`: The client's socket listener responsible for receiving incoming requests.
597+
/// - `processor_channel`: A sender channel used to forward the parsed requests for processing.
598+
async fn client_reader_loop(
609599
mut client_listener: UnixStreamListener,
610-
client: &Client,
611-
writer: Rc<Writer>,
600+
processor_channel: Sender<Vec<CommandRequest>>,
612601
) -> ClosingReason {
613602
loop {
614603
match client_listener.next_values().await {
615604
Closed(reason) => {
616605
return reason;
617606
}
618607
ReceivedValues(received_requests) => {
619-
handle_requests(received_requests, client, &writer).await;
608+
if let Err(_err) = processor_channel.send(received_requests).await {
609+
// Failed to send requests because the processor task was unexpectedly closed
610+
return ClosingReason::ClientRequestProcessorClosed;
611+
}
620612
}
621613
}
622614
}
@@ -651,6 +643,43 @@ async fn push_manager_loop(mut push_rx: mpsc::UnboundedReceiver<PushInfo>, write
651643
}
652644
}
653645

646+
// Process all incoming requests received from the socket for this client. This task would be responsible for two things:
647+
// 1. Listening on the channel for new requests and pushing them into the futures queue
648+
// 2. Processing the futures queue by polling the queue to let the futures progress and removing completed futures
649+
// This task will be closed when the channel to send requests through will be closed or when aborted by the caller.
650+
async fn request_processor_loop(
651+
client: Client,
652+
writer: Rc<Writer>,
653+
mut requests_channel: Receiver<Vec<CommandRequest>>,
654+
) {
655+
let mut futures_queue = FuturesUnordered::new();
656+
loop {
657+
tokio::select! {
658+
// Handle new incoming requests from the channel
659+
Some(requests) = requests_channel.recv() => {
660+
requests
661+
.into_iter()
662+
.map(|request| handle_request(request, client.clone(), writer.clone()))
663+
.for_each(|future| futures_queue.push(future));
664+
}
665+
666+
// Poll the futures queue and process the next completed future.
667+
Some(_) = poll_fn(|cx| futures_queue.poll_next_unpin(cx)) => {},
668+
669+
// Exit the loop if both the channel and queue are empty
670+
else => {
671+
if futures_queue.is_empty() {
672+
log_debug(
673+
"request processor",
674+
"Client channel is closed, and no more tasks are pending. Shutting down the request processor."
675+
);
676+
break;
677+
}
678+
}
679+
}
680+
}
681+
}
682+
654683
async fn listen_on_client_stream(socket: UnixStream) {
655684
let socket = Rc::new(socket);
656685
// Spawn a new task to listen on this client's stream
@@ -707,24 +736,48 @@ async fn listen_on_client_stream(socket: UnixStream) {
707736
}
708737
};
709738
log_info("connection", "new connection started");
739+
// Each client has two dedicated tasks:
740+
// 1. `client_reader_loop`: Listens on the socket, receives new requests, parses them,
741+
// and forwards them to the second task.
742+
// 2. `request_processor_loop`: Manages a futures queue by receiving requests from the channel,
743+
// adding them to the queue, and continuously polling them until completion.
744+
let (requests_sender, requests_receiver) = mpsc::channel::<Vec<CommandRequest>>(100);
745+
let cloned_writer = writer.clone();
746+
let request_processor = task::spawn_local(request_processor_loop(
747+
client.clone(),
748+
cloned_writer,
749+
requests_receiver,
750+
));
710751
tokio::select! {
711-
reader_closing = read_values_loop(client_listener, &client, writer.clone()) => {
712-
if let ClosingReason::UnhandledError(err) = reader_closing {
713-
let _res = write_closing_error(ClosingError{err_message: err.to_string()}, u32::MAX, &writer, "client closing").await;
752+
reader_closing = client_reader_loop(client_listener, requests_sender) => {
753+
// Write the closing reason back to the wrapper
754+
if reader_closing.is_error() {
755+
if let Err(err) = write_closing_error(
756+
ClosingError{err_message: reader_closing.to_string()},
757+
u32::MAX,
758+
&writer,
759+
"client closing"
760+
).await {
761+
log_warn(
762+
"client closing",
763+
format!("Failed to write the closing error: {err:?}")
764+
);
765+
}
714766
};
715767
log_trace("client closing", "reader closed");
716768
},
717769
writer_closing = receiver.recv() => {
718-
if let Some(ClosingReason::UnhandledError(err)) = writer_closing {
719-
log_error("client closing", format!("Writer closed with error: {err}"));
720-
} else {
721-
log_trace("client closing", "writer closed");
722-
}
723-
},
770+
match writer_closing {
771+
Some(closing_reason) if closing_reason.is_error() => {
772+
log_error("client closing", format!("Writer closed with error: {closing_reason}"));
773+
},
774+
_ => log_trace("client closing", "writer closed")
775+
}},
724776
_ = push_manager_loop(push_rx, writer.clone()) => {
725777
log_trace("client closing", "push manager closed");
726778
}
727779
}
780+
request_processor.abort();
728781
log_trace("client closing", "closing connection");
729782
}
730783

@@ -733,10 +786,43 @@ async fn listen_on_client_stream(socket: UnixStream) {
733786
pub enum ClosingReason {
734787
/// The socket was closed. This is the expected way that the listener should be closed.
735788
ReadSocketClosed,
789+
/// The client's request processor task was unexpectedly closed.
790+
ClientRequestProcessorClosed,
736791
/// The listener encounter an error it couldn't handle.
737792
UnhandledError(RedisError),
738793
}
739794

795+
impl ClosingReason {
796+
/// Returns `true` if the closing reason was due to an error, otherwise `false`.
797+
pub(crate) fn is_error(&self) -> bool {
798+
match self {
799+
ClosingReason::ReadSocketClosed => false, // Expected closure, not an error
800+
ClosingReason::ClientRequestProcessorClosed => true, // Unexpected closure, treated as an error
801+
ClosingReason::UnhandledError(_) => true, // Error encountered
802+
}
803+
}
804+
}
805+
806+
// Implement Display for ClosingReason
807+
impl fmt::Display for ClosingReason {
808+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
809+
match self {
810+
ClosingReason::ReadSocketClosed => {
811+
write!(f, "The socket was closed")
812+
}
813+
ClosingReason::ClientRequestProcessorClosed => {
814+
write!(
815+
f,
816+
"The client's request processor has been unexpectedly closed."
817+
)
818+
}
819+
ClosingReason::UnhandledError(err) => {
820+
write!(f, "Unhandled error encountered: {}", err)
821+
}
822+
}
823+
}
824+
}
825+
740826
impl From<io::Error> for ClosingReason {
741827
fn from(error: io::Error) -> Self {
742828
UnhandledError(error.into())

0 commit comments

Comments
 (0)