Skip to content

Commit 075e337

Browse files
Add idempotency lock for axum server
1 parent 8b0fc4a commit 075e337

File tree

2 files changed

+104
-5
lines changed

2 files changed

+104
-5
lines changed

src/lib.rs

Lines changed: 48 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,26 @@ use worker::{event, Context, Date, Env, Request};
44

55
// All imports for server specific things
66
#[cfg(feature = "server")]
7+
use axum::{
8+
async_trait,
9+
extract::FromRequest,
10+
extract::{Path, WebSocketUpgrade},
11+
http::{Request, StatusCode},
12+
TypedHeader,
13+
};
14+
#[cfg(feature = "server")]
715
use axum::{routing::get, Router};
816
#[cfg(feature = "server")]
17+
use std::collections::HashMap;
18+
#[cfg(feature = "server")]
919
use std::env;
1020
#[cfg(feature = "server")]
1121
use std::net::SocketAddr;
1222
#[cfg(feature = "server")]
23+
use std::sync::Arc;
24+
#[cfg(feature = "server")]
25+
use tokio::sync::Mutex;
26+
#[cfg(feature = "server")]
1327
use tower_http::trace::{DefaultMakeSpan, TraceLayer};
1428

1529
mod logger;
@@ -21,15 +35,48 @@ mod cloudflare;
2135
#[cfg(feature = "server")]
2236
mod server;
2337

38+
#[cfg(feature = "server")]
39+
struct IdempotencyKey(Option<String>);
40+
41+
#[cfg(feature = "server")]
42+
#[async_trait]
43+
impl<'a, B> FromRequest<(), B> for IdempotencyKey
44+
where
45+
B: Send + 'static,
46+
{
47+
type Rejection = StatusCode;
48+
49+
async fn from_request(req: Request<B>, _: &()) -> Result<Self, Self::Rejection> {
50+
let headers = req.headers().clone();
51+
let key = headers
52+
.get("Idempotency-Key")
53+
.and_then(|v| v.to_str().ok())
54+
.map(|s| s.to_string());
55+
Ok(IdempotencyKey(key))
56+
}
57+
}
58+
2459
/// Main function for running the program as a server
2560
#[tokio::main]
2661
#[cfg(feature = "server")]
2762
async fn main() {
2863
println!("Running ln-websocket-proxy");
2964
tracing_subscriber::fmt::init();
3065

66+
let locks = Arc::new(Mutex::new(HashMap::new()));
67+
3168
let app = Router::new()
32-
.route("/v1/:ip/:port", get(crate::server::ws_handler))
69+
.route(
70+
"/v1/:ip/:port",
71+
get(
72+
|path: Path<(String, String)>,
73+
ws: WebSocketUpgrade,
74+
user_agent: Option<TypedHeader<headers::UserAgent>>,
75+
idempotency_key: IdempotencyKey| async move {
76+
server::ws_handler(path, ws, user_agent, idempotency_key.0, locks.clone())
77+
},
78+
),
79+
)
3380
.layer(
3481
TraceLayer::new_for_http()
3582
.make_span_with(DefaultMakeSpan::default().include_headers(true)),

src/server.rs

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,27 +6,79 @@ use axum::{
66
},
77
response::IntoResponse,
88
};
9-
use std::time::Duration;
9+
use std::{collections::HashMap, sync::Arc, time::Duration};
1010
use tokio::{
1111
io::{AsyncReadExt, AsyncWriteExt},
1212
net::TcpStream,
13+
sync::Mutex,
1314
};
1415

15-
pub(crate) async fn ws_handler(
16+
struct IdempotencyLockGuard {
17+
key: Option<String>,
18+
locks: Arc<Mutex<HashMap<String, bool>>>,
19+
}
20+
21+
impl IdempotencyLockGuard {
22+
async fn new(
23+
key: Option<String>,
24+
locks: Arc<Mutex<HashMap<String, bool>>>,
25+
) -> Result<Self, String> {
26+
if let Some(ref key) = key {
27+
let mut lock_map = locks.lock().await;
28+
if lock_map.get(key).copied().unwrap_or(false) {
29+
return Err("Idempotency key already in use".to_string());
30+
}
31+
lock_map.insert(key.clone(), true);
32+
}
33+
Ok(Self { key, locks })
34+
}
35+
}
36+
37+
impl Drop for IdempotencyLockGuard {
38+
fn drop(&mut self) {
39+
if let Some(ref key) = self.key {
40+
if let Ok(mut lock_map) = self.locks.try_lock() {
41+
lock_map.insert(key.clone(), false);
42+
// Log that the lock was successfully released if needed
43+
} else {
44+
// Log an error or handle the lock acquisition failure
45+
logger::error("Failed to acquire lock to release idempotency key");
46+
}
47+
}
48+
}
49+
}
50+
51+
pub(crate) fn ws_handler(
1652
Path((ip, port)): Path<(String, String)>,
1753
ws: WebSocketUpgrade,
1854
user_agent: Option<TypedHeader<headers::UserAgent>>,
55+
idempotency_key: Option<String>,
56+
locks: Arc<Mutex<HashMap<String, bool>>>,
1957
) -> impl IntoResponse {
2058
logger::info(&format!("ip: {ip}, port: {port}"));
2159
if let Some(TypedHeader(user_agent)) = user_agent {
2260
logger::info(&format!("`{user_agent}` connected"));
2361
}
2462

2563
ws.protocols(["binary"])
26-
.on_upgrade(move |socket| handle_socket(socket, ip, port))
64+
.on_upgrade(move |socket| handle_socket(socket, ip, port, idempotency_key, locks))
2765
}
2866

29-
pub(crate) async fn handle_socket(mut socket: WebSocket, host: String, port: String) {
67+
pub(crate) async fn handle_socket(
68+
socket: WebSocket,
69+
host: String,
70+
port: String,
71+
idempotency_key: Option<String>,
72+
locks: Arc<Mutex<HashMap<String, bool>>>,
73+
) {
74+
let lock_guard = match IdempotencyLockGuard::new(idempotency_key, locks).await {
75+
Ok(guard) => guard,
76+
Err(e) => {
77+
logger::error(&e);
78+
return;
79+
}
80+
};
81+
3082
let server_stream = match connect_to_addr(host, port, server_tcp_connect) {
3183
Ok(s) => s,
3284
Err(e) => {

0 commit comments

Comments
 (0)