@@ -871,6 +871,188 @@ fn tls_client_config() -> Result<Arc<ClientConfig>, &'static io::Error> {
871
871
Ok ( CONFIG . as_ref ( ) ?. clone ( ) )
872
872
}
873
873
874
+ #[ traced_test]
875
+ #[ test]
876
+ async fn connect_proxy_http ( ) -> Result < ( ) , BoxError > {
877
+ let listener = tokio:: net:: TcpListener :: bind ( "127.0.0.1:0" ) . await ?;
878
+ let addr = listener. local_addr ( ) ?;
879
+ let ( tx, mut rx) = mpsc:: channel :: < u64 > ( 1 ) ;
880
+ let shutdown = tokio_util:: sync:: CancellationToken :: new ( ) ;
881
+
882
+ let ln_shutdown = shutdown. clone ( ) ;
883
+ tokio:: spawn ( async move {
884
+ let res = connect_proxy:: run_proxy ( listener, ln_shutdown) . await ;
885
+ tx. send ( res) . await . unwrap ( ) ;
886
+ } ) ;
887
+
888
+ let sess = Session :: builder ( )
889
+ . authtoken_from_env ( )
890
+ . proxy_url ( format ! ( "http://{addr}" ) . parse ( ) . unwrap ( ) )
891
+ . unwrap ( )
892
+ . connect ( )
893
+ . await ?;
894
+
895
+ tracing:: debug!( "{}" , sess. id( ) ) ;
896
+
897
+ shutdown. cancel ( ) ;
898
+ // verify we got a request
899
+ let conns = rx. recv ( ) . await ;
900
+
901
+ assert_eq ! ( Some ( 1 ) , conns) ;
902
+ Ok ( ( ) )
903
+ }
904
+
905
+ // connect_proxy contains code for connect_proxy tests
906
+ // This code is adapted from https://github.com/hyperium/hyper/blob/c449528a33d266a8ca1210baca11e5d649ca6c27/examples/http_proxy.rs#L37
907
+ // Used under the terms of the MIT license, Copyright (c) 2014-2025 Sean McArthur
908
+ mod connect_proxy {
909
+ use bytes:: Bytes ;
910
+ use http_body_util:: {
911
+ combinators:: BoxBody ,
912
+ BodyExt ,
913
+ Empty ,
914
+ Full ,
915
+ } ;
916
+ use hyper:: {
917
+ client:: conn:: http1:: Builder ,
918
+ http,
919
+ server:: conn:: http1,
920
+ service:: service_fn,
921
+ upgrade:: Upgraded ,
922
+ Method ,
923
+ Request ,
924
+ Response ,
925
+ } ;
926
+ use hyper_util:: rt:: TokioIo ;
927
+ use tokio:: net:: TcpStream ;
928
+ use tokio_util:: sync:: CancellationToken ;
929
+
930
+ pub async fn run_proxy ( listener : tokio:: net:: TcpListener , shutdown : CancellationToken ) -> u64 {
931
+ // count requests so our caller can test that we received a request
932
+ let mut req_count = 0 ;
933
+ loop {
934
+ let ( stream, _) = match shutdown. run_until_cancelled ( listener. accept ( ) ) . await {
935
+ None => {
936
+ return req_count;
937
+ }
938
+ Some ( r) => r. unwrap ( ) ,
939
+ } ;
940
+ let io = TokioIo :: new ( stream) ;
941
+ req_count += 1 ;
942
+
943
+ tokio:: task:: spawn ( async move {
944
+ if let Err ( err) = http1:: Builder :: new ( )
945
+ . preserve_header_case ( true )
946
+ . title_case_headers ( true )
947
+ . serve_connection ( io, service_fn ( proxy) )
948
+ . with_upgrades ( )
949
+ . await
950
+ {
951
+ println ! ( "Failed to serve connection: {:?}" , err) ;
952
+ }
953
+ } ) ;
954
+ }
955
+ }
956
+
957
+ async fn proxy (
958
+ req : Request < hyper:: body:: Incoming > ,
959
+ ) -> Result < Response < BoxBody < Bytes , hyper:: Error > > , hyper:: Error > {
960
+ println ! ( "req: {:?}" , req) ;
961
+
962
+ if Method :: CONNECT == req. method ( ) {
963
+ // Received an HTTP request like:
964
+ // ```
965
+ // CONNECT www.domain.com:443 HTTP/1.1
966
+ // Host: www.domain.com:443
967
+ // Proxy-Connection: Keep-Alive
968
+ // ```
969
+ //
970
+ // When HTTP method is CONNECT we should return an empty body
971
+ // then we can eventually upgrade the connection and talk a new protocol.
972
+ //
973
+ // Note: only after client received an empty body with STATUS_OK can the
974
+ // connection be upgraded, so we can't return a response inside
975
+ // `on_upgrade` future.
976
+ if let Some ( addr) = host_addr ( req. uri ( ) ) {
977
+ tokio:: task:: spawn ( async move {
978
+ match hyper:: upgrade:: on ( req) . await {
979
+ Ok ( upgraded) => {
980
+ if let Err ( e) = tunnel ( upgraded, addr) . await {
981
+ eprintln ! ( "server io error: {}" , e) ;
982
+ } ;
983
+ }
984
+ Err ( e) => eprintln ! ( "upgrade error: {}" , e) ,
985
+ }
986
+ } ) ;
987
+
988
+ Ok ( Response :: new ( empty ( ) ) )
989
+ } else {
990
+ eprintln ! ( "CONNECT host is not socket addr: {:?}" , req. uri( ) ) ;
991
+ let mut resp = Response :: new ( full ( "CONNECT must be to a socket address" ) ) ;
992
+ * resp. status_mut ( ) = http:: StatusCode :: BAD_REQUEST ;
993
+
994
+ Ok ( resp)
995
+ }
996
+ } else {
997
+ let host = req. uri ( ) . host ( ) . expect ( "uri has no host" ) ;
998
+ let port = req. uri ( ) . port_u16 ( ) . unwrap_or ( 80 ) ;
999
+
1000
+ let stream = TcpStream :: connect ( ( host, port) ) . await . unwrap ( ) ;
1001
+ let io = TokioIo :: new ( stream) ;
1002
+
1003
+ let ( mut sender, conn) = Builder :: new ( )
1004
+ . preserve_header_case ( true )
1005
+ . title_case_headers ( true )
1006
+ . handshake ( io)
1007
+ . await ?;
1008
+ tokio:: task:: spawn ( async move {
1009
+ if let Err ( err) = conn. await {
1010
+ println ! ( "Connection failed: {:?}" , err) ;
1011
+ }
1012
+ } ) ;
1013
+
1014
+ let resp = sender. send_request ( req) . await ?;
1015
+ Ok ( resp. map ( |b| b. boxed ( ) ) )
1016
+ }
1017
+ }
1018
+
1019
+ fn host_addr ( uri : & http:: Uri ) -> Option < String > {
1020
+ uri. authority ( ) . map ( |auth| auth. to_string ( ) )
1021
+ }
1022
+
1023
+ fn empty ( ) -> BoxBody < Bytes , hyper:: Error > {
1024
+ Empty :: < Bytes > :: new ( )
1025
+ . map_err ( |never| match never { } )
1026
+ . boxed ( )
1027
+ }
1028
+
1029
+ fn full < T : Into < Bytes > > ( chunk : T ) -> BoxBody < Bytes , hyper:: Error > {
1030
+ Full :: new ( chunk. into ( ) )
1031
+ . map_err ( |never| match never { } )
1032
+ . boxed ( )
1033
+ }
1034
+
1035
+ // Create a TCP connection to host:port, build a tunnel between the connection and
1036
+ // the upgraded connection
1037
+ async fn tunnel ( upgraded : Upgraded , addr : String ) -> std:: io:: Result < ( ) > {
1038
+ // Connect to remote server
1039
+ let mut server = TcpStream :: connect ( addr) . await ?;
1040
+ let mut upgraded = TokioIo :: new ( upgraded) ;
1041
+
1042
+ // Proxying data
1043
+ let ( from_client, from_server) =
1044
+ tokio:: io:: copy_bidirectional ( & mut upgraded, & mut server) . await ?;
1045
+
1046
+ // Print message when done
1047
+ println ! (
1048
+ "client wrote {} bytes and received {} bytes" ,
1049
+ from_client, from_server
1050
+ ) ;
1051
+
1052
+ Ok ( ( ) )
1053
+ }
1054
+ }
1055
+
874
1056
#[ traced_test]
875
1057
#[ cfg_attr( not( feature = "paid-tests" ) , ignore) ]
876
1058
#[ test]
0 commit comments