Skip to content

feat: add progress notification handling and related structures #282

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jun 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions crates/rmcp-macros/src/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -207,12 +207,22 @@ pub fn tool(attr: TokenStream, input: TokenStream) -> syn::Result<TokenStream> {
// 2. make return type: `std::pin::Pin<Box<dyn Future<Output = #ReturnType> + Send + '_>>`
// 3. make body: { Box::pin(async move { #body }) }
let new_output = syn::parse2::<ReturnType>({
let mut lt = quote! { 'static };
if let Some(receiver) = fn_item.sig.receiver() {
if let Some((_, receiver_lt)) = receiver.reference.as_ref() {
if let Some(receiver_lt) = receiver_lt {
lt = quote! { #receiver_lt };
} else {
lt = quote! { '_ };
}
}
}
match &fn_item.sig.output {
syn::ReturnType::Default => {
quote! { -> std::pin::Pin<Box<dyn Future<Output = ()> + Send + '_>> }
quote! { -> std::pin::Pin<Box<dyn Future<Output = ()> + Send + #lt>> }
}
syn::ReturnType::Type(_, ty) => {
quote! { -> std::pin::Pin<Box<dyn Future<Output = #ty> + Send + '_>> }
quote! { -> std::pin::Pin<Box<dyn Future<Output = #ty> + Send + #lt>> }
}
}
})?;
Expand Down
7 changes: 6 additions & 1 deletion crates/rmcp/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ chrono = { version = "0.4.38", default-features = false, features = ["serde", "c

[features]
default = ["base64", "macros", "server"]
client = []
client = ["dep:tokio-stream"]
server = ["transport-async-rw", "dep:schemars"]
macros = ["dep:rmcp-macros", "dep:paste"]

Expand Down Expand Up @@ -191,3 +191,8 @@ path = "tests/test_message_protocol.rs"
name = "test_message_schema"
required-features = ["server", "client", "schemars"]
path = "tests/test_message_schema.rs"

[[test]]
name = "test_progress_subscriber"
required-features = ["server", "client", "macros"]
path = "tests/test_progress_subscriber.rs"
1 change: 1 addition & 0 deletions crates/rmcp/src/handler/client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub mod progress;
use crate::{
error::Error as McpError,
model::*,
Expand Down
100 changes: 100 additions & 0 deletions crates/rmcp/src/handler/client/progress.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
use std::{collections::HashMap, sync::Arc};

use futures::{Stream, StreamExt};
use tokio::sync::RwLock;
use tokio_stream::wrappers::ReceiverStream;

use crate::model::{ProgressNotificationParam, ProgressToken};
type Dispatcher =
Arc<RwLock<HashMap<ProgressToken, tokio::sync::mpsc::Sender<ProgressNotificationParam>>>>;

/// A dispatcher for progress notifications.
#[derive(Debug, Clone, Default)]
pub struct ProgressDispatcher {
pub(crate) dispatcher: Dispatcher,
}

impl ProgressDispatcher {
const CHANNEL_SIZE: usize = 16;
pub fn new() -> Self {
Self::default()
}

/// Handle a progress notification by sending it to the appropriate subscriber
pub async fn handle_notification(&self, notification: ProgressNotificationParam) {
let token = &notification.progress_token;
if let Some(sender) = self.dispatcher.read().await.get(token).cloned() {
let send_result = sender.send(notification).await;
if let Err(e) = send_result {
tracing::warn!("Failed to send progress notification: {e}");
}
}
}

/// Subscribe to progress notifications for a specific token.
///
/// If you drop the returned `ProgressSubscriber`, it will automatically unsubscribe from notifications for that token.
pub async fn subscribe(&self, progress_token: ProgressToken) -> ProgressSubscriber {
let (sender, receiver) = tokio::sync::mpsc::channel(Self::CHANNEL_SIZE);
self.dispatcher
.write()
.await
.insert(progress_token.clone(), sender);
let receiver = ReceiverStream::new(receiver);
ProgressSubscriber {
progress_token,
receiver,
dispacher: self.dispatcher.clone(),
}
}

/// Unsubscribe from progress notifications for a specific token.
pub async fn unsubscribe(&self, token: &ProgressToken) {
self.dispatcher.write().await.remove(token);
}

/// Clear all dispachter.
pub async fn clear(&self) {
let mut dispacher = self.dispatcher.write().await;
dispacher.clear();
}
}

pub struct ProgressSubscriber {
pub(crate) progress_token: ProgressToken,
pub(crate) receiver: ReceiverStream<ProgressNotificationParam>,
pub(crate) dispacher: Dispatcher,
}

impl ProgressSubscriber {
pub fn progress_token(&self) -> &ProgressToken {
&self.progress_token
}
}

impl Stream for ProgressSubscriber {
type Item = ProgressNotificationParam;

fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
self.receiver.poll_next_unpin(cx)
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.receiver.size_hint()
}
}

impl Drop for ProgressSubscriber {
fn drop(&mut self) {
let token = self.progress_token.clone();
self.receiver.close();
let dispatcher = self.dispacher.clone();
tokio::spawn(async move {
let mut dispacher = dispatcher.write_owned().await;
dispacher.remove(&token);
});
}
}
14 changes: 9 additions & 5 deletions crates/rmcp/src/handler/server/tool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,6 @@ pub trait FromToolCallContextPart<S>: Sized {
pub trait IntoCallToolResult {
fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error>;
}
impl IntoCallToolResult for () {
fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
Ok(CallToolResult::success(vec![]))
}
}

impl<T: IntoContents> IntoCallToolResult for T {
fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
Expand All @@ -120,6 +115,15 @@ impl<T: IntoContents, E: IntoContents> IntoCallToolResult for Result<T, E> {
}
}

impl<T: IntoCallToolResult> IntoCallToolResult for Result<T, crate::Error> {
fn into_call_tool_result(self) -> Result<CallToolResult, crate::Error> {
match self {
Ok(value) => value.into_call_tool_result(),
Err(error) => Err(error),
}
}
}

pin_project_lite::pin_project! {
#[project = IntoCallToolResultFutProj]
pub enum IntoCallToolResultFut<F, R> {
Expand Down
40 changes: 35 additions & 5 deletions crates/rmcp/src/model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ pub type RequestId = NumberOrString;
#[serde(transparent)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct ProgressToken(pub NumberOrString);
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct Request<M = String, P = JsonObject> {
pub method: M,
Expand All @@ -255,6 +255,16 @@ pub struct Request<M = String, P = JsonObject> {
pub extensions: Extensions,
}

impl<M: Default, P> Request<M, P> {
pub fn new(params: P) -> Self {
Self {
method: Default::default(),
params,
extensions: Extensions::default(),
}
}
}

impl<M, P> GetExtensions for Request<M, P> {
fn extensions(&self) -> &Extensions {
&self.extensions
Expand All @@ -264,7 +274,7 @@ impl<M, P> GetExtensions for Request<M, P> {
}
}

#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct RequestOptionalParam<M = String, P = JsonObject> {
pub method: M,
Expand All @@ -277,7 +287,17 @@ pub struct RequestOptionalParam<M = String, P = JsonObject> {
pub extensions: Extensions,
}

#[derive(Debug, Clone)]
impl<M: Default, P> RequestOptionalParam<M, P> {
pub fn with_param(params: P) -> Self {
Self {
method: Default::default(),
params: Some(params),
extensions: Extensions::default(),
}
}
}

#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct RequestNoParam<M = String> {
pub method: M,
Expand All @@ -296,7 +316,7 @@ impl<M> GetExtensions for RequestNoParam<M> {
&mut self.extensions
}
}
#[derive(Debug, Clone)]
#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct Notification<M = String, P = JsonObject> {
pub method: M,
Expand All @@ -308,7 +328,17 @@ pub struct Notification<M = String, P = JsonObject> {
pub extensions: Extensions,
}

#[derive(Debug, Clone)]
impl<M: Default, P> Notification<M, P> {
pub fn new(params: P) -> Self {
Self {
method: Default::default(),
params,
extensions: Extensions::default(),
}
}
}

#[derive(Debug, Clone, Default)]
#[cfg_attr(feature = "schemars", derive(schemars::JsonSchema))]
pub struct NotificationNoParam<M = String> {
pub method: M,
Expand Down
6 changes: 6 additions & 0 deletions crates/rmcp/src/model/content.rs
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,9 @@ impl IntoContents for String {
vec![Content::text(self)]
}
}

impl IntoContents for () {
fn into_contents(self) -> Vec<Content> {
vec![]
}
}
3 changes: 2 additions & 1 deletion crates/rmcp/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -745,8 +745,9 @@ where
let mut extensions = Extensions::new();
let mut meta = Meta::new();
// avoid clone
std::mem::swap(&mut extensions, request.extensions_mut());
// swap meta firstly, otherwise progress token will be lost
std::mem::swap(&mut meta, request.get_meta_mut());
std::mem::swap(&mut extensions, request.extensions_mut());
let context = RequestContext {
ct: context_ct,
id: id.clone(),
Expand Down
Loading
Loading