@@ -11,7 +11,6 @@ use crate::errors::{error_message, error_type, RequestErrorType};
11
11
use crate :: response;
12
12
use crate :: response:: Response ;
13
13
use bytes:: Bytes ;
14
- use directories:: BaseDirs ;
15
14
use logger_core:: { log_debug, log_error, log_info, log_trace, log_warn} ;
16
15
use once_cell:: sync:: Lazy ;
17
16
use protobuf:: { Chars , Message } ;
@@ -22,6 +21,7 @@ use redis::cluster_routing::{ResponsePolicy, Routable};
22
21
use redis:: { ClusterScanArgs , Cmd , PushInfo , RedisError , ScanStateRC , Value } ;
23
22
use std:: cell:: Cell ;
24
23
use std:: collections:: HashSet ;
24
+ use std:: path:: PathBuf ;
25
25
use std:: ptr:: from_mut;
26
26
use std:: rc:: Rc ;
27
27
use std:: sync:: RwLock ;
@@ -39,7 +39,9 @@ use ClosingReason::*;
39
39
use PipeListeningResult :: * ;
40
40
41
41
/// The socket file name
42
- const SOCKET_FILE_NAME : & str = "glide-socket" ;
42
+ const SOCKET_FILE_NAME : & str = "glide-socket.soc" ;
43
+ /// The socket folder
44
+ const SOCKET_FOLDER : & str = "glide" ;
43
45
44
46
/// The maximum length of a request's arguments to be passed as a vector of
45
47
/// strings instead of a pointer
@@ -569,7 +571,7 @@ async fn handle_requests(
569
571
570
572
pub fn close_socket ( socket_path : & String ) {
571
573
log_info ( "close_socket" , format ! ( "closing socket at {socket_path}" ) ) ;
572
- let _ = std :: fs :: remove_file ( socket_path ) ;
574
+ remove_socket_dir ( ) ;
573
575
}
574
576
575
577
async fn create_client (
@@ -783,29 +785,39 @@ struct ClosingError {
783
785
/// Get the socket full path.
784
786
/// The socket file name will contain the process ID and will try to be saved into the user's runtime directory
785
787
/// (e.g. /run/user/1000) in Unix systems. If the runtime dir isn't found, the socket file will be saved to the temp dir.
786
- /// For Windows, the socket file will be saved to %AppData%\Local.
787
788
pub fn get_socket_path_from_name ( socket_name : String ) -> String {
788
- let base_dirs = BaseDirs :: new ( ) . expect ( "Failed to create BaseDirs" ) ;
789
- let tmp_dir;
790
- let folder = if cfg ! ( windows) {
791
- base_dirs. data_local_dir ( )
792
- } else {
793
- base_dirs. runtime_dir ( ) . unwrap_or ( {
794
- tmp_dir = env:: temp_dir ( ) ;
795
- tmp_dir. as_path ( )
796
- } )
797
- } ;
798
- folder
789
+ get_socket_dir ( )
799
790
. join ( socket_name)
800
791
. into_os_string ( )
801
792
. into_string ( )
802
- . expect ( "Couldn't create socket path" )
793
+ . expect ( "Failed to create socket path from name" )
794
+ }
795
+
796
+ /// Get the socket directory path.
797
+ fn get_socket_dir ( ) -> PathBuf {
798
+ // Use XDG_RUNTIME_DIR if available, else fallback to temp directory
799
+ if let Ok ( runtime_dir) = env:: var ( "XDG_RUNTIME_DIR" ) {
800
+ PathBuf :: from ( runtime_dir)
801
+ . join ( SOCKET_FOLDER )
802
+ . join ( std:: process:: id ( ) . to_string ( ) )
803
+ } else {
804
+ env:: temp_dir ( )
805
+ . join ( SOCKET_FOLDER )
806
+ . join ( std:: process:: id ( ) . to_string ( ) )
807
+ }
808
+ }
809
+
810
+ /// Remove socket dir of the process
811
+ pub fn remove_socket_dir ( ) {
812
+ let socket_dir = get_socket_dir ( ) ;
813
+ if socket_dir. exists ( ) {
814
+ let _ = std:: fs:: remove_dir_all ( socket_dir) ;
815
+ }
803
816
}
804
817
805
818
/// Get the socket path as a string
806
819
pub fn get_socket_path ( ) -> String {
807
- let socket_name = format ! ( "{}-{}" , SOCKET_FILE_NAME , std:: process:: id( ) ) ;
808
- get_socket_path_from_name ( socket_name)
820
+ get_socket_path_from_name ( SOCKET_FILE_NAME . to_string ( ) )
809
821
}
810
822
811
823
/// This function is exposed only for the sake of testing with a nonstandard `socket_path`.
@@ -820,7 +832,10 @@ pub fn start_socket_listener_internal<InitCallback>(
820
832
static INITIALIZED_SOCKETS : Lazy < RwLock < HashSet < String > > > =
821
833
Lazy :: new ( || RwLock :: new ( HashSet :: new ( ) ) ) ;
822
834
823
- let socket_path = socket_path. unwrap_or_else ( get_socket_path) ;
835
+ let socket_path = match socket_path {
836
+ Some ( path) => path,
837
+ None => get_socket_path ( ) ,
838
+ } ;
824
839
825
840
{
826
841
// Optimize for already initialized
@@ -841,7 +856,6 @@ pub fn start_socket_listener_internal<InitCallback>(
841
856
init_callback ( Ok ( socket_path. clone ( ) ) ) ;
842
857
return ;
843
858
}
844
-
845
859
let ( tx, rx) = std:: sync:: mpsc:: channel ( ) ;
846
860
let socket_path_cloned = socket_path. clone ( ) ;
847
861
let init_callback_cloned = init_callback. clone ( ) ;
@@ -860,8 +874,16 @@ pub fn start_socket_listener_internal<InitCallback>(
860
874
}
861
875
Ok ( runtime) => runtime,
862
876
} ;
863
-
864
877
runtime. block_on ( async move {
878
+ // Clean up any leftover from previous runs and create socket dir
879
+ remove_socket_dir ( ) ;
880
+ if let Err ( err) = std:: fs:: create_dir_all ( get_socket_dir ( ) ) {
881
+ log_error (
882
+ "listen_on_socket" ,
883
+ format ! ( "Failed to create socket directory: {err}" ) ,
884
+ ) ;
885
+ }
886
+
865
887
let listener_socket = match UnixListener :: bind ( socket_path_cloned. clone ( ) ) {
866
888
Err ( err) => {
867
889
log_error (
@@ -874,8 +896,6 @@ pub fn start_socket_listener_internal<InitCallback>(
874
896
} ;
875
897
876
898
// Signal initialization is successful.
877
- // IMPORTANT:
878
- // tx.send() must be called before init_callback_cloned() to ensure runtimes, such as Python, can properly complete the main function
879
899
let _ = tx. send ( true ) ;
880
900
init_callback_cloned ( Ok ( socket_path_cloned. clone ( ) ) ) ;
881
901
@@ -900,7 +920,9 @@ pub fn start_socket_listener_internal<InitCallback>(
900
920
drop ( listener_socket) ;
901
921
let _ = std:: fs:: remove_file ( socket_path_cloned. clone ( ) ) ;
902
922
903
- // no more listening on socket - update the sockets db
923
+ // Clean the entire process-id socket directory on close
924
+ remove_socket_dir ( ) ;
925
+
904
926
let mut sockets_write_guard = INITIALIZED_SOCKETS
905
927
. write ( )
906
928
. expect ( "Failed to acquire sockets db write guard" ) ;
@@ -917,7 +939,6 @@ pub fn start_socket_listener_internal<InitCallback>(
917
939
} )
918
940
. expect ( "Thread spawn failed. Cannot report error because callback was moved." ) ;
919
941
920
- // wait for thread initialization signaling, callback invocation is done in the thread
921
942
let _ = rx. recv ( ) . map ( |res| {
922
943
if res {
923
944
sockets_write_guard. insert ( socket_path) ;
0 commit comments