From 3c0ebc49f2fd491f7430371d580b81e64d4b3f3e Mon Sep 17 00:00:00 2001 From: koe Date: Fri, 20 Dec 2024 02:51:44 -0600 Subject: [PATCH] refactor webtransport server to use wtransport crate --- CHANGELOG.md | 2 +- examples/echo_server_cross/src/main.rs | 12 +- renet2/Cargo.toml | 31 +- .../websocket_socket/client/socket.rs | 4 +- .../websocket_socket/server/socket.rs | 14 +- .../webtransport_socket/client/socket.rs | 4 +- .../webtransport_socket/server/cert_utils.rs | 2 +- .../webtransport_socket/server/socket.rs | 286 ++++++++---------- .../transport/webtransport_socket/utils.rs | 2 +- 9 files changed, 153 insertions(+), 204 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 421ad1aa..1391fdcb 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -6,13 +6,13 @@ - `rustls`: 0.21 -> 0.23.5 - `quinn`: 0.10 -> 0.11.6 - `rcgen`: 0.12 -> 0.13 - - `h3-quinn`/`h3-webtransport`/`h3`: h3-v0.0.4 -> h3-v0.0.6 - Split `TransportSocket` into separate `ServerSocket`/`ClientSocket` traits. - Add `webtransport_is_available()`/`webtransport_is_available_with_cert_hashes()` helpers for WASM clients. - Add support for reliable transport sockets. - Add `TransportSocket::is_reliable`. It's true for in-memory sockets and WebSockets, and false for UDP and WebTransport. - Add `has_reliable_socket` argument to `RenetClient::new` - Add WebSocket server and client. The client is WASM-only. +- Replace `h3` dependency with `wtransport` for WebTransport backend. ## 0.0.7 - 12/02/24 diff --git a/examples/echo_server_cross/src/main.rs b/examples/echo_server_cross/src/main.rs index e61649f4..35976ccf 100644 --- a/examples/echo_server_cross/src/main.rs +++ b/examples/echo_server_cross/src/main.rs @@ -4,7 +4,7 @@ use std::{ }; use warp::Filter; -use log::{debug, info}; +use log::{debug, info, trace}; use renet2::{ transport::{ BoxedSocket, NativeSocket, NetcodeServerTransport, ServerCertHash, ServerSetupConfig, ServerSocket, WebServerDestination, @@ -22,11 +22,7 @@ struct ClientConnectionInfo { } fn main() { - env_logger::builder() - .filter_level(log::LevelFilter::Info) - .filter_module("h3::server", log::LevelFilter::Warn) - .filter_module("h3::server::connection", log::LevelFilter::Warn) - .init(); + env_logger::builder().filter_level(log::LevelFilter::Info).init(); let runtime = tokio::runtime::Runtime::new().unwrap(); @@ -113,13 +109,13 @@ fn run_renet_server(mut transport: NetcodeServerTransport) { let mut last_updated = Instant::now(); loop { - debug!("server tick"); + trace!("server tick"); let now = Instant::now(); let duration = now - last_updated; last_updated = now; transport.update(duration, &mut server).unwrap(); - debug!("server update"); + trace!("server update"); while let Some(event) = server.get_event() { match event { diff --git a/renet2/Cargo.toml b/renet2/Cargo.toml index a460cd0e..6aae40c7 100644 --- a/renet2/Cargo.toml +++ b/renet2/Cargo.toml @@ -41,22 +41,21 @@ wt_server_transport = [ "transport", "dep:crossbeam", "dep:anyhow", + "dep:wtransport", + "dep:rustls", "dep:rustls-pki-types", "dep:rcgen", "dep:quinn", - "dep:h3-quinn", - "dep:h3-webtransport", - "dep:tokio", "dep:http", + + "dep:tokio", "dep:futures", - "dep:h3", "dep:time", #"dep:x509-cert", #"dep:spki", #"dep:base64", - "dep:serde_json", - "dep:form_urlencoded", + "dep:urlencoding", ] # Enable the WebTransport client transport (WASM only) @@ -71,14 +70,16 @@ wt_client_transport = [ "dep:send_wrapper", "dep:getrandom", "dep:web-sys", - "dep:serde_json" + "dep:urlencoding", ] # Enable the WebSocket server transport ws_server_transport = [ "transport", "dep:tungstenite", - "dep:tokio-tungstenite" + "dep:tokio-tungstenite", + "dep:http", + "dep:urlencoding", ] # Enable rustls acceptors for WebSocket server transports. @@ -99,7 +100,8 @@ ws_client_transport = [ "dep:web-sys", "dep:serde_json", "dep:futures-util", - "dep:futures-channel" + "dep:futures-channel", + "dep:urlencoding", ] [dependencies] @@ -119,28 +121,23 @@ crossbeam = { version = "0.8", optional = true } # WebTransport shared futures = { version = "0.3", optional = true } serde_json = { version = "1.0", optional = true } +urlencoding = { version = "2.1", optional = true } # WebTransport server anyhow = { version = "1.0", optional = true } +wtransport = { version = "0.5", optional = true, default-features = false, features = ["quinn", "self-signed"] } rustls = { version = "0.23.5", optional = true } #locked to 0.23.5 until quinn updates rustls-pki-types = { version = "1.7", optional = true } #locked to 1.7 until quinn updates rcgen = { version = "0.13", optional = true } quinn = { version = "0.11.6", optional = true, default-features = false, features = [ - "runtime-tokio", "rustls-ring", ] } -h3-quinn = { tag = "h3-v0.0.6", optional = true, git = "https://github.com/hyperium/h3" } -h3-webtransport = { version = "0.1", tag = "h3-v0.0.6", optional = true, git = "https://github.com/hyperium/h3" } -tokio = { version = "1.32", optional = true, features = ["full"] } http = { version = "1.0", optional = true } -h3 = { tag = "h3-v0.0.6", optional = true, git = "https://github.com/hyperium/h3", features = [ - "i-implement-a-third-party-backend-and-opt-into-breaking-changes" -] } +tokio = { version = "1.32", optional = true, features = ["full"] } time = { version = "0.3", optional = true } #x509-cert = { version = "0.2", optional = true } #spki = { version = "0.7", optional = true, features = ["fingerprint"] } #base64 = { version = "0.22", optional = true } -form_urlencoded = { version = "1.2", optional = true } # WebTransport client async-channel = { version = "2.2", optional = true } diff --git a/renet2/src/transport/websocket_socket/client/socket.rs b/renet2/src/transport/websocket_socket/client/socket.rs index bd60c348..62f1cd47 100644 --- a/renet2/src/transport/websocket_socket/client/socket.rs +++ b/renet2/src/transport/websocket_socket/client/socket.rs @@ -93,8 +93,8 @@ impl WebSocketClient { }; // Build URL with connection request. - let connect_msg_ser = serde_json::to_string(&connection_req).expect("could not serialize connect msg"); - server_url.query_pairs_mut().append_pair(HTTP_CONNECT_REQ, connect_msg_ser.as_str()); + let connect_msg_ser = urlencoding::encode_binary(&connection_req); + server_url.set_query(Some(format!("{}={}", HTTP_CONNECT_REQ, &connect_msg_ser).as_str())); let Ok(ws) = WebSocket::new(server_url.as_str()) else { warn!( diff --git a/renet2/src/transport/websocket_socket/server/socket.rs b/renet2/src/transport/websocket_socket/server/socket.rs index 8d5ded4d..0c5fc840 100644 --- a/renet2/src/transport/websocket_socket/server/socket.rs +++ b/renet2/src/transport/websocket_socket/server/socket.rs @@ -681,21 +681,13 @@ fn extract_client_connection_req(uri: &Uri) -> Result, Error> { log::trace!("invalid uri query, dropping connection request..."); return Err(Error::msg("invalid uri query, dropping connection request...")); }; - let mut query_elements_iterator = form_urlencoded::parse(query.as_bytes()); - let Some((key, connection_req)) = query_elements_iterator.next() else { + let Some(encoded) = query.split_once(HTTP_CONNECT_REQ).and_then(|(_, r)| r.strip_prefix("=")) else { log::trace!("invalid uri query (missing req), dropping connection request..."); return Err(Error::msg("invalid uri query (missing req), dropping connection request...")); }; - if key != HTTP_CONNECT_REQ { - log::trace!("invalid uri query (bad key), dropping connection request..."); - return Err(Error::msg("invalid uri query (bad key), dropping connection request...")); - } - let Ok(connection_req) = serde_json::de::from_str::>(&connection_req) else { - log::trace!("invalid uri query (bad req), dropping connection request..."); - return Err(Error::msg("invalid uri query (bad req), dropping connection request...")); - }; + let connection_req = urlencoding::decode_binary(encoded.as_bytes()); - Ok(connection_req) + Ok(connection_req.into()) } /// Makes a websocket url: `{ws, wss}://[ip:port]/ws`. diff --git a/renet2/src/transport/webtransport_socket/client/socket.rs b/renet2/src/transport/webtransport_socket/client/socket.rs index 50f0342a..f6ab8ace 100644 --- a/renet2/src/transport/webtransport_socket/client/socket.rs +++ b/renet2/src/transport/webtransport_socket/client/socket.rs @@ -171,8 +171,8 @@ impl WebTransportClient { .clone() .try_into() .expect("could not convert server destination to url"); - let connect_msg_ser = serde_json::to_string(&connection_req).expect("could not serialize connect msg"); - url.query_pairs_mut().append_pair(HTTP_CONNECT_REQ, connect_msg_ser.as_str()); + let connect_msg_ser = urlencoding::encode_binary(&connection_req); + url.set_query(Some(format!("{}={}", HTTP_CONNECT_REQ, &connect_msg_ser).as_str())); // Set up WebTransport. let web_transport = match Self::init_web_transport(url.as_str(), options).await { diff --git a/renet2/src/transport/webtransport_socket/server/cert_utils.rs b/renet2/src/transport/webtransport_socket/server/cert_utils.rs index f2f21665..3b1352bd 100644 --- a/renet2/src/transport/webtransport_socket/server/cert_utils.rs +++ b/renet2/src/transport/webtransport_socket/server/cert_utils.rs @@ -32,7 +32,7 @@ pub fn generate_self_signed_certificate_opinionated, ) -> Result<(CertificateDer<'static>, PrivateKeyDer<'static>), rcgen::Error> { let not_before = OffsetDateTime::now_utc().saturating_sub(1.hours()); //adjust for client system time variance - let not_after = not_before.saturating_add(2.weeks().saturating_sub(1.minutes())); //less than 2 weeks + let not_after = not_before.saturating_add(2.weeks()); let mut distinguished_name = DistinguishedName::new(); distinguished_name.push(DnType::CommonName, "renet2 self signed cert"); diff --git a/renet2/src/transport/webtransport_socket/server/socket.rs b/renet2/src/transport/webtransport_socket/server/socket.rs index 241f9a48..22ee2f04 100644 --- a/renet2/src/transport/webtransport_socket/server/socket.rs +++ b/renet2/src/transport/webtransport_socket/server/socket.rs @@ -1,14 +1,11 @@ use anyhow::Error; use bytes::Bytes; -use h3::{error::ErrorLevel, ext::Protocol, server::Connection}; -use h3_quinn::Connection as H3QuinnConnection; -use h3_webtransport::server::WebTransportSession; -use http::{uri::Uri, Method}; use log::{debug, error, trace}; use quinn::crypto::rustls::QuicServerConfig; -use quinn::{EndpointConfig, TokioRuntime}; +use quinn::IdleTimeout; use rustls_pki_types::{CertificateDer, PrivateKeyDer}; use tokio::{sync::mpsc, task::AbortHandle}; +use wtransport::error::SendDatagramError; use std::collections::HashMap; use std::ops::Bound::{Excluded, Included}; @@ -95,6 +92,47 @@ impl WebTransportServerConfig { (config, hash) } + + /// Converts self into a [`wtransport::ServerConfig`]. + /// + /// Used automatically by [`WebTransportServer::new`]. + pub fn create_server_config(self) -> Result { + // TODO: Allow injecting cert resolver via `with_cert_resolver()`, which would allow more than one certificate. + // That would be useful for long-lived servers whose clients are using ServerCertHash, since then you could + // specify many certificates (for the expected lifetime of the server) or even inject fresh ones via atomics + // and channels. + if rustls::crypto::CryptoProvider::get_default().is_none() { + let _ = rustls::crypto::ring::default_provider().install_default(); + } + let mut tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) + .with_no_client_auth() + .with_single_cert(vec![self.cert], self.key)?; + + tls_config.max_early_data_size = u32::MAX; + // We set the ALPN protocols to h3 as first, so that the browser will use the newest HTTP/3 draft and as fallback + // we use older versions of the HTTP/3 draft. + let alpn: Vec> = vec![ + b"h3".to_vec(), + b"h3-32".to_vec(), + b"h3-31".to_vec(), + b"h3-30".to_vec(), + b"h3-29".to_vec(), + ]; + tls_config.alpn_protocols = alpn; + + let mut server_config: quinn::ServerConfig = quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(tls_config)?)); + let mut transport_config = quinn::TransportConfig::default(); + transport_config + .keep_alive_interval(Some(Duration::from_secs(2))) + .max_idle_timeout(Some(IdleTimeout::try_from(Duration::from_secs(15))?)); + server_config.transport = Arc::new(transport_config); + + let wt_config = wtransport::ServerConfig::builder() + .with_bind_address(self.listen) + .build_with_quic_config(server_config); + + Ok(wt_config) + } } impl Clone for WebTransportServerConfig { @@ -112,8 +150,8 @@ impl Clone for WebTransportServerConfig { struct WebTransportServerClient { /// Connection session. /// - /// When this is dropped, the internal `H3QuinnConnection` will send a close message to the client. - session: Arc>, + // TODO: When this is dropped, is a close message send to the client? + session: wtransport::Connection, reader_receiver: crossbeam::channel::Receiver, abort_sender: mpsc::UnboundedSender<()>, /// When this struct is dropped, the reader thread will shut down automatically since the `abort_sender` channel @@ -167,7 +205,7 @@ enum ClientConnectionResult { Success { client_idx: u64, client_id: u64, - session: WebTransportSession, + session: wtransport::Connection, }, Failure { client_idx: u64, @@ -182,7 +220,6 @@ pub struct WebTransportServer { addr: SocketAddr, - endpoint: quinn::Endpoint, connection_req_receiver: mpsc::Receiver, connection_receiver: mpsc::Receiver, connection_abort_handle: AbortHandle, @@ -207,13 +244,9 @@ impl WebTransportServer { /// - Errors if unable to bind to the [`WebTransportServerConfig::listen`] address, which can happen if your /// machine is using all ports on a pre-defined IP address. pub fn new(config: WebTransportServerConfig, handle: tokio::runtime::Handle) -> Result { - let target_addr = config.listen; let max_clients = config.max_clients; - let server_config = Self::create_server_config(config)?; - let socket = std::net::UdpSocket::bind(target_addr)?; - let endpoint = handle.block_on(async move { - quinn::Endpoint::new(EndpointConfig::default(), Some(server_config), socket, Arc::new(TokioRuntime)) - })?; + let server_config = config.create_server_config()?; + let endpoint = handle.block_on(async move { wtransport::Endpoint::server(server_config) })?; let addr = endpoint.local_addr()?; let (sender, receiver) = mpsc::channel::(max_clients); let client_iterator = Arc::new(AtomicU64::new(0)); @@ -222,7 +255,7 @@ impl WebTransportServer { let abort_handle = handle .spawn(Self::accept_connection( sender, - endpoint.clone(), + endpoint, client_iterator.clone(), Arc::clone(¤t_clients), connection_req_sender, @@ -233,7 +266,6 @@ impl WebTransportServer { Ok(Self { handle, addr, - endpoint, connection_req_receiver, connection_receiver: receiver, connection_abort_handle: abort_handle, @@ -249,51 +281,41 @@ impl WebTransportServer { } /// Disconnects the server. + // TODO: verify that aborting the endpoint's thread is enough to shut it down properly pub fn close(&mut self) { - self.endpoint.close(0u32.into(), b"Server shutdown"); self.connection_abort_handle.abort(); self.closed = true; } async fn accept_connection( sender: mpsc::Sender, - endpoint: quinn::Endpoint, + endpoint: wtransport::Endpoint, client_iterator: Arc, current_clients: Arc, connection_req_sender: mpsc::Sender, max_clients: usize, ) { - while let Some(new_conn) = endpoint.accept().await { + loop { + let incoming_connection = endpoint.accept().await; + + // Check for capacity. + let is_full = { + let current_clients = current_clients.load(Ordering::Relaxed); + // We allow 25% extra clients in case clients want to override their old sessions. + (current_clients * 4) >= (max_clients * 5) + }; + if is_full { + incoming_connection.refuse(); + continue; + } + let sender = sender.clone(); - let current_clients = current_clients.clone(); let client_iterator = client_iterator.clone(); let connection_req_sender = connection_req_sender.clone(); tokio::spawn(async move { - match new_conn.await { - Ok(conn) => { - let is_full = { - let current_clients = current_clients.load(Ordering::Relaxed); - // We allow 25% extra clients in case clients want to override their old sessions. - (current_clients * 4) >= (max_clients * 5) - }; - if is_full { - conn.close(0u32.into(), b"Server full"); - return; - } - //todo: need max_field_section_size? - let Ok(h3_conn) = h3::server::builder() - .enable_webtransport(true) - .enable_connect(true) - .enable_datagram(true) - .max_webtransport_sessions(1) - .send_grease(true) - .build(H3QuinnConnection::new(conn)) - .await - else { - return; - }; - - match Self::handle_connection(client_iterator, connection_req_sender, h3_conn).await { + match incoming_connection.await { + Ok(session_request) => { + match Self::handle_session_request(client_iterator, connection_req_sender, session_request).await { Ok(maybe_session) => { if let Some(session) = maybe_session { if let Err(e) = sender.try_send(session) { @@ -314,106 +336,58 @@ impl WebTransportServer { } } - fn create_server_config(config: WebTransportServerConfig) -> Result { - // TODO: Allow injecting cert resolver via `with_cert_resolver()`, which would allow more than one certificate. - // That would be useful for long-lived servers whose clients are using ServerCertHash, since then you could - // specify many certificates (for the expected lifetime of the server) or even inject fresh ones via atomics - // and channels. - if rustls::crypto::CryptoProvider::get_default().is_none() { - let _ = rustls::crypto::ring::default_provider().install_default(); - } - let mut tls_config = rustls::ServerConfig::builder_with_protocol_versions(&[&rustls::version::TLS13]) - .with_no_client_auth() - .with_single_cert(vec![config.cert], config.key)?; - - tls_config.max_early_data_size = u32::MAX; - // We set the ALPN protocols to h3 as first, so that the browser will use the newest HTTP/3 draft and as fallback - // we use older versions of HTTP/3 draft - let alpn: Vec> = vec![ - b"h3".to_vec(), - b"h3-32".to_vec(), - b"h3-31".to_vec(), - b"h3-30".to_vec(), - b"h3-29".to_vec(), - ]; - tls_config.alpn_protocols = alpn; - - let mut server_config: quinn::ServerConfig = quinn::ServerConfig::with_crypto(Arc::new(QuicServerConfig::try_from(tls_config)?)); - let mut transport_config = quinn::TransportConfig::default(); - transport_config.keep_alive_interval(Some(Duration::from_secs(2))); - server_config.transport = Arc::new(transport_config); - Ok(server_config) - } - - async fn handle_connection( + async fn handle_session_request( client_iterator: Arc, connection_req_sender: mpsc::Sender, - mut conn: Connection, - ) -> Result, h3::Error> { - match conn.accept().await { - Ok(Some((req, stream))) => { - let ext = req.extensions(); - if ext.get::() != Some(&Protocol::WEB_TRANSPORT) { - return Ok(None); - } - if *req.method() != Method::CONNECT { - return Ok(None); - } - - // Extract the client's first connection request from the request URL. - // - // SECURITY NOTE: Connection requests are sent *unencrypted*, which matches how they are - // sent when using UDP sockets. - // TODO: Consider authenticating UDP client addresses in connect tokens, and sending WebTransport - // connection requests after sessions are established. - let packet = extract_client_connection_req(req.uri())?; + session_request: wtransport::endpoint::SessionRequest, + ) -> Result, wtransport::error::ConnectionError> { + // Extract the client's first connection request from the request URL. + // + // SECURITY NOTE: Connection requests are sent *unencrypted*, which matches how they are + // sent when using UDP sockets. + // TODO: Consider authenticating UDP client addresses in connect tokens, and sending WebTransport + // connection requests after sessions are established. + let packet = extract_client_connection_req(session_request.path())?; + + // Assign an identifier to this client. + let client_idx = client_iterator.fetch_add(1, Ordering::Relaxed); + + // Send connection request packet to netcode for evaluation. + let (result_sender, mut result_receiver) = mpsc::channel::(1usize); + let Ok(_) = connection_req_sender.try_send(ConnectionRequest { + client_idx, + packet, + result_sender, + }) else { + return Ok(None); + }; - // Assign an identifier to this client. - let client_idx = client_iterator.fetch_add(1, Ordering::Relaxed); + // Wait for the result of evaluating the connection request. + // - The connection must be validated before we accept the session to avoid resources being + // consumed by fake clients. + let Some(ConnectionRequestResult::Success { client_id }) = result_receiver.recv().await else { + return Ok(None); + }; - // Send connection request packet to netcode for evaluation. - let (result_sender, mut result_receiver) = mpsc::channel::(1usize); - let Ok(_) = connection_req_sender.try_send(ConnectionRequest { - client_idx, - packet, - result_sender, - }) else { - return Ok(None); - }; - - // Wait for the result of evaluating the connection request. - // - The connection must be validated before we accept the session to avoid resources being - // consumed by fake clients. - let Some(ConnectionRequestResult::Success { client_id }) = result_receiver.recv().await else { - return Ok(None); - }; - - // Finalize the connection. - match WebTransportSession::accept(req, stream, conn).await { - Ok(session) => Ok(Some(ClientConnectionResult::Success { - client_idx, - client_id, - session, - })), - Err(err) => { - // We must return failure here because `ConnectionRequestResult::Success` means the server - // is tracking this connection. We need the server to clean up its pending client entry. - debug!("Failed to handle connection: {err:?}"); - Ok(Some(ClientConnectionResult::Failure { client_idx })) - } - } + // Finalize the connection. + match session_request.accept().await { + Ok(session) => Ok(Some(ClientConnectionResult::Success { + client_idx, + client_id, + session, + })), + Err(err) => { + // We must return failure here because `ConnectionRequestResult::Success` means the server + // is tracking this connection. We need the server to clean up its pending client entry. + debug!("Failed to handle connection: {err:?}"); + Ok(Some(ClientConnectionResult::Failure { client_idx })) } - - // Indicates no more streams to be received. - Ok(None) => Ok(None), - - Err(err) => Err(err), } } fn reading_thread( handle: &tokio::runtime::Handle, - read_datagram: Arc>, + read_datagram: wtransport::Connection, sender: crossbeam::channel::Sender, mut abort_signal: mpsc::UnboundedReceiver<()>, ) -> tokio::task::JoinHandle<()> { @@ -439,8 +413,8 @@ impl WebTransportServer { _ = abort_signal.recv() => { break; }, - Ok(result) = read_datagram.accept_datagram() => match result { - Some((_, datagram_bytes)) => match sender.try_send(datagram_bytes) { + Ok(datagram) = read_datagram.receive_datagram() => { + match sender.try_send(datagram.payload()) { Ok(_) => {} Err(err) => { if let crossbeam::channel::TrySendError::Disconnected(_) = err { @@ -449,8 +423,7 @@ impl WebTransportServer { trace!("The reading data could not be sent because the channel is currently full and sending \ would require blocking."); } - }, - None => break, + } }, _ = &mut sleep => { trace!("WT client socket reader timed out, disconnecting."); @@ -571,14 +544,13 @@ impl ServerSocket for WebTransportServer { } // Set up datagram reading for the session. - let shared_session = Arc::new(session); let (sender, receiver) = crossbeam::channel::bounded::(256); let (abort_sender, abort_receiver) = mpsc::unbounded_channel::<()>(); - let thread = Self::reading_thread(&self.handle, shared_session.clone(), sender, abort_receiver); + let thread = Self::reading_thread(&self.handle, session.clone(), sender, abort_receiver); self.clients.insert( client_idx, WebTransportServerClient { - session: shared_session, + session, reader_receiver: receiver, abort_sender, reader_thread: thread, @@ -710,12 +682,12 @@ impl ServerSocket for WebTransportServer { let data = Bytes::copy_from_slice(packet); if let Err(err) = client_data.session.send_datagram(data) { // See https://www.rfc-editor.org/rfc/rfc9114.html#errors - match err.get_error_level() { - ErrorLevel::ConnectionError => { + match err { + SendDatagramError::NotConnected => { self.disconnect(addr); return Err(std::io::Error::from(ErrorKind::ConnectionAborted).into()); } - ErrorLevel::StreamError => debug!("Stream error: {err}"), + SendDatagramError::UnsupportedByPeer | SendDatagramError::TooLarge => debug!("Stream error: {err}"), } } @@ -723,24 +695,16 @@ impl ServerSocket for WebTransportServer { } } -fn extract_client_connection_req(uri: &Uri) -> Result, h3::Error> { - let Some(query) = uri.query() else { +fn extract_client_connection_req(path: &str) -> Result, wtransport::error::ConnectionError> { + let Some((_, query)) = path.split_once('?') else { log::trace!("invalid uri query, dropping connection request..."); - return Err(h3::Error::from(h3::error::Code::H3_REQUEST_INCOMPLETE)); + return Err(wtransport::error::ConnectionError::LocallyClosed); }; - let mut query_elements_iterator = form_urlencoded::parse(query.as_bytes()); - let Some((key, connection_req)) = query_elements_iterator.next() else { + let Some(encoded) = query.split_once(HTTP_CONNECT_REQ).and_then(|(_, r)| r.strip_prefix("=")) else { log::trace!("invalid uri query (missing req), dropping connection request..."); - return Err(h3::Error::from(h3::error::Code::H3_REQUEST_INCOMPLETE)); - }; - if key != HTTP_CONNECT_REQ { - log::trace!("invalid uri query (bad key), dropping connection request..."); - return Err(h3::Error::from(h3::error::Code::H3_REQUEST_INCOMPLETE)); - } - let Ok(connection_req) = serde_json::de::from_str::>(&connection_req) else { - log::trace!("invalid uri query (bad req), dropping connection request..."); - return Err(h3::Error::from(h3::error::Code::H3_REQUEST_INCOMPLETE)); + return Err(wtransport::error::ConnectionError::LocallyClosed); }; + let connection_req = urlencoding::decode_binary(encoded.as_bytes()); - Ok(connection_req) + Ok(connection_req.into()) } diff --git a/renet2/src/transport/webtransport_socket/utils.rs b/renet2/src/transport/webtransport_socket/utils.rs index 20f6094e..9618b838 100644 --- a/renet2/src/transport/webtransport_socket/utils.rs +++ b/renet2/src/transport/webtransport_socket/utils.rs @@ -32,7 +32,7 @@ impl TryFrom> for ServerCertHash { /// - `WebTransportClientConfig::server_dest`: Set the destination here. This tells the client where to connect. /// - `ServerSetupConfig::server_addresses`: Store the destination as a `SocketAddr` in the server addresses for the /// WebTransport server. This is used to validate connect tokens. -/// - `ClientAuthentication`: Use the destination in `ClientAuthentication::Unsecure::server_address` and in +/// - `ClientAuthentication`: Use the destination in `ClientAuthentication::Unsecure::server_address` or in /// `ConnectToken::generate()` for secure auth. The server address is used internally by `renet2` to coordinate /// packet sending and receiving. ///