Skip to content

Commit 88fc0d8

Browse files
committed
feat: add progress notification handling and related structures
also fix some lifetime bugs of tool macro
1 parent 893969a commit 88fc0d8

File tree

10 files changed

+300
-18
lines changed

10 files changed

+300
-18
lines changed

crates/rmcp-macros/src/tool.rs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -207,12 +207,22 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
207207
// 2. make return type: `std::pin::Pin<Box<dyn Future<Output = #ReturnType> + Send + '_>>`
208208
// 3. make body: { Box::pin(async move { #body }) }
209209
let new_output = syn::parse2::<ReturnType>({
210+
let mut lt = quote! { 'static };
211+
if let Some(receiver) = fn_item.sig.receiver() {
212+
if let Some((_, receiver_lt)) = receiver.reference.as_ref() {
213+
if let Some(receiver_lt) = receiver_lt {
214+
lt = quote! { #receiver_lt };
215+
} else {
216+
lt = quote! { '_ };
217+
}
218+
}
219+
}
210220
match &fn_item.sig.output {
211221
syn::ReturnType::Default => {
212-
quote! { -> std::pin::Pin<Box<dyn Future<Output = ()> + Send + '_>> }
222+
quote! { -> std::pin::Pin<Box<dyn Future<Output = ()> + Send + #lt>> }
213223
}
214224
syn::ReturnType::Type(_, ty) => {
215-
quote! { -> std::pin::Pin<Box<dyn Future<Output = #ty> + Send + '_>> }
225+
quote! { -> std::pin::Pin<Box<dyn Future<Output = #ty> + Send + #lt>> }
216226
}
217227
}
218228
})?;

crates/rmcp/Cargo.toml

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ chrono = { version = "0.4.38", default-features = false, features = ["serde", "c
7171

7272
[features]
7373
default = ["base64", "macros", "server"]
74-
client = []
74+
client = ["dep:tokio-stream"]
7575
server = ["transport-async-rw", "dep:schemars"]
7676
macros = ["dep:rmcp-macros", "dep:paste"]
7777

@@ -191,3 +191,8 @@ path = "tests/test_message_protocol.rs"
191191
name = "test_message_schema"
192192
required-features = ["server", "client", "schemars"]
193193
path = "tests/test_message_schema.rs"
194+
195+
[[test]]
196+
name = "test_progress_subscriber"
197+
required-features = ["server", "client", "macros"]
198+
path = "tests/test_progress_subscriber.rs"

crates/rmcp/src/handler/client.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod progress;
12
use crate::{
23
error::Error as McpError,
34
model::*,
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
use std::{collections::HashMap, sync::Arc};
2+
3+
use futures::{Stream, StreamExt};
4+
use tokio::sync::RwLock;
5+
use tokio_stream::wrappers::ReceiverStream;
6+
7+
use crate::model::{ProgressNotificationParam, ProgressToken};
8+
type Dispatcher =
9+
Arc<RwLock<HashMap<ProgressToken, tokio::sync::mpsc::Sender<ProgressNotificationParam>>>>;
10+
11+
/// A dispatcher for progress notifications.
12+
#[derive(Debug, Clone, Default)]
13+
pub struct ProgressDispatcher {
14+
pub(crate) dispatcher: Dispatcher,
15+
}
16+
17+
impl ProgressDispatcher {
18+
const CHANNEL_SIZE: usize = 16;
19+
pub fn new() -> Self {
20+
Self::default()
21+
}
22+
23+
/// Handle a progress notification by sending it to the appropriate subscriber
24+
pub async fn handle_notification(&self, notification: ProgressNotificationParam) {
25+
let token = &notification.progress_token;
26+
if let Some(sender) = self.dispatcher.read().await.get(token).cloned() {
27+
let send_result = sender.send(notification).await;
28+
if let Err(e) = send_result {
29+
tracing::warn!("Failed to send progress notification: {e}");
30+
}
31+
}
32+
}
33+
34+
/// Subscribe to progress notifications for a specific token.
35+
///
36+
/// If you drop the returned `ProgressSubscriber`, it will automatically unsubscribe from notifications for that token.
37+
pub async fn subscribe(&self, progress_token: ProgressToken) -> ProgressSubscriber {
38+
let (sender, receiver) = tokio::sync::mpsc::channel(Self::CHANNEL_SIZE);
39+
self.dispatcher
40+
.write()
41+
.await
42+
.insert(progress_token.clone(), sender);
43+
let receiver = ReceiverStream::new(receiver);
44+
ProgressSubscriber {
45+
progress_token,
46+
receiver,
47+
dispacher: self.dispatcher.clone(),
48+
}
49+
}
50+
51+
/// Unsubscribe from progress notifications for a specific token.
52+
pub async fn unsubscribe(&self, token: &ProgressToken) {
53+
self.dispatcher.write().await.remove(token);
54+
}
55+
56+
/// Clear all dispachter.
57+
pub async fn clear(&self) {
58+
let mut dispacher = self.dispatcher.write().await;
59+
dispacher.clear();
60+
}
61+
}
62+
63+
pub struct ProgressSubscriber {
64+
pub(crate) progress_token: ProgressToken,
65+
pub(crate) receiver: ReceiverStream<ProgressNotificationParam>,
66+
pub(crate) dispacher: Dispatcher,
67+
}
68+
69+
impl ProgressSubscriber {
70+
pub fn progress_token(&self) -> &ProgressToken {
71+
&self.progress_token
72+
}
73+
}
74+
75+
impl Stream for ProgressSubscriber {
76+
type Item = ProgressNotificationParam;
77+
78+
fn poll_next(
79+
mut self: std::pin::Pin<&mut Self>,
80+
cx: &mut std::task::Context<'_>,
81+
) -> std::task::Poll<Option<Self::Item>> {
82+
self.receiver.poll_next_unpin(cx)
83+
}
84+
85+
fn size_hint(&self) -> (usize, Option<usize>) {
86+
self.receiver.size_hint()
87+
}
88+
}
89+
90+
impl Drop for ProgressSubscriber {
91+
fn drop(&mut self) {
92+
let token = self.progress_token.clone();
93+
self.receiver.close();
94+
let dispatcher = self.dispacher.clone();
95+
tokio::spawn(async move {
96+
let mut dispacher = dispatcher.write_owned().await;
97+
dispacher.remove(&token);
98+
});
99+
}
100+
}

crates/rmcp/src/handler/server/tool.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,6 @@ pub trait FromToolCallContextPart<S>: Sized {
9999
pub trait IntoCallToolResult {
100100
fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error>;
101101
}
102-
impl IntoCallToolResult for () {
103-
fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
104-
Ok(CallToolResult::success(vec![]))
105-
}
106-
}
107102

108103
impl<T: IntoContents> IntoCallToolResult for T {
109104
fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
@@ -120,6 +115,15 @@ impl<T: IntoContents, E: IntoContents> IntoCallToolResult for Result<T, E> {
120115
}
121116
}
122117

118+
impl<T: IntoCallToolResult> IntoCallToolResult for Result<T, crate::Error> {
119+
fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
120+
match self {
121+
Ok(value) => value.into_call_tool_result(),
122+
Err(error) => Err(error),
123+
}
124+
}
125+
}
126+
123127
pin_project_lite::pin_project! {
124128
#[project = IntoCallToolResultFutProj]
125129
pub enum IntoCallToolResultFut<F, R> {

crates/rmcp/src/model.rs

Lines changed: 35 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -242,7 +242,7 @@ pub type RequestId = NumberOrString;
242242
#[serde(transparent)]
243243
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
244244
pub struct ProgressToken(pub NumberOrString);
245-
#[derive(Debug, Clone)]
245+
#[derive(Debug, Clone, Default)]
246246
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
247247
pub struct Request<M = String, P = JsonObject> {
248248
pub method: M,
@@ -255,6 +255,16 @@ pub struct Request<M = String, P = JsonObject> {
255255
pub extensions: Extensions,
256256
}
257257

258+
impl<M: Default, P> Request<M, P> {
259+
pub fn new(params: P) -> Self {
260+
Self {
261+
method: Default::default(),
262+
params,
263+
extensions: Extensions::default(),
264+
}
265+
}
266+
}
267+
258268
impl<M, P> GetExtensions for Request<M, P> {
259269
fn extensions(&self) -> &Extensions {
260270
&self.extensions
@@ -264,7 +274,7 @@ impl<M, P> GetExtensions for Request<M, P> {
264274
}
265275
}
266276

267-
#[derive(Debug, Clone)]
277+
#[derive(Debug, Clone, Default)]
268278
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
269279
pub struct RequestOptionalParam<M = String, P = JsonObject> {
270280
pub method: M,
@@ -277,7 +287,17 @@ pub struct RequestOptionalParam<M = String, P = JsonObject> {
277287
pub extensions: Extensions,
278288
}
279289

280-
#[derive(Debug, Clone)]
290+
impl<M: Default, P> RequestOptionalParam<M, P> {
291+
pub fn with_param(params: P) -> Self {
292+
Self {
293+
method: Default::default(),
294+
params: Some(params),
295+
extensions: Extensions::default(),
296+
}
297+
}
298+
}
299+
300+
#[derive(Debug, Clone, Default)]
281301
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
282302
pub struct RequestNoParam<M = String> {
283303
pub method: M,
@@ -296,7 +316,7 @@ impl<M> GetExtensions for RequestNoParam<M> {
296316
&mut self.extensions
297317
}
298318
}
299-
#[derive(Debug, Clone)]
319+
#[derive(Debug, Clone, Default)]
300320
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
301321
pub struct Notification<M = String, P = JsonObject> {
302322
pub method: M,
@@ -308,7 +328,17 @@ pub struct Notification<M = String, P = JsonObject> {
308328
pub extensions: Extensions,
309329
}
310330

311-
#[derive(Debug, Clone)]
331+
impl<M: Default, P> Notification<M, P> {
332+
pub fn new(params: P) -> Self {
333+
Self {
334+
method: Default::default(),
335+
params,
336+
extensions: Extensions::default(),
337+
}
338+
}
339+
}
340+
341+
#[derive(Debug, Clone, Default)]
312342
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
313343
pub struct NotificationNoParam<M = String> {
314344
pub method: M,

crates/rmcp/src/model/content.rs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,9 @@ impl IntoContents for String {
165165
vec![Content::text(self)]
166166
}
167167
}
168+
169+
impl IntoContents for () {
170+
fn into_contents(self) -> Vec<Content> {
171+
vec![]
172+
}
173+
}

crates/rmcp/src/service.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -745,8 +745,9 @@ where
745745
let mut extensions = Extensions::new();
746746
let mut meta = Meta::new();
747747
// avoid clone
748-
std::mem::swap(&mut extensions, request.extensions_mut());
748+
// swap meta firstly, otherwise progress token will be lost
749749
std::mem::swap(&mut meta, request.get_meta_mut());
750+
std::mem::swap(&mut extensions, request.extensions_mut());
750751
let context = RequestContext {
751752
ct: context_ct,
752753
id: id.clone(),

0 commit comments

Comments
 (0)