Skip to content

Commit 235aef8

Browse files
committed
Add an online test for an http connect proxy_url
This verifies an http connect proxy works, and since we're making changes around there, it's worth having a basic test to validate it functions.
1 parent 70fca68 commit 235aef8

File tree

2 files changed

+184
-1
lines changed

2 files changed

+184
-1
lines changed

ngrok/Cargo.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ futures-util = "0.3.30"
5252
windows-sys = { version = "0.45.0", features = ["Win32_Foundation"] }
5353

5454
[dev-dependencies]
55-
hyper = "1.1.0"
55+
hyper = { version = "1.1.0", features = [ "client" ] }
5656
hyper-util = { version = "0.1.3", features = [
5757
"tokio",
5858
"server",
@@ -73,6 +73,7 @@ tokio-tungstenite = { version = "0.21.0", features = [
7373
] }
7474
tower = { version = "0.5", features = ["util"] }
7575
axum = { version = "0.7.4", features = ["tokio"] }
76+
http-body-util = "0.1.3"
7677

7778
[[example]]
7879
name = "tls"

ngrok/src/online_tests.rs

Lines changed: 182 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -871,6 +871,188 @@ fn tls_client_config() -> Result<Arc<ClientConfig>, &'static io::Error> {
871871
Ok(CONFIG.as_ref()?.clone())
872872
}
873873

874+
#[traced_test]
875+
#[test]
876+
async fn connect_proxy_http() -> Result<(), BoxError> {
877+
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await?;
878+
let addr = listener.local_addr()?;
879+
let (tx, mut rx) = mpsc::channel::<u64>(1);
880+
let shutdown = tokio_util::sync::CancellationToken::new();
881+
882+
let ln_shutdown = shutdown.clone();
883+
tokio::spawn(async move {
884+
let res = connect_proxy::run_proxy(listener, ln_shutdown).await;
885+
tx.send(res).await.unwrap();
886+
});
887+
888+
let sess = Session::builder()
889+
.authtoken_from_env()
890+
.proxy_url(format!("http://{addr}").parse().unwrap())
891+
.unwrap()
892+
.connect()
893+
.await?;
894+
895+
tracing::debug!("{}", sess.id());
896+
897+
shutdown.cancel();
898+
// verify we got a request
899+
let conns = rx.recv().await;
900+
901+
assert_eq!(Some(1), conns);
902+
Ok(())
903+
}
904+
905+
// connect_proxy contains code for connect_proxy tests
906+
// This code is adapted from https://github.com/hyperium/hyper/blob/c449528a33d266a8ca1210baca11e5d649ca6c27/examples/http_proxy.rs#L37
907+
// Used under the terms of the MIT license, Copyright (c) 2014-2025 Sean McArthur
908+
mod connect_proxy {
909+
use bytes::Bytes;
910+
use http_body_util::{
911+
combinators::BoxBody,
912+
BodyExt,
913+
Empty,
914+
Full,
915+
};
916+
use hyper::{
917+
client::conn::http1::Builder,
918+
http,
919+
server::conn::http1,
920+
service::service_fn,
921+
upgrade::Upgraded,
922+
Method,
923+
Request,
924+
Response,
925+
};
926+
use hyper_util::rt::TokioIo;
927+
use tokio::net::TcpStream;
928+
use tokio_util::sync::CancellationToken;
929+
930+
pub async fn run_proxy(listener: tokio::net::TcpListener, shutdown: CancellationToken) -> u64 {
931+
// count requests so our caller can test that we received a request
932+
let mut req_count = 0;
933+
loop {
934+
let (stream, _) = match shutdown.run_until_cancelled(listener.accept()).await {
935+
None => {
936+
return req_count;
937+
}
938+
Some(r) => r.unwrap(),
939+
};
940+
let io = TokioIo::new(stream);
941+
req_count += 1;
942+
943+
tokio::task::spawn(async move {
944+
if let Err(err) = http1::Builder::new()
945+
.preserve_header_case(true)
946+
.title_case_headers(true)
947+
.serve_connection(io, service_fn(proxy))
948+
.with_upgrades()
949+
.await
950+
{
951+
println!("Failed to serve connection: {:?}", err);
952+
}
953+
});
954+
}
955+
}
956+
957+
async fn proxy(
958+
req: Request<hyper::body::Incoming>,
959+
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
960+
println!("req: {:?}", req);
961+
962+
if Method::CONNECT == req.method() {
963+
// Received an HTTP request like:
964+
// ```
965+
// CONNECT www.domain.com:443 HTTP/1.1
966+
// Host: www.domain.com:443
967+
// Proxy-Connection: Keep-Alive
968+
// ```
969+
//
970+
// When HTTP method is CONNECT we should return an empty body
971+
// then we can eventually upgrade the connection and talk a new protocol.
972+
//
973+
// Note: only after client received an empty body with STATUS_OK can the
974+
// connection be upgraded, so we can't return a response inside
975+
// `on_upgrade` future.
976+
if let Some(addr) = host_addr(req.uri()) {
977+
tokio::task::spawn(async move {
978+
match hyper::upgrade::on(req).await {
979+
Ok(upgraded) => {
980+
if let Err(e) = tunnel(upgraded, addr).await {
981+
eprintln!("server io error: {}", e);
982+
};
983+
}
984+
Err(e) => eprintln!("upgrade error: {}", e),
985+
}
986+
});
987+
988+
Ok(Response::new(empty()))
989+
} else {
990+
eprintln!("CONNECT host is not socket addr: {:?}", req.uri());
991+
let mut resp = Response::new(full("CONNECT must be to a socket address"));
992+
*resp.status_mut() = http::StatusCode::BAD_REQUEST;
993+
994+
Ok(resp)
995+
}
996+
} else {
997+
let host = req.uri().host().expect("uri has no host");
998+
let port = req.uri().port_u16().unwrap_or(80);
999+
1000+
let stream = TcpStream::connect((host, port)).await.unwrap();
1001+
let io = TokioIo::new(stream);
1002+
1003+
let (mut sender, conn) = Builder::new()
1004+
.preserve_header_case(true)
1005+
.title_case_headers(true)
1006+
.handshake(io)
1007+
.await?;
1008+
tokio::task::spawn(async move {
1009+
if let Err(err) = conn.await {
1010+
println!("Connection failed: {:?}", err);
1011+
}
1012+
});
1013+
1014+
let resp = sender.send_request(req).await?;
1015+
Ok(resp.map(|b| b.boxed()))
1016+
}
1017+
}
1018+
1019+
fn host_addr(uri: &http::Uri) -> Option<String> {
1020+
uri.authority().map(|auth| auth.to_string())
1021+
}
1022+
1023+
fn empty() -> BoxBody<Bytes, hyper::Error> {
1024+
Empty::<Bytes>::new()
1025+
.map_err(|never| match never {})
1026+
.boxed()
1027+
}
1028+
1029+
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
1030+
Full::new(chunk.into())
1031+
.map_err(|never| match never {})
1032+
.boxed()
1033+
}
1034+
1035+
// Create a TCP connection to host:port, build a tunnel between the connection and
1036+
// the upgraded connection
1037+
async fn tunnel(upgraded: Upgraded, addr: String) -> std::io::Result<()> {
1038+
// Connect to remote server
1039+
let mut server = TcpStream::connect(addr).await?;
1040+
let mut upgraded = TokioIo::new(upgraded);
1041+
1042+
// Proxying data
1043+
let (from_client, from_server) =
1044+
tokio::io::copy_bidirectional(&mut upgraded, &mut server).await?;
1045+
1046+
// Print message when done
1047+
println!(
1048+
"client wrote {} bytes and received {} bytes",
1049+
from_client, from_server
1050+
);
1051+
1052+
Ok(())
1053+
}
1054+
}
1055+
8741056
#[traced_test]
8751057
#[cfg_attr(not(feature = "paid-tests"), ignore)]
8761058
#[test]

0 commit comments

Comments
 (0)