Skip to content

Commit eebdd45

Browse files
Implement an easier way to get Send Serve and Stub
1 parent 33b0b21 commit eebdd45

File tree

6 files changed

+159
-17
lines changed

6 files changed

+159
-17
lines changed

plugins/src/lib.rs

+6-6
Original file line numberDiff line numberDiff line change
@@ -559,15 +559,15 @@ impl<'a> ServiceGenerator<'a> {
559559
)| {
560560
quote! {
561561
#( #attrs )*
562-
async fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> #output;
562+
fn #ident(self, context: ::tarpc::context::Context, #( #args ),*) -> impl ::core::future::Future<Output = #output> + ::core::marker::Send;
563563
}
564564
},
565565
);
566566

567567
let stub_doc = format!("The stub trait for service [`{service_ident}`].");
568568
quote! {
569569
#( #attrs )*
570-
#vis trait #service_ident: ::core::marker::Sized {
570+
#vis trait #service_ident: ::core::marker::Sized + ::core::marker::Send {
571571
#( #rpc_fns )*
572572

573573
/// Returns a serving function to use with
@@ -578,11 +578,11 @@ impl<'a> ServiceGenerator<'a> {
578578
}
579579

580580
#[doc = #stub_doc]
581-
#vis trait #client_stub_ident: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident> {
581+
#vis trait #client_stub_ident: ::tarpc::client::stub::SendStub<Req = #request_ident, Resp = #response_ident> {
582582
}
583583

584584
impl<S> #client_stub_ident for S
585-
where S: ::tarpc::client::stub::Stub<Req = #request_ident, Resp = #response_ident>
585+
where S: ::tarpc::client::stub::SendStub<Req = #request_ident, Resp = #response_ident>
586586
{
587587
}
588588
}
@@ -616,7 +616,7 @@ impl<'a> ServiceGenerator<'a> {
616616
} = self;
617617

618618
quote! {
619-
impl<S> ::tarpc::server::Serve for #server_ident<S>
619+
impl<S> ::tarpc::server::SendServe for #server_ident<S>
620620
where S: #service_ident
621621
{
622622
type Req = #request_ident;
@@ -780,7 +780,7 @@ impl<'a> ServiceGenerator<'a> {
780780

781781
quote! {
782782
impl<Stub> #client_ident<Stub>
783-
where Stub: ::tarpc::client::stub::Stub<
783+
where Stub: ::tarpc::client::stub::SendStub<
784784
Req = #request_ident,
785785
Resp = #response_ident>
786786
{

tarpc/src/client/stub.rs

+51-4
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
//! Provides a Stub trait, implemented by types that can call remote services.
22
3+
use std::future::Future;
4+
35
use crate::{
46
client::{Channel, RpcError},
57
context,
6-
server::Serve,
8+
server::{SendServe, Serve},
79
RequestName,
810
};
911

@@ -15,7 +17,6 @@ mod mock;
1517

1618
/// A connection to a remote service.
1719
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
18-
#[allow(async_fn_in_trait)]
1920
pub trait Stub {
2021
/// The service request type.
2122
type Req: RequestName;
@@ -24,8 +25,28 @@ pub trait Stub {
2425
type Resp;
2526

2627
/// Calls a remote service.
27-
async fn call(&self, ctx: context::Context, request: Self::Req)
28-
-> Result<Self::Resp, RpcError>;
28+
fn call(
29+
&self,
30+
ctx: context::Context,
31+
request: Self::Req,
32+
) -> impl Future<Output = Result<Self::Resp, RpcError>>;
33+
}
34+
35+
/// A connection to a remote service.
36+
/// Calls the service with requests of type `Req` and receives responses of type `Resp`.
37+
pub trait SendStub: Send {
38+
/// The service request type.
39+
type Req: RequestName;
40+
41+
/// The service response type.
42+
type Resp;
43+
44+
/// Calls a remote service.
45+
fn call(
46+
&self,
47+
ctx: context::Context,
48+
request: Self::Req,
49+
) -> impl Future<Output = Result<Self::Resp, RpcError>> + Send;
2950
}
3051

3152
impl<Req, Resp> Stub for Channel<Req, Resp>
@@ -40,6 +61,19 @@ where
4061
}
4162
}
4263

64+
impl<Req, Resp> SendStub for Channel<Req, Resp>
65+
where
66+
Req: RequestName + Send,
67+
Resp: Send,
68+
{
69+
type Req = Req;
70+
type Resp = Resp;
71+
72+
async fn call(&self, ctx: context::Context, request: Req) -> Result<Self::Resp, RpcError> {
73+
Self::call(self, ctx, request).await
74+
}
75+
}
76+
4377
impl<S> Stub for S
4478
where
4579
S: Serve + Clone,
@@ -50,3 +84,16 @@ where
5084
self.clone().serve(ctx, req).await.map_err(RpcError::Server)
5185
}
5286
}
87+
88+
impl<S> SendStub for S
89+
where
90+
S: SendServe + Clone + Sync,
91+
S::Req: Send + Sync,
92+
S::Resp: Send,
93+
{
94+
type Req = S::Req;
95+
type Resp = S::Resp;
96+
async fn call(&self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, RpcError> {
97+
self.clone().serve(ctx, req).await.map_err(RpcError::Server)
98+
}
99+
}

tarpc/src/client/stub/load_balance.rs

+18
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,24 @@ mod round_robin {
2828
}
2929
}
3030

31+
impl<Stub> stub::SendStub for RoundRobin<Stub>
32+
where
33+
Stub: stub::SendStub + Send + Sync,
34+
Stub::Req: Send,
35+
{
36+
type Req = Stub::Req;
37+
type Resp = Stub::Resp;
38+
39+
async fn call(
40+
&self,
41+
ctx: context::Context,
42+
request: Self::Req,
43+
) -> Result<Stub::Resp, RpcError> {
44+
let next = self.stubs.next();
45+
next.call(ctx, request).await
46+
}
47+
}
48+
3149
/// A Stub that load-balances across backing stubs by round robin.
3250
#[derive(Clone, Debug)]
3351
pub struct RoundRobin<Stub> {

tarpc/src/client/stub/mock.rs

+26-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,8 @@
11
use crate::{
2-
client::{stub::Stub, RpcError},
2+
client::{
3+
stub::{SendStub, Stub},
4+
RpcError,
5+
},
36
context, RequestName, ServerError,
47
};
58
use std::{collections::HashMap, hash::Hash, io};
@@ -42,3 +45,25 @@ where
4245
})
4346
}
4447
}
48+
49+
impl<Req, Resp> SendStub for Mock<Req, Resp>
50+
where
51+
Req: Eq + Hash + RequestName + Send + Sync,
52+
Resp: Clone + Send + Sync,
53+
{
54+
type Req = Req;
55+
type Resp = Resp;
56+
57+
async fn call(&self, _: context::Context, request: Self::Req) -> Result<Resp, RpcError> {
58+
self.responses
59+
.get(&request)
60+
.cloned()
61+
.map(Ok)
62+
.unwrap_or_else(|| {
63+
Err(RpcError::Server(ServerError {
64+
kind: io::ErrorKind::NotFound,
65+
detail: "mock (request, response) entry not found".into(),
66+
}))
67+
})
68+
}
69+
}

