@@ -893,6 +893,188 @@ fn tls_client_config() -> Result<Arc<ClientConfig>, &'static io::Error> {
893
893
Ok ( CONFIG . as_ref ( ) ?. clone ( ) )
894
894
}
895
895
896
+ #[ traced_test]
897
+ #[ test]
898
+ async fn connect_proxy_http ( ) -> Result < ( ) , BoxError > {
899
+ let listener = tokio:: net:: TcpListener :: bind ( "127.0.0.1:0" ) . await ?;
900
+ let addr = listener. local_addr ( ) ?;
901
+ let ( tx, mut rx) = mpsc:: channel :: < u64 > ( 1 ) ;
902
+ let shutdown = tokio_util:: sync:: CancellationToken :: new ( ) ;
903
+
904
+ let ln_shutdown = shutdown. clone ( ) ;
905
+ tokio:: spawn ( async move {
906
+ let res = connect_proxy:: run_proxy ( listener, ln_shutdown) . await ;
907
+ tx. send ( res) . await . unwrap ( ) ;
908
+ } ) ;
909
+
910
+ let sess = Session :: builder ( )
911
+ . authtoken_from_env ( )
912
+ . proxy_url ( format ! ( "http://{addr}" ) . parse ( ) . unwrap ( ) )
913
+ . unwrap ( )
914
+ . connect ( )
915
+ . await ?;
916
+
917
+ tracing:: debug!( "{}" , sess. id( ) ) ;
918
+
919
+ shutdown. cancel ( ) ;
920
+ // verify we got a request
921
+ let conns = rx. recv ( ) . await ;
922
+
923
+ assert_eq ! ( Some ( 1 ) , conns) ;
924
+ Ok ( ( ) )
925
+ }
926
+
927
+ // connect_proxy contains code for connect_proxy tests
928
+ // This code is adapted from https://github.com/hyperium/hyper/blob/c449528a33d266a8ca1210baca11e5d649ca6c27/examples/http_proxy.rs#L37
929
+ // Used under the terms of the MIT license, Copyright (c) 2014-2025 Sean McArthur
930
+ mod connect_proxy {
931
+ use bytes:: Bytes ;
932
+ use http_body_util:: {
933
+ combinators:: BoxBody ,
934
+ BodyExt ,
935
+ Empty ,
936
+ Full ,
937
+ } ;
938
+ use hyper:: {
939
+ client:: conn:: http1:: Builder ,
940
+ http,
941
+ server:: conn:: http1,
942
+ service:: service_fn,
943
+ upgrade:: Upgraded ,
944
+ Method ,
945
+ Request ,
946
+ Response ,
947
+ } ;
948
+ use hyper_util:: rt:: TokioIo ;
949
+ use tokio:: net:: TcpStream ;
950
+ use tokio_util:: sync:: CancellationToken ;
951
+
952
+ pub async fn run_proxy ( listener : tokio:: net:: TcpListener , shutdown : CancellationToken ) -> u64 {
953
+ // count requests so our caller can test that we received a request
954
+ let mut req_count = 0 ;
955
+ loop {
956
+ let ( stream, _) = match shutdown. run_until_cancelled ( listener. accept ( ) ) . await {
957
+ None => {
958
+ return req_count;
959
+ }
960
+ Some ( r) => r. unwrap ( ) ,
961
+ } ;
962
+ let io = TokioIo :: new ( stream) ;
963
+ req_count += 1 ;
964
+
965
+ tokio:: task:: spawn ( async move {
966
+ if let Err ( err) = http1:: Builder :: new ( )
967
+ . preserve_header_case ( true )
968
+ . title_case_headers ( true )
969
+ . serve_connection ( io, service_fn ( proxy) )
970
+ . with_upgrades ( )
971
+ . await
972
+ {
973
+ println ! ( "Failed to serve connection: {:?}" , err) ;
974
+ }
975
+ } ) ;
976
+ }
977
+ }
978
+
979
+ async fn proxy (
980
+ req : Request < hyper:: body:: Incoming > ,
981
+ ) -> Result < Response < BoxBody < Bytes , hyper:: Error > > , hyper:: Error > {
982
+ println ! ( "req: {:?}" , req) ;
983
+
984
+ if Method :: CONNECT == req. method ( ) {
985
+ // Received an HTTP request like:
986
+ // ```
987
+ // CONNECT www.domain.com:443 HTTP/1.1
988
+ // Host: www.domain.com:443
989
+ // Proxy-Connection: Keep-Alive
990
+ // ```
991
+ //
992
+ // When HTTP method is CONNECT we should return an empty body
993
+ // then we can eventually upgrade the connection and talk a new protocol.
994
+ //
995
+ // Note: only after client received an empty body with STATUS_OK can the
996
+ // connection be upgraded, so we can't return a response inside
997
+ // `on_upgrade` future.
998
+ if let Some ( addr) = host_addr ( req. uri ( ) ) {
999
+ tokio:: task:: spawn ( async move {
1000
+ match hyper:: upgrade:: on ( req) . await {
1001
+ Ok ( upgraded) => {
1002
+ if let Err ( e) = tunnel ( upgraded, addr) . await {
1003
+ eprintln ! ( "server io error: {}" , e) ;
1004
+ } ;
1005
+ }
1006
+ Err ( e) => eprintln ! ( "upgrade error: {}" , e) ,
1007
+ }
1008
+ } ) ;
1009
+
1010
+ Ok ( Response :: new ( empty ( ) ) )
1011
+ } else {
1012
+ eprintln ! ( "CONNECT host is not socket addr: {:?}" , req. uri( ) ) ;
1013
+ let mut resp = Response :: new ( full ( "CONNECT must be to a socket address" ) ) ;
1014
+ * resp. status_mut ( ) = http:: StatusCode :: BAD_REQUEST ;
1015
+
1016
+ Ok ( resp)
1017
+ }
1018
+ } else {
1019
+ let host = req. uri ( ) . host ( ) . expect ( "uri has no host" ) ;
1020
+ let port = req. uri ( ) . port_u16 ( ) . unwrap_or ( 80 ) ;
1021
+
1022
+ let stream = TcpStream :: connect ( ( host, port) ) . await . unwrap ( ) ;
1023
+ let io = TokioIo :: new ( stream) ;
1024
+
1025
+ let ( mut sender, conn) = Builder :: new ( )
1026
+ . preserve_header_case ( true )
1027
+ . title_case_headers ( true )
1028
+ . handshake ( io)
1029
+ . await ?;
1030
+ tokio:: task:: spawn ( async move {
1031
+ if let Err ( err) = conn. await {
1032
+ println ! ( "Connection failed: {:?}" , err) ;
1033
+ }
1034
+ } ) ;
1035
+
1036
+ let resp = sender. send_request ( req) . await ?;
1037
+ Ok ( resp. map ( |b| b. boxed ( ) ) )
1038
+ }
1039
+ }
1040
+
1041
+ fn host_addr ( uri : & http:: Uri ) -> Option < String > {
1042
+ uri. authority ( ) . map ( |auth| auth. to_string ( ) )
1043
+ }
1044
+
1045
+ fn empty ( ) -> BoxBody < Bytes , hyper:: Error > {
1046
+ Empty :: < Bytes > :: new ( )
1047
+ . map_err ( |never| match never { } )
1048
+ . boxed ( )
1049
+ }
1050
+
1051
+ fn full < T : Into < Bytes > > ( chunk : T ) -> BoxBody < Bytes , hyper:: Error > {
1052
+ Full :: new ( chunk. into ( ) )
1053
+ . map_err ( |never| match never { } )
1054
+ . boxed ( )
1055
+ }
1056
+
1057
+ // Create a TCP connection to host:port, build a tunnel between the connection and
1058
+ // the upgraded connection
1059
+ async fn tunnel ( upgraded : Upgraded , addr : String ) -> std:: io:: Result < ( ) > {
1060
+ // Connect to remote server
1061
+ let mut server = TcpStream :: connect ( addr) . await ?;
1062
+ let mut upgraded = TokioIo :: new ( upgraded) ;
1063
+
1064
+ // Proxying data
1065
+ let ( from_client, from_server) =
1066
+ tokio:: io:: copy_bidirectional ( & mut upgraded, & mut server) . await ?;
1067
+
1068
+ // Print message when done
1069
+ println ! (
1070
+ "client wrote {} bytes and received {} bytes" ,
1071
+ from_client, from_server
1072
+ ) ;
1073
+
1074
+ Ok ( ( ) )
1075
+ }
1076
+ }
1077
+
896
1078
#[ traced_test]
897
1079
#[ cfg_attr( not( feature = "paid-tests" ) , ignore) ]
898
1080
#[ test]
0 commit comments