Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1093,7 +1093,7 @@ pub mod rpc {
use std::{fmt::Debug, future::Future, io, marker::PhantomData, pin::Pin, sync::Arc};

use n0_future::{future::Boxed as BoxFuture, task::JoinSet};
use quinn::ConnectionError;
use quinn::{ConnectionError, Endpoint};
use serde::{de::DeserializeOwned, Serialize};
use smallvec::SmallVec;
use tracing::{trace, trace_span, warn, Instrument};
Expand Down Expand Up @@ -1470,6 +1470,89 @@ pub mod rpc {
request_id += 1;
}
}

type MultiHandler = Arc<
dyn Fn(
&[u8],
quinn::RecvStream,
quinn::SendStream,
) -> std::result::Result<
BoxFuture<std::result::Result<(), SendError>>,
(quinn::RecvStream, quinn::SendStream),
> + Send
+ Sync
+ 'static,
>;

pub struct Listener {
handlers: Vec<MultiHandler>,
}

impl Listener {
pub fn add_handler<R: DeserializeOwned + 'static>(mut self, handler: Handler<R>) -> Self {
self.handlers.push(Arc::new(
move |buf, recv, send| match postcard::from_bytes::<R>(buf) {
Err(_) => Err((recv, send)),
Ok(msg) => Ok(handler(msg, recv, send)),
},
));
self
}

pub async fn listen(self, endpoint: Endpoint) {
let mut request_id = 0u64;
let mut tasks = JoinSet::new();
while let Some(incoming) = endpoint.accept().await {
let handlers = self.handlers.clone();
let fut = async move {
let connection = match incoming.await {
Ok(connection) => connection,
Err(cause) => {
warn!("failed to accept connection {cause:?}");
return io::Result::Ok(());
}
};
loop {
let (mut send, mut recv) = match connection.accept_bi().await {
Ok((s, r)) => (s, r),
Err(ConnectionError::ApplicationClosed(cause))
if cause.error_code.into_inner() == 0 =>
{
trace!("remote side closed connection {cause:?}");
return Ok(());
}
Err(cause) => {
warn!("failed to accept bi stream {cause:?}");
return Err(cause.into());
}
};
let size = recv.read_varint_u64().await?.ok_or_else(|| {
io::Error::new(io::ErrorKind::UnexpectedEof, "failed to read size")
})?;
let mut buf = vec![0; size as usize];
recv.read_exact(&mut buf)
.await
.map_err(|e| io::Error::new(io::ErrorKind::UnexpectedEof, e))?;
for handler in &handlers {
match handler(&buf, recv, send) {
Ok(fut) => {
fut.await?;
break;
}
Err((recv_ret, send_ret)) => {
recv = recv_ret;
send = send_ret;
}
}
}
}
};
let span = trace_span!("rpc", id = request_id);
tasks.spawn(fut.instrument(span));
request_id += 1;
}
}
}
}

/// A request to a service. This can be either local or remote.
Expand Down