Skip to content

Commit 32d5bc1

Browse files
authored
feat: Accept handler (#116)
2 parents a949899 + 30ce4cf commit 32d5bc1

File tree

8 files changed

+127
-107
lines changed

8 files changed

+127
-107
lines changed

Cargo.toml

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ quinn = { package = "iroh-quinn", version = "0.12", optional = true }
2727
serde = { version = "1.0.183", features = ["derive"] }
2828
tokio = { version = "1", default-features = false, features = ["macros", "sync"] }
2929
tokio-serde = { version = "0.8", features = ["bincode"], optional = true }
30-
tokio-util = { version = "0.7", features = ["codec"], optional = true }
30+
tokio-util = { version = "0.7", features = ["rt"] }
3131
tracing = "0.1"
3232
hex = "0.4.3"
3333
futures = { version = "0.3.30", optional = true }
@@ -52,12 +52,13 @@ proc-macro2 = "1.0.66"
5252
futures-buffered = "0.2.4"
5353
testresult = "0.4.1"
5454
nested_enum_utils = "0.1.0"
55+
tokio-util = { version = "0.7", features = ["rt"] }
5556

5657
[features]
57-
hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "dep:tokio-util"]
58-
quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"]
58+
hyper-transport = ["dep:flume", "dep:hyper", "dep:bincode", "dep:bytes", "dep:tokio-serde", "tokio-util/codec"]
59+
quinn-transport = ["dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "tokio-util/codec"]
5960
flume-transport = ["dep:flume"]
60-
iroh-net-transport = ["dep:iroh-net", "dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "dep:tokio-util"]
61+
iroh-net-transport = ["dep:iroh-net", "dep:flume", "dep:quinn", "dep:bincode", "dep:tokio-serde", "tokio-util/codec"]
6162
macros = []
6263
default = ["flume-transport"]
6364

examples/modularize.rs

Lines changed: 4 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@ use app::AppService;
1212
use futures_lite::StreamExt;
1313
use futures_util::SinkExt;
1414
use quic_rpc::{client::BoxedConnector, transport::flume, Listener, RpcClient, RpcServer};
15-
use tracing::warn;
1615

1716
#[tokio::main]
1817
async fn main() -> Result<()> {
@@ -32,23 +31,11 @@ async fn main() -> Result<()> {
3231

3332
async fn run_server<C: Listener<AppService>>(server_conn: C, handler: app::Handler) {
3433
let server = RpcServer::<AppService, _>::new(server_conn);
35-
loop {
36-
let Ok(accepting) = server.accept().await else {
37-
continue;
38-
};
39-
match accepting.read_first().await {
40-
Err(err) => warn!(?err, "server accept failed"),
41-
Ok((req, chan)) => {
42-
let handler = handler.clone();
43-
tokio::task::spawn(async move {
44-
if let Err(err) = handler.handle_rpc_request(req, chan).await {
45-
warn!(?err, "internal rpc error");
46-
}
47-
});
48-
}
49-
}
50-
}
34+
server
35+
.accept_loop(move |req, chan| handler.clone().handle_rpc_request(req, chan))
36+
.await
5137
}
38+
5239
pub async fn client_demo(conn: BoxedConnector<AppService>) -> Result<()> {
5340
let rpc_client = RpcClient::<AppService>::new(conn);
5441
let client = app::Client::new(rpc_client.clone());

src/server.rs

Lines changed: 66 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,13 +7,16 @@ use std::{
77
marker::PhantomData,
88
pin::Pin,
99
result,
10+
sync::Arc,
1011
task::{self, Poll},
1112
};
1213

1314
use futures_lite::{Future, Stream, StreamExt};
1415
use futures_util::{SinkExt, TryStreamExt};
1516
use pin_project::pin_project;
16-
use tokio::sync::oneshot;
17+
use tokio::{sync::oneshot, task::JoinSet};
18+
use tokio_util::task::AbortOnDropHandle;
19+
use tracing::{error, warn};
1720

1821
use crate::{
1922
transport::{
@@ -211,6 +214,68 @@ impl<S: Service, C: Listener<S>> RpcServer<S, C> {
211214
pub fn into_inner(self) -> C {
212215
self.source
213216
}
217+
218+
/// Run an accept loop for this server.
219+
///
220+
/// Each request will be handled in a separate task.
221+
///
222+
/// It is the caller's responsibility to poll the returned future to drive the server.
223+
pub async fn accept_loop<Fun, Fut, E>(self, handler: Fun)
224+
where
225+
S: Service,
226+
C: Listener<S>,
227+
Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
228+
Fut: Future<Output = Result<(), E>> + Send + 'static,
229+
E: Into<anyhow::Error> + 'static,
230+
{
231+
let handler = Arc::new(handler);
232+
let mut tasks = JoinSet::new();
233+
loop {
234+
tokio::select! {
235+
Some(res) = tasks.join_next(), if !tasks.is_empty() => {
236+
if let Err(e) = res {
237+
if e.is_panic() {
238+
error!("Panic handling RPC request: {e}");
239+
}
240+
}
241+
}
242+
req = self.accept() => {
243+
let req = match req {
244+
Ok(req) => req,
245+
Err(e) => {
246+
warn!("Error accepting RPC request: {e}");
247+
continue;
248+
}
249+
};
250+
let handler = handler.clone();
251+
tasks.spawn(async move {
252+
let (req, chan) = match req.read_first().await {
253+
Ok((req, chan)) => (req, chan),
254+
Err(e) => {
255+
warn!("Error reading first message: {e}");
256+
return;
257+
}
258+
};
259+
if let Err(cause) = handler(req, chan).await {
260+
warn!("Error handling RPC request: {}", cause.into());
261+
}
262+
});
263+
}
264+
}
265+
}
266+
}
267+
268+
/// Spawn an accept loop and return a handle to the task.
269+
pub fn spawn_accept_loop<Fun, Fut, E>(self, handler: Fun) -> AbortOnDropHandle<()>
270+
where
271+
S: Service,
272+
C: Listener<S>,
273+
Fun: Fn(S::Req, RpcChannel<S, C>) -> Fut + Send + Sync + 'static,
274+
Fut: Future<Output = Result<(), E>> + Send + 'static,
275+
E: Into<anyhow::Error> + 'static,
276+
{
277+
AbortOnDropHandle::new(tokio::spawn(self.accept_loop(handler)))
278+
}
214279
}
215280

216281
impl<S: Service, C: Listener<S>> AsRef<C> for RpcServer<S, C> {

tests/flume.rs

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,17 @@ use quic_rpc::{
77
transport::flume,
88
RpcClient, RpcServer, Service,
99
};
10+
use tokio_util::task::AbortOnDropHandle;
1011

1112
#[tokio::test]
1213
async fn flume_channel_bench() -> anyhow::Result<()> {
1314
tracing_subscriber::fmt::try_init().ok();
1415
let (server, client) = flume::channel(1);
1516

1617
let server = RpcServer::<ComputeService, _>::new(server);
17-
let server_handle = tokio::task::spawn(ComputeService::server(server));
18+
let _server_handle = AbortOnDropHandle::new(tokio::spawn(ComputeService::server(server)));
1819
let client = RpcClient::<ComputeService, _>::new(client);
1920
bench(client, 1000000).await?;
20-
// dropping the client will cause the server to terminate
21-
match server_handle.await? {
22-
Err(RpcServerError::Accept(_)) => {}
23-
e => panic!("unexpected termination result {e:?}"),
24-
}
2521
Ok(())
2622
}
2723

@@ -101,13 +97,7 @@ async fn flume_channel_smoke() -> anyhow::Result<()> {
10197
let (server, client) = flume::channel(1);
10298

10399
let server = RpcServer::<ComputeService, _>::new(server);
104-
let server_handle = tokio::task::spawn(ComputeService::server(server));
100+
let _server_handle = AbortOnDropHandle::new(tokio::spawn(ComputeService::server(server)));
105101
smoke_test(client).await?;
106-
107-
// dropping the client will cause the server to terminate
108-
match server_handle.await? {
109-
Err(RpcServerError::Accept(_)) => {}
110-
e => panic!("unexpected termination result {e:?}"),
111-
}
112102
Ok(())
113103
}

tests/hyper.rs

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -15,19 +15,13 @@ use tokio::task::JoinHandle;
1515

1616
mod math;
1717
use math::*;
18+
use tokio_util::task::AbortOnDropHandle;
1819
mod util;
1920

20-
fn run_server(addr: &SocketAddr) -> JoinHandle<anyhow::Result<()>> {
21+
fn run_server(addr: &SocketAddr) -> AbortOnDropHandle<()> {
2122
let channel = HyperListener::serve(addr).unwrap();
2223
let server = RpcServer::new(channel);
23-
tokio::spawn(async move {
24-
loop {
25-
let server = server.clone();
26-
ComputeService::server(server).await?;
27-
}
28-
#[allow(unreachable_code)]
29-
anyhow::Ok(())
30-
})
24+
ComputeService::server(server)
3125
}
3226

3327
#[derive(Debug, Serialize, Deserialize, From, TryInto)]
@@ -133,25 +127,21 @@ impl TestService {
133127
async fn hyper_channel_bench() -> anyhow::Result<()> {
134128
let addr: SocketAddr = "127.0.0.1:3000".parse()?;
135129
let uri: Uri = "http://127.0.0.1:3000".parse()?;
136-
let server_handle = run_server(&addr);
130+
let _server_handle = run_server(&addr);
137131
let client = HyperConnector::new(uri);
138132
let client = RpcClient::new(client);
139133
bench(client, 50000).await?;
140134
println!("terminating server");
141-
server_handle.abort();
142-
let _ = server_handle.await;
143135
Ok(())
144136
}
145137

146138
#[tokio::test]
147139
async fn hyper_channel_smoke() -> anyhow::Result<()> {
148140
let addr: SocketAddr = "127.0.0.1:3001".parse()?;
149141
let uri: Uri = "http://127.0.0.1:3001".parse()?;
150-
let server_handle = run_server(&addr);
142+
let _server_handle = run_server(&addr);
151143
let client = HyperConnector::new(uri);
152144
smoke_test(client).await?;
153-
server_handle.abort();
154-
let _ = server_handle.await;
155145
Ok(())
156146
}
157147

@@ -302,6 +292,5 @@ async fn hyper_channel_errors() -> anyhow::Result<()> {
302292

303293
println!("terminating server");
304294
server_handle.abort();
305-
let _ = server_handle.await;
306295
Ok(())
307296
}

tests/iroh-net.rs

Lines changed: 15 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,13 @@
22

33
use iroh_net::{key::SecretKey, NodeAddr};
44
use quic_rpc::{transport, RpcClient, RpcServer};
5-
use tokio::task::JoinHandle;
5+
use testresult::TestResult;
6+
7+
use crate::transport::iroh_net::{IrohNetConnector, IrohNetListener};
68

79
mod math;
810
use math::*;
11+
use tokio_util::task::AbortOnDropHandle;
912
mod util;
1013

1114
const ALPN: &[u8] = b"quic-rpc/iroh-net/test";
@@ -44,13 +47,10 @@ impl Endpoints {
4447
}
4548
}
4649

47-
fn run_server(server: iroh_net::Endpoint) -> JoinHandle<anyhow::Result<()>> {
48-
tokio::task::spawn(async move {
49-
let connection = transport::iroh_net::IrohNetListener::new(server)?;
50-
let server = RpcServer::new(connection);
51-
ComputeService::server(server).await?;
52-
anyhow::Ok(())
53-
})
50+
fn run_server(server: iroh_net::Endpoint) -> AbortOnDropHandle<()> {
51+
let connection = IrohNetListener::new(server).unwrap();
52+
let server = RpcServer::new(connection);
53+
ComputeService::server(server)
5454
}
5555

5656
// #[tokio::test(flavor = "multi_thread", worker_threads = 2)]
@@ -64,17 +64,12 @@ async fn iroh_net_channel_bench() -> anyhow::Result<()> {
6464
server_node_addr,
6565
} = Endpoints::new().await?;
6666
tracing::debug!("Starting server");
67-
let server_handle = run_server(server);
67+
let _server_handle = run_server(server);
6868
tracing::debug!("Starting client");
6969

70-
let client = RpcClient::new(transport::iroh_net::IrohNetConnector::new(
71-
client,
72-
server_node_addr,
73-
ALPN.into(),
74-
));
70+
let client = RpcClient::new(IrohNetConnector::new(client, server_node_addr, ALPN.into()));
7571
tracing::debug!("Starting benchmark");
7672
bench(client, 50000).await?;
77-
server_handle.abort();
7873
Ok(())
7974
}
8075

@@ -86,11 +81,9 @@ async fn iroh_net_channel_smoke() -> anyhow::Result<()> {
8681
server,
8782
server_node_addr,
8883
} = Endpoints::new().await?;
89-
let server_handle = run_server(server);
90-
let client_connection =
91-
transport::iroh_net::IrohNetConnector::new(client, server_node_addr, ALPN.into());
84+
let _server_handle = run_server(server);
85+
let client_connection = IrohNetConnector::new(client, server_node_addr, ALPN.into());
9286
smoke_test(client_connection).await?;
93-
server_handle.abort();
9487
Ok(())
9588
}
9689

@@ -99,7 +92,7 @@ async fn iroh_net_channel_smoke() -> anyhow::Result<()> {
9992
///
10093
/// This is a regression test.
10194
#[tokio::test]
102-
async fn server_away_and_back() -> anyhow::Result<()> {
95+
async fn server_away_and_back() -> TestResult<()> {
10396
tracing_subscriber::fmt::try_init().ok();
10497
tracing::info!("Creating endpoints");
10598

@@ -128,7 +121,7 @@ async fn server_away_and_back() -> anyhow::Result<()> {
128121
// create the RPC Server
129122
let connection = transport::iroh_net::IrohNetListener::new(server_endpoint.clone())?;
130123
let server = RpcServer::new(connection);
131-
let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 1));
124+
let server_handle = tokio::spawn(ComputeService::server_bounded(server, 1));
132125

133126
// wait a bit for connection due to Windows test failing on CI
134127
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
@@ -151,7 +144,7 @@ async fn server_away_and_back() -> anyhow::Result<()> {
151144
// make the server run again
152145
let connection = transport::iroh_net::IrohNetListener::new(server_endpoint.clone())?;
153146
let server = RpcServer::new(connection);
154-
let server_handle = tokio::task::spawn(ComputeService::server_bounded(server, 5));
147+
let server_handle = tokio::spawn(ComputeService::server_bounded(server, 5));
155148

156149
// wait a bit for connection due to Windows test failing on CI
157150
tokio::time::sleep(tokio::time::Duration::from_millis(300)).await;
@@ -163,7 +156,6 @@ async fn server_away_and_back() -> anyhow::Result<()> {
163156
// server is running, this should work
164157
let SqrResponse(response) = client.rpc(Sqr(3)).await?;
165158
assert_eq!(response, 9);
166-
167159
server_handle.abort();
168160
Ok(())
169161
}

tests/math.rs

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use quic_rpc::{
2626
};
2727
use serde::{Deserialize, Serialize};
2828
use thousands::Separable;
29+
use tokio_util::task::AbortOnDropHandle;
2930

3031
/// compute the square of a number
3132
#[derive(Debug, Serialize, Deserialize)]
@@ -163,20 +164,14 @@ impl ComputeService {
163164
}
164165
}
165166

166-
pub async fn server<C: Listener<ComputeService>>(
167+
pub fn server<C: Listener<ComputeService>>(
167168
server: RpcServer<ComputeService, C>,
168-
) -> result::Result<(), RpcServerError<C>> {
169-
let s = server;
170-
let service = ComputeService;
171-
loop {
172-
let (req, chan) = s.accept().await?.read_first().await?;
173-
let service = service.clone();
174-
tokio::spawn(async move { Self::handle_rpc_request(service, req, chan).await });
175-
}
169+
) -> AbortOnDropHandle<()> {
170+
server.spawn_accept_loop(|req, chan| Self::handle_rpc_request(ComputeService, req, chan))
176171
}
177172

178173
pub async fn handle_rpc_request<E>(
179-
service: ComputeService,
174+
self,
180175
req: ComputeRequest,
181176
chan: RpcChannel<ComputeService, E>,
182177
) -> Result<(), RpcServerError<E>>
@@ -186,10 +181,10 @@ impl ComputeService {
186181
use ComputeRequest::*;
187182
#[rustfmt::skip]
188183
match req {
189-
Sqr(msg) => chan.rpc(msg, service, ComputeService::sqr).await,
190-
Sum(msg) => chan.client_streaming(msg, service, ComputeService::sum).await,
191-
Fibonacci(msg) => chan.server_streaming(msg, service, ComputeService::fibonacci).await,
192-
Multiply(msg) => chan.bidi_streaming(msg, service, ComputeService::multiply).await,
184+
Sqr(msg) => chan.rpc(msg, self, Self::sqr).await,
185+
Sum(msg) => chan.client_streaming(msg, self, Self::sum).await,
186+
Fibonacci(msg) => chan.server_streaming(msg, self, Self::fibonacci).await,
187+
Multiply(msg) => chan.bidi_streaming(msg, self, Self::multiply).await,
193188
MultiplyUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
194189
SumUpdate(_) => Err(RpcServerError::UnexpectedStartMessage)?,
195190
}?;

0 commit comments

Comments
 (0)