Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove Tokio related channel stuff #484

Open
wants to merge 5 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion plugins/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -736,7 +736,12 @@ impl<'a> ServiceGenerator<'a> {
#vis fn new<T>(config: ::tarpc::client::Config, transport: T)
-> ::tarpc::client::NewClient<
Self,
::tarpc::client::RequestDispatch<#request_ident, #response_ident, T>
::tarpc::client::RequestDispatch<
#request_ident,
#response_ident,
T,
::tarpc::util::delay_queue::DelayQueue<u64>
>
>
where
T: ::tarpc::Transport<::tarpc::ClientMessage<#request_ident>, ::tarpc::Response<#response_ident>>
Expand Down
9 changes: 4 additions & 5 deletions tarpc/src/cancellations.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
use futures::{prelude::*, task::*};
use futures::{channel::mpsc, prelude::*, task::*};
use std::pin::Pin;
use tokio::sync::mpsc;

/// Sends request cancellation signals.
#[derive(Debug, Clone)]
Expand All @@ -14,7 +13,7 @@ pub struct CanceledRequests(mpsc::UnboundedReceiver<u64>);
pub fn cancellations() -> (RequestCancellation, CanceledRequests) {
// Unbounded because messages are sent in the drop fn. This is fine, because it's still
// bounded by the number of in-flight requests.
let (tx, rx) = mpsc::unbounded_channel();
let (tx, rx) = mpsc::unbounded();
(RequestCancellation(tx), CanceledRequests(rx))
}

Expand All @@ -29,14 +28,14 @@ impl RequestCancellation {
/// useful primarily when request processing ends prematurely for requests with long deadlines
/// which would otherwise continue to be tracked by the backing channel—a kind of leak.
pub fn cancel(&self, request_id: u64) {
let _ = self.0.send(request_id);
let _ = self.0.unbounded_send(request_id);
}
}

impl CanceledRequests {
/// Polls for a cancelled request.
pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<u64>> {
self.0.poll_recv(cx)
self.0.poll_next_unpin(cx)
}
}

Expand Down
127 changes: 59 additions & 68 deletions tarpc/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@ pub mod stub;
use crate::{
cancellations::{cancellations, CanceledRequests, RequestCancellation},
context, trace,
util::TimeUntil,
util::{
delay_queue::{DelayQueue, DelayQueueLike},
TimeUntil,
},
ChannelError, ClientMessage, Request, RequestName, Response, ServerError, Transport,
};
use futures::{prelude::*, ready, stream::Fuse, task::*};
use futures::{
channel::{mpsc, oneshot},
prelude::*,
ready,
stream::Fuse,
task::*,
};
use in_flight_requests::InFlightRequests;
use pin_project::pin_project;
use std::{
Expand All @@ -29,7 +38,6 @@ use std::{
},
time::SystemTime,
};
use tokio::sync::{mpsc, oneshot};
use tracing::Span;

/// Settings that control the behavior of the client.
Expand Down Expand Up @@ -152,6 +160,7 @@ where
cancel: true,
};
self.to_dispatch
.clone()
.send(DispatchRequest {
ctx,
span,
Expand All @@ -160,7 +169,7 @@ where
response_completion,
})
.await
.map_err(|mpsc::error::SendError(_)| RpcError::Shutdown)?;
.map_err(|mpsc::SendError { .. }| RpcError::Shutdown)?;
response_guard.response().await
}
}
Expand Down Expand Up @@ -202,7 +211,7 @@ impl<Resp> ResponseGuard<'_, Resp> {
self.cancel = false;
match response {
Ok(response) => response,
Err(oneshot::error::RecvError { .. }) => {
Err(oneshot::Canceled { .. }) => {
// The oneshot is Canceled when the dispatch task ends. In that case,
// there's nothing listening on the other side, so there's no point in
// propagating cancellation.
Expand Down Expand Up @@ -237,9 +246,27 @@ impl<Resp> Drop for ResponseGuard<'_, Resp> {
pub fn new<Req, Resp, C>(
config: Config,
transport: C,
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C>>
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C, DelayQueue<u64>>>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
{
with_in_flight_requests(
config,
transport,
InFlightRequests::<_, DelayQueue<u64>>::default(),
)
}

/// Returns a channel and dispatcher that manages the lifecycle of requests initiated by the
/// channel.
pub fn with_in_flight_requests<Req, Resp, C, Deadline>(
config: Config,
transport: C,
in_flight_requests: InFlightRequests<Result<Resp, RpcError>, Deadline>,
) -> NewClient<Channel<Req, Resp>, RequestDispatch<Req, Resp, C, Deadline>>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
Deadline: DelayQueueLike<u64>,
{
let (to_dispatch, pending_requests) = mpsc::channel(config.pending_request_buffer);
let (cancellation, canceled_requests) = cancellations();
Expand All @@ -254,7 +281,7 @@ where
config,
canceled_requests,
transport: transport.fuse(),
in_flight_requests: InFlightRequests::default(),
in_flight_requests,
pending_requests,
terminal_error: None,
},
Expand All @@ -266,7 +293,10 @@ where
#[must_use]
#[pin_project()]
#[derive(Debug)]
pub struct RequestDispatch<Req, Resp, C> {
pub struct RequestDispatch<Req, Resp, C, Deadline>
where
Deadline: DelayQueueLike<u64>,
{
/// Writes requests to the wire and reads responses off the wire.
#[pin]
transport: Fuse<C>,
Expand All @@ -275,7 +305,7 @@ pub struct RequestDispatch<Req, Resp, C> {
/// Requests that were dropped.
canceled_requests: CanceledRequests,
/// Requests already written to the wire that haven't yet received responses.
in_flight_requests: InFlightRequests<Result<Resp, RpcError>>,
in_flight_requests: InFlightRequests<Result<Resp, RpcError>, Deadline>,
/// Configures limits to prevent unlimited resource usage.
config: Config,
/// Produces errors that can be sent in response to any unprocessed requests at the time
Expand All @@ -285,13 +315,14 @@ pub struct RequestDispatch<Req, Resp, C> {
terminal_error: Option<ChannelError<dyn Any + Send + Sync + 'static>>,
}

impl<Req, Resp, C> RequestDispatch<Req, Resp, C>
impl<Req, Resp, C, Deadline> RequestDispatch<Req, Resp, C, Deadline>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
Deadline: DelayQueueLike<u64>,
{
fn in_flight_requests<'a>(
self: &'a mut Pin<&mut Self>,
) -> &'a mut InFlightRequests<Result<Resp, RpcError>> {
) -> &'a mut InFlightRequests<Result<Resp, RpcError>, Deadline> {
self.as_mut().project().in_flight_requests
}

Expand Down Expand Up @@ -433,9 +464,9 @@ where
ready!(self.ensure_writeable(cx)?);

loop {
match ready!(self.pending_requests_mut().poll_recv(cx)) {
match ready!(self.pending_requests_mut().poll_next_unpin(cx)) {
Some(request) => {
if request.response_completion.is_closed() {
if request.response_completion.is_canceled() {
let _entered = request.span.enter();
tracing::info!("AbortRequest");
continue;
Expand Down Expand Up @@ -586,14 +617,14 @@ where
tracing::warn!("RpcError::Channel");
}
loop {
match ready!(self.pending_requests_mut().poll_recv(cx)) {
match ready!(self.pending_requests_mut().poll_next_unpin(cx)) {
Some(DispatchRequest {
span,
response_completion,
..
}) => {
let _entered = span.enter();
if response_completion.is_closed() {
if response_completion.is_canceled() {
tracing::info!("AbortRequest");
} else {
tracing::warn!("RpcError::Channel");
Expand Down Expand Up @@ -636,9 +667,10 @@ where
}
}

impl<Req, Resp, C> Future for RequestDispatch<Req, Resp, C>
impl<Req, Resp, C, Deadline> Future for RequestDispatch<Req, Resp, C, Deadline>
where
C: Transport<ClientMessage<Req>, Response<Resp>>,
Deadline: DelayQueueLike<u64>,
{
type Output = Result<(), ChannelError<C::Error>>;

Expand Down Expand Up @@ -685,10 +717,15 @@ mod tests {
client::{in_flight_requests::InFlightRequests, Config},
context::{self, current},
transport::{self, channel::UnboundedChannel},
util::delay_queue::DelayQueue,
ChannelError, ClientMessage, Response,
};
use assert_matches::assert_matches;
use futures::{prelude::*, task::*};
use futures::{
channel::{mpsc, oneshot},
prelude::*,
task::*,
};
use std::{
convert::TryFrom,
fmt::Display,
Expand All @@ -700,10 +737,6 @@ mod tests {
},
};
use thiserror::Error;
use tokio::sync::{
mpsc::{self},
oneshot,
};
use tracing::Span;

#[tokio::test]
Expand All @@ -724,7 +757,7 @@ mod tests {
.await
.unwrap();
assert_matches!(dispatch.as_mut().poll(cx), Poll::Pending);
assert_matches!(rx.try_recv(), Ok(Ok(resp)) if resp == "Resp");
assert_matches!(rx.try_recv(), Ok(Some(Ok(resp))) if resp == "Resp");
}

#[tokio::test]
Expand Down Expand Up @@ -858,23 +891,6 @@ mod tests {
assert!(dispatch.as_mut().poll_next_request(cx).is_pending());
}

#[tokio::test]
async fn test_permit_before_transport_error() {
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
let (mut dispatch, mut channel, mut cx) = set_up_always_err(TransportError::Flush);
let (tx, mut rx) = oneshot::channel();
// reserve succeeds
let permit = reserve_for_send(&mut channel, tx, &mut rx).await;
// Since there's an outstanding permit, dispatch should not complete yet.
assert_matches!(dispatch.as_mut().poll(&mut cx), Poll::Pending);

let resp = permit("hi");

// errors from both the dispatch future and the request
assert_matches!(dispatch.as_mut().poll(&mut cx), Poll::Ready(Err(ChannelError::Flush(e))) if matches!(*e, TransportError::Flush));
assert_matches!(resp.response().await, Err(RpcError::Channel(_)));
}

#[tokio::test]
async fn test_shutdown() {
let _ = tracing_subscriber::fmt().with_test_writer().try_init();
Expand Down Expand Up @@ -960,14 +976,14 @@ mod tests {
fn set_up_always_err(
cause: TransportError,
) -> (
Pin<Box<RequestDispatch<String, String, AlwaysErrorTransport<String>>>>,
Pin<Box<RequestDispatch<String, String, AlwaysErrorTransport<String>, DelayQueue<u64>>>>,
Channel<String, String>,
Context<'static>,
) {
let (to_dispatch, pending_requests) = mpsc::channel(1);
let (cancellation, canceled_requests) = cancellations();
let transport: AlwaysErrorTransport<String> = AlwaysErrorTransport(cause, PhantomData);
let dispatch = Box::pin(RequestDispatch::<String, String, _> {
let dispatch = Box::pin(RequestDispatch::<String, String, _, _> {
transport: transport.fuse(),
pending_requests,
canceled_requests,
Expand Down Expand Up @@ -1051,6 +1067,7 @@ mod tests {
String,
String,
UnboundedChannel<Response<String>, ClientMessage<String>>,
DelayQueue<u64>,
>,
>,
>,
Expand All @@ -1063,7 +1080,7 @@ mod tests {
let (cancellation, canceled_requests) = cancellations();
let (client_channel, server_channel) = transport::channel::unbounded();

let dispatch = RequestDispatch::<String, String, _> {
let dispatch = RequestDispatch::<String, String, _, _> {
transport: client_channel.fuse(),
pending_requests,
canceled_requests,
Expand All @@ -1081,32 +1098,6 @@ mod tests {
(Box::pin(dispatch), channel, server_channel)
}

async fn reserve_for_send<'a>(
channel: &'a mut Channel<String, String>,
response_completion: oneshot::Sender<Result<String, RpcError>>,
response: &'a mut oneshot::Receiver<Result<String, RpcError>>,
) -> impl FnOnce(&str) -> ResponseGuard<'a, String> {
let permit = channel.to_dispatch.reserve().await.unwrap();
|request| {
let request_id =
u64::try_from(channel.next_request_id.fetch_add(1, Ordering::Relaxed)).unwrap();
let request = DispatchRequest {
ctx: context::current(),
span: Span::current(),
request_id,
request: request.to_string(),
response_completion,
};
permit.send(request);
ResponseGuard {
response,
cancellation: &channel.cancellation,
request_id,
cancel: true,
}
}
}

async fn send_request<'a>(
channel: &'a mut Channel<String, String>,
request: &str,
Expand Down
Loading