tarpc/src/client/stub/retry.rs

+27
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,33 @@ where
3333
}
3434
}
3535

36+
impl<Stub, Req, F> stub::SendStub for Retry<F, Stub>
37+
where
38+
Req: RequestName + Send + Sync,
39+
Stub: stub::SendStub<Req = Arc<Req>> + Send + Sync,
40+
F: Fn(&Result<Stub::Resp, RpcError>, u32) -> bool + Send + Sync,
41+
{
42+
type Req = Req;
43+
type Resp = Stub::Resp;
44+
45+
async fn call(
46+
&self,
47+
ctx: context::Context,
48+
request: Self::Req,
49+
) -> Result<Stub::Resp, RpcError> {
50+
let request = Arc::new(request);
51+
for i in 1.. {
52+
let result = self.stub.call(ctx, Arc::clone(&request)).await;
53+
if (self.should_retry)(&result, i) {
54+
tracing::trace!("Retrying on attempt {i}");
55+
continue;
56+
}
57+
return result;
58+
}
59+
unreachable!("Wow, that was a lot of attempts!");
60+
}
61+
}
62+
3663
/// A Stub that retries requests based on response contents.
3764
/// Note: to use this stub with Serde serialization, the "rc" feature of Serde needs to be enabled.
3865
#[derive(Clone, Debug)]

tarpc/src/server.rs

+31-6
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ impl Config {
6767
}
6868

6969
/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
70-
#[allow(async_fn_in_trait)]
7170
pub trait Serve {
7271
/// Type of request.
7372
type Req: RequestName;
@@ -76,7 +75,33 @@ pub trait Serve {
7675
type Resp;
7776

7877
/// Responds to a single request.
79-
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError>;
78+
fn serve(
79+
self,
80+
ctx: context::Context,
81+
req: Self::Req,
82+
) -> impl Future<Output = Result<Self::Resp, ServerError>>;
83+
}
84+
85+
/// Equivalent to a `FnOnce(Req) -> impl Future<Output = Resp>`.
86+
pub trait SendServe: Send {
87+
/// Type of request.
88+
type Req: RequestName;
89+
/// Type of response.
90+
type Resp;
91+
/// Responds to a single request.
92+
fn serve(
93+
self,
94+
ctx: context::Context,
95+
req: Self::Req,
96+
) -> impl Future<Output = Result<Self::Resp, ServerError>> + Send;
97+
}
98+
99+
impl<S: SendServe> Serve for S {
100+
type Req = <Self as SendServe>::Req;
101+
type Resp = <Self as SendServe>::Resp;
102+
async fn serve(self, ctx: context::Context, req: Self::Req) -> Result<Self::Resp, ServerError> {
103+
<Self as SendServe>::serve(self, ctx, req).await
104+
}
80105
}
81106

82107
/// A Serve wrapper around a Fn.
@@ -113,11 +138,11 @@ where
113138
}
114139
}
115140

116-
impl<Req, Resp, Fut, F> Serve for ServeFn<Req, Resp, F>
141+
impl<Req, Resp, Fut, F> SendServe for ServeFn<Req, Resp, F>
117142
where
118-
Req: RequestName,
119-
F: FnOnce(context::Context, Req) -> Fut,
120-
Fut: Future<Output = Result<Resp, ServerError>>,
143+
Req: RequestName + Send,
144+
F: FnOnce(context::Context, Req) -> Fut + Send,
145+
Fut: Future<Output = Result<Resp, ServerError>> + Send,
121146
{
122147
type Req = Req;
123148
type Resp = Resp;

0 commit comments

Comments
 (0)