Skip to content

Commit cda2603

Browse files
authored
Merge pull request #778 from hyperware-ai/hf/dont-crash-on-web-socket-close
dont crash on web socket close
2 parents e082831 + 7dfcfb7 commit cda2603

File tree

5 files changed

+102
-56
lines changed

5 files changed

+102
-56
lines changed

Cargo.lock

Lines changed: 2 additions & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hyperdrive/src/http/server.rs

Lines changed: 23 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
use crate::http::server_types::{
2-
HttpResponse, HttpServerAction, HttpServerError, HttpServerRequest, IncomingHttpRequest,
3-
MessageType, RpcResponseBody, WsMessageType,
2+
HttpResponse, HttpResponseSenders, HttpSender, HttpServerAction, HttpServerError,
3+
HttpServerRequest, IncomingHttpRequest, MessageType, RpcResponseBody, WebSocketSender,
4+
WebSocketSenders, WsMessageType,
45
};
5-
use crate::http::utils;
6+
use crate::http::utils::{self, send_action_response};
67
use crate::keygen;
78
use base64::{engine::general_purpose::STANDARD as base64_standard, Engine};
89
use dashmap::DashMap;
@@ -39,18 +40,6 @@ const WS_SELF_IMPOSED_MAX_CONNECTIONS: u32 = 128;
3940

4041
const LOGIN_HTML: &str = include_str!("login.html");
4142

42-
/// mapping from a given HTTP request (assigned an ID) to the oneshot
43-
/// channel that will get a response from the app that handles the request,
44-
/// and a string which contains the path that the request was made to.
45-
type HttpResponseSenders = Arc<DashMap<u64, (String, HttpSender)>>;
46-
type HttpSender = tokio::sync::oneshot::Sender<(HttpResponse, Vec<u8>)>;
47-
48-
/// mapping from an open websocket connection to a channel that will ingest
49-
/// WebSocketPush messages from the app that handles the connection, and
50-
/// send them to the connection.
51-
type WebSocketSenders = Arc<DashMap<u32, (ProcessId, WebSocketSender)>>;
52-
type WebSocketSender = tokio::sync::mpsc::Sender<warp::ws::Message>;
53-
5443
type PathBindings = Arc<RwLock<Router<BoundPath>>>;
5544
type WsPathBindings = Arc<RwLock<Router<BoundWsPath>>>;
5645
type SecureSubdomains = Arc<RwLock<HashSet<String>>>;
@@ -121,7 +110,14 @@ async fn send_push(
121110
}
122111
}
123112
WsMessageType::Close => {
124-
unreachable!();
113+
return utils::handle_close_websocket(
114+
id,
115+
&source,
116+
send_to_loop,
117+
ws_senders,
118+
channel_id,
119+
)
120+
.await;
125121
}
126122
};
127123
// Send to the websocket if registered
@@ -1546,19 +1542,17 @@ async fn handle_app_message(
15461542
return;
15471543
}
15481544
HttpServerAction::WebSocketClose(channel_id) => {
1549-
if let Some(got) = ws_senders.get(&channel_id) {
1550-
if got.value().0 != km.source.process {
1551-
send_action_response(
1552-
km.id,
1553-
km.source,
1554-
&send_to_loop,
1555-
Err(HttpServerError::WsChannelNotFound),
1556-
)
1557-
.await;
1558-
return;
1559-
}
1560-
let _ = got.value().1.send(warp::ws::Message::close()).await;
1561-
ws_senders.remove(&channel_id);
1545+
let is_return = utils::handle_close_websocket(
1546+
km.id,
1547+
&km.source,
1548+
&send_to_loop,
1549+
ws_senders,
1550+
channel_id,
1551+
)
1552+
.await;
1553+
1554+
if is_return {
1555+
return;
15621556
}
15631557
}
15641558
}
@@ -1569,28 +1563,3 @@ async fn handle_app_message(
15691563
}
15701564
}
15711565
}
1572-
1573-
pub async fn send_action_response(
1574-
id: u64,
1575-
target: Address,
1576-
send_to_loop: &MessageSender,
1577-
result: Result<(), HttpServerError>,
1578-
) {
1579-
KernelMessage::builder()
1580-
.id(id)
1581-
.source(("our", HTTP_SERVER_PROCESS_ID.clone()))
1582-
.target(target)
1583-
.message(Message::Response((
1584-
Response {
1585-
inherit: false,
1586-
body: serde_json::to_vec(&result).unwrap(),
1587-
metadata: None,
1588-
capabilities: vec![],
1589-
},
1590-
None,
1591-
)))
1592-
.build()
1593-
.unwrap()
1594-
.send(send_to_loop)
1595-
.await;
1596-
}

