Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] MCP prompts #639

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions crates/goose-mcp/src/developer/prompts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub fn create_prompts() -> Vec<Prompt> {
.collect();

prompts.push(Prompt::new(&template.id, &template.template, arguments));
println!("Loaded prompt: {}", template.id);
}

prompts
Expand Down
2 changes: 2 additions & 0 deletions crates/goose-server/src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Export route modules
pub mod agent;
pub mod health;
pub mod prompts;
pub mod reply;
pub mod secrets;
pub mod system;
Expand All @@ -14,5 +15,6 @@ pub fn configure(state: crate::state::AppState) -> Router {
.merge(reply::routes(state.clone()))
.merge(agent::routes(state.clone()))
.merge(system::routes(state.clone()))
.merge(prompts::routes(state.clone()))
.merge(secrets::routes(state))
}
105 changes: 105 additions & 0 deletions crates/goose-server/src/routes/prompts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use mcp_core::protocol::{GetPromptResult, ListPromptsResult};

use crate::state::AppState;
use axum::{
extract::State,
http::{HeaderMap, StatusCode},
routing::post,
Json, Router,
};
use serde::{Deserialize, Serialize};
use serde_json::Value;

#[derive(Serialize, Deserialize)]
struct Prompt {
name: String,
description: Option<String>,
required: Option<bool>,
}

#[derive(Serialize, Deserialize)]
struct ListPromptRequest {
system: String
}

#[derive(Serialize, Deserialize)]
struct PromptRequest {
system: String,
name: String,
arguments: Value
}

async fn list_prompts_handler(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<ListPromptRequest>,
) -> Result<Json<ListPromptsResult>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;

if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}

let agent = state.agent.lock().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;

// Get prompts through agent passthrough
let result = agent
.passthrough(&request.system, "prompts/list", serde_json::json!({}))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

// Deserialize the result to ListPromptsResult
let prompts_result: ListPromptsResult =
serde_json::from_value(result).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

Ok(Json(prompts_result))
}

async fn get_prompt_handler(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<PromptRequest>,
) -> Result<Json<GetPromptResult>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;

if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}

let agent = state.agent.lock().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;

// Get prompt through agent passthrough
let result = agent
.passthrough(
&request.system,
"prompts/get",
serde_json::json!({
"name": &request.name,
"arguments": &request.arguments,
}),
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

// Deserialize the result to GetPromptResult
let prompt_result: GetPromptResult =
serde_json::from_value(result).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(prompt_result))
}

pub fn routes(state: AppState) -> Router {
Router::new()
.route("/prompts/list", post(list_prompts_handler))
.route("/prompts/get", post(get_prompt_handler))
.with_state(state)
}
2 changes: 1 addition & 1 deletion crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub trait Agent: Send + Sync {
async fn list_systems(&self) -> Vec<String>;

/// Pass through a JSON-RPC request to a specific system
async fn passthrough(&self, system: &str, request: Value) -> SystemResult<Value>;
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value>;

/// Get the total usage of the agent
async fn usage(&self) -> Vec<ProviderUsage>;
Expand Down
6 changes: 6 additions & 0 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,4 +314,10 @@ impl Capabilities {

result
}

/// Retrieve a client by name, if it exists
pub async fn get_system(&self, system: &str) -> Option<Arc<Mutex<Box<dyn McpClientTrait>>>> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for clarity get_client to match get_client_for_tool?

// Use the clients hashmap to get the client associated with the system name
self.clients.get(system).cloned()
}
}
20 changes: 16 additions & 4 deletions crates/goose/src/agents/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tracing::{debug, instrument};

