@@ -12,6 +12,7 @@ use crate::response;
12
12
use crate :: response:: Response ;
13
13
use bytes:: Bytes ;
14
14
use directories:: BaseDirs ;
15
+ use futures:: { future:: poll_fn, stream:: FuturesUnordered , StreamExt } ;
15
16
use logger_core:: { log_debug, log_error, log_info, log_trace, log_warn} ;
16
17
use once_cell:: sync:: Lazy ;
17
18
use protobuf:: { Chars , Message } ;
@@ -25,13 +26,13 @@ use std::collections::HashSet;
25
26
use std:: ptr:: from_mut;
26
27
use std:: rc:: Rc ;
27
28
use std:: sync:: RwLock ;
28
- use std:: { env, str} ;
29
+ use std:: { env, fmt , str} ;
29
30
use std:: { io, thread} ;
30
31
use thiserror:: Error ;
31
32
use tokio:: net:: { UnixListener , UnixStream } ;
32
33
use tokio:: runtime:: Builder ;
33
34
use tokio:: sync:: mpsc;
34
- use tokio:: sync:: mpsc:: { channel, Sender } ;
35
+ use tokio:: sync:: mpsc:: { channel, Receiver , Sender } ;
35
36
use tokio:: sync:: Mutex ;
36
37
use tokio:: task;
37
38
use tokio_util:: task:: LocalPoolHandle ;
@@ -459,112 +460,96 @@ fn get_route(
459
460
}
460
461
}
461
462
462
- fn handle_request ( request : CommandRequest , mut client : Client , writer : Rc < Writer > ) {
463
- task:: spawn_local ( async move {
464
- let mut updated_inflight_counter = true ;
465
- let client_clone = client. clone ( ) ;
463
+ async fn handle_request ( request : CommandRequest , mut client : Client , writer : Rc < Writer > ) {
464
+ let mut updated_inflight_counter = true ;
465
+ let client_clone = client. clone ( ) ;
466
466
467
- let result = match client. reserve_inflight_request ( ) {
468
- false => {
469
- updated_inflight_counter = false ;
470
- Err ( ClientUsageError :: User (
471
- "Reached maximum inflight requests" . to_string ( ) ,
472
- ) )
473
- }
474
- true => match request. command {
475
- Some ( action) => match action {
476
- command_request:: Command :: ClusterScan ( cluster_scan_command) => {
477
- cluster_scan ( cluster_scan_command, client) . await
478
- }
479
- command_request:: Command :: SingleCommand ( command) => {
480
- match get_redis_command ( & command) {
481
- Ok ( cmd) => match get_route ( request. route . 0 , Some ( & cmd) ) {
482
- Ok ( routes) => send_command ( cmd, client, routes) . await ,
483
- Err ( e) => Err ( e) ,
484
- } ,
467
+ let result = match client. reserve_inflight_request ( ) {
468
+ false => {
469
+ updated_inflight_counter = false ;
470
+ Err ( ClientUsageError :: User (
471
+ "Reached maximum inflight requests" . to_string ( ) ,
472
+ ) )
473
+ }
474
+ true => match request. command {
475
+ Some ( action) => match action {
476
+ command_request:: Command :: ClusterScan ( cluster_scan_command) => {
477
+ cluster_scan ( cluster_scan_command, client) . await
478
+ }
479
+ command_request:: Command :: SingleCommand ( command) => {
480
+ match get_redis_command ( & command) {
481
+ Ok ( cmd) => match get_route ( request. route . 0 , Some ( & cmd) ) {
482
+ Ok ( routes) => send_command ( cmd, client, routes) . await ,
485
483
Err ( e) => Err ( e) ,
486
- }
484
+ } ,
485
+ Err ( e) => Err ( e) ,
487
486
}
488
- command_request :: Command :: Transaction ( transaction ) => {
489
- match get_route ( request . route . 0 , None ) {
490
- Ok ( routes ) => send_transaction ( transaction , & mut client , routes ) . await ,
491
- Err ( e ) => Err ( e ) ,
492
- }
487
+ }
488
+ command_request :: Command :: Transaction ( transaction ) => {
489
+ match get_route ( request . route . 0 , None ) {
490
+ Ok ( routes ) => send_transaction ( transaction , & mut client , routes ) . await ,
491
+ Err ( e ) => Err ( e ) ,
493
492
}
494
- command_request:: Command :: ScriptInvocation ( script) => {
495
- match get_route ( request. route . 0 , None ) {
496
- Ok ( routes) => {
497
- invoke_script (
498
- script. hash ,
499
- Some ( script. keys ) ,
500
- Some ( script. args ) ,
501
- client,
502
- routes,
503
- )
504
- . await
505
- }
506
- Err ( e) => Err ( e) ,
493
+ }
494
+ command_request:: Command :: ScriptInvocation ( script) => {
495
+ match get_route ( request. route . 0 , None ) {
496
+ Ok ( routes) => {
497
+ invoke_script (
498
+ script. hash ,
499
+ Some ( script. keys ) ,
500
+ Some ( script. args ) ,
501
+ client,
502
+ routes,
503
+ )
504
+ . await
507
505
}
506
+ Err ( e) => Err ( e) ,
508
507
}
509
- command_request:: Command :: ScriptInvocationPointers ( script) => {
510
- let keys = script
511
- . keys_pointer
512
- . map ( |pointer| * unsafe { Box :: from_raw ( pointer as * mut Vec < Bytes > ) } ) ;
513
- let args = script
514
- . args_pointer
515
- . map ( |pointer| * unsafe { Box :: from_raw ( pointer as * mut Vec < Bytes > ) } ) ;
516
- match get_route ( request. route . 0 , None ) {
517
- Ok ( routes) => {
518
- invoke_script ( script. hash , keys, args, client, routes) . await
519
- }
520
- Err ( e) => Err ( e) ,
521
- }
508
+ }
509
+ command_request:: Command :: ScriptInvocationPointers ( script) => {
510
+ let keys = script
511
+ . keys_pointer
512
+ . map ( |pointer| * unsafe { Box :: from_raw ( pointer as * mut Vec < Bytes > ) } ) ;
513
+ let args = script
514
+ . args_pointer
515
+ . map ( |pointer| * unsafe { Box :: from_raw ( pointer as * mut Vec < Bytes > ) } ) ;
516
+ match get_route ( request. route . 0 , None ) {
517
+ Ok ( routes) => invoke_script ( script. hash , keys, args, client, routes) . await ,
518
+ Err ( e) => Err ( e) ,
522
519
}
523
- command_request:: Command :: UpdateConnectionPassword (
524
- update_connection_password_command,
525
- ) => client
526
- . update_connection_password (
527
- update_connection_password_command
528
- . password
529
- . map ( |chars| chars. to_string ( ) ) ,
530
- update_connection_password_command. immediate_auth ,
531
- )
532
- . await
533
- . map_err ( |err| err. into ( ) ) ,
534
- } ,
535
- None => {
536
- log_debug (
537
- "received error" ,
538
- format ! (
539
- "Received empty request for callback {}" ,
540
- request. callback_idx
541
- ) ,
542
- ) ;
543
- Err ( ClientUsageError :: Internal (
544
- "Received empty request" . to_string ( ) ,
545
- ) )
546
520
}
521
+ command_request:: Command :: UpdateConnectionPassword (
522
+ update_connection_password_command,
523
+ ) => client
524
+ . update_connection_password (
525
+ update_connection_password_command
526
+ . password
527
+ . map ( |chars| chars. to_string ( ) ) ,
528
+ update_connection_password_command. immediate_auth ,
529
+ )
530
+ . await
531
+ . map_err ( |err| err. into ( ) ) ,
547
532
} ,
548
- } ;
549
-
550
- if updated_inflight_counter {
551
- client_clone. release_inflight_request ( ) ;
552
- }
553
-
554
- let _res = write_result ( result, request. callback_idx , & writer) . await ;
555
- } ) ;
556
- }
533
+ None => {
534
+ log_debug (
535
+ "received error" ,
536
+ format ! (
537
+ "Received empty request for callback {}" ,
538
+ request. callback_idx
539
+ ) ,
540
+ ) ;
541
+ Err ( ClientUsageError :: Internal (
542
+ "Received empty request" . to_string ( ) ,
543
+ ) )
544
+ }
545
+ } ,
546
+ } ;
557
547
558
- async fn handle_requests (
559
- received_requests : Vec < CommandRequest > ,
560
- client : & Client ,
561
- writer : & Rc < Writer > ,
562
- ) {
563
- for request in received_requests {
564
- handle_request ( request, client. clone ( ) , writer. clone ( ) ) ;
548
+ if updated_inflight_counter {
549
+ client_clone. release_inflight_request ( ) ;
565
550
}
566
- // Yield to ensure that the subtasks aren't starved.
567
- task :: yield_now ( ) . await ;
551
+
552
+ let _res = write_result ( result , request . callback_idx , & writer ) . await ;
568
553
}
569
554
570
555
pub fn close_socket ( socket_path : & String ) {
@@ -605,18 +590,25 @@ async fn wait_for_connection_configuration_and_create_client(
605
590
}
606
591
}
607
592
608
- async fn read_values_loop (
593
+ /// Listens for new requests on the socket, parses them, and forwards them to the request processor.
594
+ ///
595
+ /// # Arguments:
596
+ /// - `client_listener`: The client's socket listener responsible for receiving incoming requests.
597
+ /// - `processor_channel`: A sender channel used to forward the parsed requests for processing.
598
+ async fn client_reader_loop (
609
599
mut client_listener : UnixStreamListener ,
610
- client : & Client ,
611
- writer : Rc < Writer > ,
600
+ processor_channel : Sender < Vec < CommandRequest > > ,
612
601
) -> ClosingReason {
613
602
loop {
614
603
match client_listener. next_values ( ) . await {
615
604
Closed ( reason) => {
616
605
return reason;
617
606
}
618
607
ReceivedValues ( received_requests) => {
619
- handle_requests ( received_requests, client, & writer) . await ;
608
+ if let Err ( _err) = processor_channel. send ( received_requests) . await {
609
+ // Failed to send requests because the processor task was unexpectedly closed
610
+ return ClosingReason :: ClientRequestProcessorClosed ;
611
+ }
620
612
}
621
613
}
622
614
}
@@ -651,6 +643,43 @@ async fn push_manager_loop(mut push_rx: mpsc::UnboundedReceiver<PushInfo>, write
651
643
}
652
644
}
653
645
646
+ // Process all incoming requests received from the socket for this client. This task would be responsible for two things:
647
+ // 1. Listening on the channel for new requests and pushing them into the futures queue
648
+ // 2. Processing the futures queue by polling the queue to let the futures progress and removing completed futures
649
+ // This task will be closed when the channel to send requests through will be closed or when aborted by the caller.
650
+ async fn request_processor_loop (
651
+ client : Client ,
652
+ writer : Rc < Writer > ,
653
+ mut requests_channel : Receiver < Vec < CommandRequest > > ,
654
+ ) {
655
+ let mut futures_queue = FuturesUnordered :: new ( ) ;
656
+ loop {
657
+ tokio:: select! {
658
+ // Handle new incoming requests from the channel
659
+ Some ( requests) = requests_channel. recv( ) => {
660
+ requests
661
+ . into_iter( )
662
+ . map( |request| handle_request( request, client. clone( ) , writer. clone( ) ) )
663
+ . for_each( |future| futures_queue. push( future) ) ;
664
+ }
665
+
666
+ // Poll the futures queue and process the next completed future.
667
+ Some ( _) = poll_fn( |cx| futures_queue. poll_next_unpin( cx) ) => { } ,
668
+
669
+ // Exit the loop if both the channel and queue are empty
670
+ else => {
671
+ if futures_queue. is_empty( ) {
672
+ log_debug(
673
+ "request processor" ,
674
+ "Client channel is closed, and no more tasks are pending. Shutting down the request processor."
675
+ ) ;
676
+ break ;
677
+ }
678
+ }
679
+ }
680
+ }
681
+ }
682
+
654
683
async fn listen_on_client_stream ( socket : UnixStream ) {
655
684
let socket = Rc :: new ( socket) ;
656
685
// Spawn a new task to listen on this client's stream
@@ -707,24 +736,48 @@ async fn listen_on_client_stream(socket: UnixStream) {
707
736
}
708
737
} ;
709
738
log_info ( "connection" , "new connection started" ) ;
739
+ // Each client has two dedicated tasks:
740
+ // 1. `client_reader_loop`: Listens on the socket, receives new requests, parses them,
741
+ // and forwards them to the second task.
742
+ // 2. `request_processor_loop`: Manages a futures queue by receiving requests from the channel,
743
+ // adding them to the queue, and continuously polling them until completion.
744
+ let ( requests_sender, requests_receiver) = mpsc:: channel :: < Vec < CommandRequest > > ( 100 ) ;
745
+ let cloned_writer = writer. clone ( ) ;
746
+ let request_processor = task:: spawn_local ( request_processor_loop (
747
+ client. clone ( ) ,
748
+ cloned_writer,
749
+ requests_receiver,
750
+ ) ) ;
710
751
tokio:: select! {
711
- reader_closing = read_values_loop( client_listener, & client, writer. clone( ) ) => {
712
- if let ClosingReason :: UnhandledError ( err) = reader_closing {
713
- let _res = write_closing_error( ClosingError { err_message: err. to_string( ) } , u32 :: MAX , & writer, "client closing" ) . await ;
752
+ reader_closing = client_reader_loop( client_listener, requests_sender) => {
753
+ // Write the closing reason back to the wrapper
754
+ if reader_closing. is_error( ) {
755
+ if let Err ( err) = write_closing_error(
756
+ ClosingError { err_message: reader_closing. to_string( ) } ,
757
+ u32 :: MAX ,
758
+ & writer,
759
+ "client closing"
760
+ ) . await {
761
+ log_warn(
762
+ "client closing" ,
763
+ format!( "Failed to write the closing error: {err:?}" )
764
+ ) ;
765
+ }
714
766
} ;
715
767
log_trace( "client closing" , "reader closed" ) ;
716
768
} ,
717
769
writer_closing = receiver. recv( ) => {
718
- if let Some ( ClosingReason :: UnhandledError ( err ) ) = writer_closing {
719
- log_error ( "client closing" , format! ( "Writer closed with error: {err}" ) ) ;
720
- } else {
721
- log_trace ( "client closing" , "writer closed" ) ;
722
- }
723
- } ,
770
+ match writer_closing {
771
+ Some ( closing_reason ) if closing_reason . is_error ( ) => {
772
+ log_error ( "client closing" , format! ( "Writer closed with error: {closing_reason}" ) ) ;
773
+ } ,
774
+ _ => log_trace ( "client closing" , "writer closed" )
775
+ } } ,
724
776
_ = push_manager_loop( push_rx, writer. clone( ) ) => {
725
777
log_trace( "client closing" , "push manager closed" ) ;
726
778
}
727
779
}
780
+ request_processor. abort ( ) ;
728
781
log_trace ( "client closing" , "closing connection" ) ;
729
782
}
730
783
@@ -733,10 +786,43 @@ async fn listen_on_client_stream(socket: UnixStream) {
733
786
pub enum ClosingReason {
734
787
/// The socket was closed. This is the expected way that the listener should be closed.
735
788
ReadSocketClosed ,
789
+ /// The client's request processor task was unexpectedly closed.
790
+ ClientRequestProcessorClosed ,
736
791
/// The listener encounter an error it couldn't handle.
737
792
UnhandledError ( RedisError ) ,
738
793
}
739
794
795
+ impl ClosingReason {
796
+ /// Returns `true` if the closing reason was due to an error, otherwise `false`.
797
+ pub ( crate ) fn is_error ( & self ) -> bool {
798
+ match self {
799
+ ClosingReason :: ReadSocketClosed => false , // Expected closure, not an error
800
+ ClosingReason :: ClientRequestProcessorClosed => true , // Unexpected closure, treated as an error
801
+ ClosingReason :: UnhandledError ( _) => true , // Error encountered
802
+ }
803
+ }
804
+ }
805
+
806
+ // Implement Display for ClosingReason
807
+ impl fmt:: Display for ClosingReason {
808
+ fn fmt ( & self , f : & mut fmt:: Formatter < ' _ > ) -> fmt:: Result {
809
+ match self {
810
+ ClosingReason :: ReadSocketClosed => {
811
+ write ! ( f, "The socket was closed" )
812
+ }
813
+ ClosingReason :: ClientRequestProcessorClosed => {
814
+ write ! (
815
+ f,
816
+ "The client's request processor has been unexpectedly closed."
817
+ )
818
+ }
819
+ ClosingReason :: UnhandledError ( err) => {
820
+ write ! ( f, "Unhandled error encountered: {}" , err)
821
+ }
822
+ }
823
+ }
824
+ }
825
+
740
826
impl From < io:: Error > for ClosingReason {
741
827
fn from ( error : io:: Error ) -> Self {
742
828
UnhandledError ( error. into ( ) )
0 commit comments