hyperdrive/src/http/utils.rs

Lines changed: 60 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,12 @@
1+
use crate::http::server_types::{HttpServerError, WebSocketSenders};
12
use hmac::{Hmac, Mac};
23
use jwt::VerifyWithKey;
3-
use lib::{core::ProcessId, types::http_server};
4+
use lib::types::{
5+
core::{
6+
Address, KernelMessage, Message, MessageSender, ProcessId, Response, HTTP_SERVER_PROCESS_ID,
7+
},
8+
http_server,
9+
};
410
use serde::{Deserialize, Serialize};
511
use sha2::Sha256;
612
use std::collections::HashMap;
@@ -158,3 +164,56 @@ pub fn is_behind_reverse_proxy(headers: &warp::http::HeaderMap) -> bool {
158164
}
159165
return false;
160166
}
167+
168+
pub async fn handle_close_websocket(
169+
id: u64,
170+
source: &Address,
171+
send_to_loop: &MessageSender,
172+
ws_senders: WebSocketSenders,
173+
channel_id: u32,
174+
) -> bool {
175+
let Some(got) = ws_senders.get(&channel_id) else {
176+
return false;
177+
};
178+
179+
if got.value().0 != source.process {
180+
send_action_response(
181+
id,
182+
source.clone(),
183+
send_to_loop,
184+
Err(HttpServerError::WsChannelNotFound),
185+
)
186+
.await;
187+
return true;
188+
}
189+
190+
let _ = got.value().1.send(warp::ws::Message::close()).await;
191+
ws_senders.remove(&channel_id);
192+
193+
return false;
194+
}
195+
196+
pub async fn send_action_response(
197+
id: u64,
198+
target: Address,
199+
send_to_loop: &MessageSender,
200+
result: Result<(), HttpServerError>,
201+
) {
202+
KernelMessage::builder()
203+
.id(id)
204+
.source(("our", HTTP_SERVER_PROCESS_ID.clone()))
205+
.target(target)
206+
.message(Message::Response((
207+
Response {
208+
inherit: false,
209+
body: serde_json::to_vec(&result).unwrap(),
210+
metadata: None,
211+
capabilities: vec![],
212+
},
213+
None,
214+
)))
215+
.build()
216+
.unwrap()
217+
.send(send_to_loop)
218+
.await;
219+
}

lib/Cargo.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ alloy = { version = "0.8.1", features = [
2323
"rpc-types",
2424
"rpc-types-eth",
2525
] }
26+
dashmap = "5.5.3"
2627
lazy_static = "1.4.0"
2728
rand = "0.8.4"
2829
regex = "1.11.0"
@@ -32,4 +33,5 @@ serde = { version = "1.0", features = ["derive"] }
3233
serde_json = "1.0"
3334
thiserror = "1.0"
3435
tokio = { version = "1.28", features = ["sync"] }
36+
warp = "0.3.5"
3537
wasmtime = { version = "33.0.0", features = ["component-model"] }

lib/src/http/server_types.rs

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1-
use crate::core::LazyLoadBlob;
1+
use crate::core::{LazyLoadBlob, ProcessId};
2+
use dashmap::DashMap;
23
use serde::{Deserialize, Serialize};
34
use std::collections::HashMap;
5+
use std::sync::Arc;
46
use thiserror::Error;
57

68
/// HTTP Request received from the `http-server:distro:sys` service as a
@@ -201,3 +203,15 @@ pub struct JwtClaims {
201203
pub subdomain: Option<String>,
202204
pub expiration: u64,
203205
}
206+
207+
/// mapping from a given HTTP request (assigned an ID) to the oneshot
208+
/// channel that will get a response from the app that handles the request,
209+
/// and a string which contains the path that the request was made to.
210+
pub type HttpResponseSenders = Arc<DashMap<u64, (String, HttpSender)>>;
211+
pub type HttpSender = tokio::sync::oneshot::Sender<(HttpResponse, Vec<u8>)>;
212+
213+
/// mapping from an open websocket connection to a channel that will ingest
214+
/// WebSocketPush messages from the app that handles the connection, and
215+
/// send them to the connection.
216+
pub type WebSocketSenders = Arc<DashMap<u32, (ProcessId, WebSocketSender)>>;
217+
pub type WebSocketSender = tokio::sync::mpsc::Sender<warp::ws::Message>;

0 commit comments

Comments
 (0)