use super::Agent;
use crate::agents::capabilities::{Capabilities, ResourceItem};
use crate::agents::system::{SystemConfig, SystemResult};
use crate::agents::system::{SystemConfig, SystemError, SystemResult};
use crate::message::{Message, MessageContent, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
Expand Down Expand Up @@ -147,9 +147,21 @@ impl Agent for DefaultAgent {
.expect("Failed to list systems")
}

async fn passthrough(&self, _system: &str, _request: Value) -> SystemResult<Value> {
// TODO implement
Ok(Value::Null)
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value> {
println!("in pass through");
let capabilities = self.capabilities.lock().await;
let client = capabilities
.get_system(system)
.await
.unwrap_or_else(|| panic!("System not found: {}", system));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should return an Err (SystemError::UnknownSystem) here instead of panicing?

let client = client.lock().await;

let result: Value = client
.forward_request(method, params)
.await
.map_err(SystemError::Client)?;

Ok(result)
}

#[instrument(skip(self, messages), fields(user_message))]
Expand Down
21 changes: 17 additions & 4 deletions crates/goose/src/agents/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::{debug, instrument};

use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::system::{SystemConfig, SystemResult};
use crate::agents::system::{SystemConfig, SystemError, SystemResult};
use crate::message::{Message, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
Expand Down Expand Up @@ -52,9 +52,22 @@ impl Agent for ReferenceAgent {
.expect("Failed to list systems")
}

async fn passthrough(&self, _system: &str, _request: Value) -> SystemResult<Value> {
// TODO implement
Ok(Value::Null)
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meganit: maybe implement self.capabilities.passthrough helper - i think this code would be the same throughout

let capabilities = self.capabilities.lock().await;
tracing::info!("passthrough" = system, method);
let client = capabilities
.get_system(system)
.await
.unwrap_or_else(|| panic!("System not found: {}", system));
let client = client.lock().await;

tracing::info!("forwarding request" = method);
let result: Value = client
.forward_request(method, params)
.await
.map_err(SystemError::Client)?;

Ok(result)
}

#[instrument(skip(self, messages), fields(user_message))]
Expand Down
55 changes: 51 additions & 4 deletions crates/mcp-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use mcp_core::protocol::{
CallToolResult, Implementation, InitializeResult, JsonRpcError, JsonRpcMessage,
JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult,
ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities,
METHOD_NOT_FOUND
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand Down Expand Up @@ -94,6 +95,12 @@ pub trait McpClientTrait: Send + Sync {
async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error>;

async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error>;

async fn list_prompts(&self) -> Result<ListPromptsResult, Error>;

async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;

async fn forward_request(&self, method: &str, params: Value) -> Result<Value, Error>;
}

/// The MCP client is the interface for MCP operations.
Expand Down Expand Up @@ -129,6 +136,7 @@ where
where
R: for<'de> Deserialize<'de>,
{
println!("send_request: method={:?}, params={:?}", method, params);
let mut service = self.service.lock().await;
service.ready().await.map_err(|_| Error::NotReady)?;

Expand All @@ -149,7 +157,7 @@ where
params: params.clone(),
source: Box::new(e.into()),
})?;

println!("response_msg: {:?}", response_msg);
match response_msg {
JsonRpcMessage::Response(JsonRpcResponse {
id, result, error, ..
Expand Down Expand Up @@ -303,6 +311,7 @@ where
}

async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error> {
println!("list_tools");
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
Expand Down Expand Up @@ -339,4 +348,42 @@ where
// https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2
self.send_request("tools/call", params).await
}

async fn list_prompts(&self) -> Result<ListPromptsResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'prompts' capability".to_string(),
});
}

self.send_request("prompts/list", serde_json::json!({}))
.await
}

async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}

if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'prompts' capability".to_string(),
});
}

let params = serde_json::json!({ "name": name, "arguments": arguments });
// TODO ERROR: check that if there is an error, we send back is_error: true with msg
// https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2
self.send_request("prompts/get", params).await
}

async fn forward_request(&self, method: &str, params: Value) -> Result<Value, Error> {
println!("Forwarding request to server: method={}, params={}", method, params);
self.send_request(method, params).await
}
}
13 changes: 13 additions & 0 deletions crates/mcp-server/src/router.rs
Original file line number Diff line number Diff line change
Expand Up @@ -258,6 +258,7 @@ pub trait Router: Send + Sync + 'static {
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
async move {
println!("handling prompts list");
let prompts = self.list_prompts().unwrap_or_default();

let result = ListPromptsResult { prompts };
Expand All @@ -277,6 +278,7 @@ pub trait Router: Send + Sync + 'static {
req: JsonRpcRequest,
) -> impl Future<Output = Result<JsonRpcResponse, RouterError>> + Send {
async move {
println!("handling prompts get");
// Validate and extract parameters
let params = req
.params
Expand All @@ -294,6 +296,9 @@ pub trait Router: Send + Sync + 'static {
.and_then(Value::as_object)
.ok_or_else(|| RouterError::InvalidParams("Missing arguments object".into()))?;

println!("extracted arguments: {:?}", arguments);
println!("extracted name {:?}", prompt_name);

// Fetch the prompt definition first
let prompt = match self.list_prompts() {
Some(prompts) => prompts
Expand All @@ -305,6 +310,8 @@ pub trait Router: Send + Sync + 'static {
None => return Err(RouterError::PromptNotFound("No prompts available".into())),
};

println!("extracted prompt {:?}", prompt);

// Validate required arguments
for arg in &prompt.arguments {
if arg.required
Expand All @@ -321,13 +328,17 @@ pub trait Router: Send + Sync + 'static {
}
}

println!("validated arguments");

// Now get the prompt content
let description = self
.get_prompt(prompt_name)
.ok_or_else(|| RouterError::PromptNotFound("Prompt not found".into()))?
.await
.map_err(|e| RouterError::Internal(e.to_string()))?;

println!("extracted description {:?}", description);

// Validate prompt arguments for potential security issues from user text input
// Checks:
// - Prompt must be less than 10000 total characters
Expand Down Expand Up @@ -361,6 +372,8 @@ pub trait Router: Send + Sync + 'static {
}
}

println!("validated arguments round 2");

// Validate the prompt description length
if description.len() > 10000 {
return Err(RouterError::Internal(
Expand Down