Skip to content

Commit de38f24

Browse files
Add idempotency lock for axum server
1 parent 8b0fc4a commit de38f24

File tree

2 files changed

+107
-5
lines changed

2 files changed

+107
-5
lines changed

src/lib.rs

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,24 @@ use worker::{event, Context, Date, Env, Request};
44

55
// All imports for server specific things
66
#[cfg(feature = "server")]
7+
use axum::{
8+
async_trait,
9+
extract::FromRequest,
10+
http::{Request, StatusCode},
11+
};
12+
#[cfg(feature = "server")]
713
use axum::{routing::get, Router};
814
#[cfg(feature = "server")]
15+
use std::collections::HashMap;
16+
#[cfg(feature = "server")]
917
use std::env;
1018
#[cfg(feature = "server")]
1119
use std::net::SocketAddr;
1220
#[cfg(feature = "server")]
21+
use std::sync::Arc;
22+
#[cfg(feature = "server")]
23+
use tokio::sync::Mutex;
24+
#[cfg(feature = "server")]
1325
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
1426

1527
mod logger;
@@ -21,15 +33,53 @@ mod cloudflare;
2133
#[cfg(feature = "server")]
2234
mod server;
2335

36+
#[cfg(feature = "server")]
37+
struct IdempotencyKey(Option<String>);
38+
39+
#[cfg(feature = "server")]
40+
#[async_trait]
41+
impl<'a, B> FromRequest<(), B> for IdempotencyKey
42+
where
43+
B: Send + 'static,
44+
{
45+
type Rejection = StatusCode;
46+
47+
async fn from_request(req: Request<B>, _: &()) -> Result<Self, Self::Rejection> {
48+
let headers = req.headers().clone();
49+
let key = headers
50+
.get("Idempotency-Key")
51+
.and_then(|v| v.to_str().ok())
52+
.map(|s| s.to_string());
53+
Ok(IdempotencyKey(key))
54+
}
55+
}
56+
2457
/// Main function for running the program as a server
2558
#[tokio::main]
2659
#[cfg(feature = "server")]
2760
async fn main() {
61+
use axum::{
62+
extract::{Path, WebSocketUpgrade},
63+
TypedHeader,
64+
};
65+
2866
println!("Running ln-websocket-proxy");
2967
tracing_subscriber::fmt::init();
3068

69+
let locks = Arc::new(Mutex::new(HashMap::new()));
70+
3171
let app = Router::new()
32-
.route("/v1/:ip/:port", get(crate::server::ws_handler))
72+
.route(
73+
"/v1/:ip/:port",
74+
get(
75+
|path: Path<(String, String)>,
76+
ws: WebSocketUpgrade,
77+
user_agent: Option<TypedHeader<headers::UserAgent>>,
78+
idempotency_key: IdempotencyKey| async move {
79+
server::ws_handler(path, ws, user_agent, idempotency_key.0, locks.clone())
80+
},
81+
),
82+
)
3383
.layer(
3484
TraceLayer::new_for_http()
3585
.make_span_with(DefaultMakeSpan::default().include_headers(true)),

src/server.rs

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,79 @@ use axum::{
66
},
77
response::IntoResponse,
88
};
9-
use std::time::Duration;
9+
use std::{collections::HashMap, sync::Arc, time::Duration};
1010
use tokio::{
1111
io::{AsyncReadExt, AsyncWriteExt},
1212
net::TcpStream,
13+
sync::Mutex,
1314
};
1415

15-
pub(crate) async fn ws_handler(
16+
struct IdempotencyLockGuard {
17+
key: Option<String>,
18+
locks: Arc<Mutex<HashMap<String, bool>>>,
19+
}
20+
21+
impl IdempotencyLockGuard {
22+
async fn new(
23+
key: Option<String>,
24+
locks: Arc<Mutex<HashMap<String, bool>>>,
25+
) -> Result<Self, String> {
26+
if let Some(ref key) = key {
27+
let mut lock_map = locks.lock().await;
28+
if lock_map.get(key).copied().unwrap_or(false) {
29+
return Err("Idempotency key already in use".to_string());
30+
}
31+
lock_map.insert(key.clone(), true);
32+
}
33+
Ok(Self { key, locks })
34+
}
35+
}
36+
37+
impl Drop for IdempotencyLockGuard {
38+
fn drop(&mut self) {
39+
if let Some(ref key) = self.key {
40+
if let Ok(mut lock_map) = self.locks.try_lock() {
41+
lock_map.insert(key.clone(), false);
42+
// Log that the lock was successfully released if needed
43+
} else {
44+
// Log an error or handle the lock acquisition failure
45+
logger::error("Failed to acquire lock to release idempotency key");
46+
}
47+
}
48+
}
49+
}
50+
51+
pub(crate) fn ws_handler(
1652
Path((ip, port)): Path<(String, String)>,
1753
ws: WebSocketUpgrade,
1854
user_agent: Option<TypedHeader<headers::UserAgent>>,
55+
idempotency_key: Option<String>,
56+
locks: Arc<Mutex<HashMap<String, bool>>>,
1957
) -> impl IntoResponse {
2058
logger::info(&format!("ip: {ip}, port: {port}"));
2159
if let Some(TypedHeader(user_agent)) = user_agent {
2260
logger::info(&format!("`{user_agent}` connected"));
2361
}
2462

2563
ws.protocols(["binary"])
26-
.on_upgrade(move |socket| handle_socket(socket, ip, port))
64+
.on_upgrade(move |socket| handle_socket(socket, ip, port, idempotency_key, locks))
2765
}
2866

29-
pub(crate) async fn handle_socket(mut socket: WebSocket, host: String, port: String) {
67+
pub(crate) async fn handle_socket(
68+
socket: WebSocket,
69+
host: String,
70+
port: String,
71+
idempotency_key: Option<String>,
72+
locks: Arc<Mutex<HashMap<String, bool>>>,
73+
) {
74+
let lock_guard = match IdempotencyLockGuard::new(idempotency_key, locks).await {
75+
Ok(guard) => guard,
76+
Err(e) => {
77+
logger::error(&e);
78+
return;
79+
}
80+
};
81+
3082
let server_stream = match connect_to_addr(host, port, server_tcp_connect) {
3183
Ok(s) => s,
3284
Err(e) => {

0 commit comments

Comments
 (0)