Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[feat] MCP prompts #639

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/goose-server/src/routes/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Export route modules
pub mod agent;
pub mod health;
pub mod prompts;
pub mod reply;
pub mod secrets;
pub mod system;
Expand All @@ -14,5 +15,6 @@ pub fn configure(state: crate::state::AppState) -> Router {
.merge(reply::routes(state.clone()))
.merge(agent::routes(state.clone()))
.merge(system::routes(state.clone()))
.merge(prompts::routes(state.clone()))
.merge(secrets::routes(state))
}
96 changes: 96 additions & 0 deletions crates/goose-server/src/routes/prompts.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
use mcp_core::protocol::{GetPromptResult, ListPromptsResult};

use crate::state::AppState;
use axum::{
extract::{Path, State},
http::{HeaderMap, StatusCode},
routing::post,
Json, Router,
};
use serde::{Deserialize, Serialize};

#[derive(Serialize, Deserialize)]
struct Prompt {
name: String,
description: Option<String>,
required: Option<bool>,
}

#[derive(Serialize, Deserialize)]
struct PromptRequest {
system: String,
}

async fn list_prompts_handler(
State(state): State<AppState>,
headers: HeaderMap,
Json(request): Json<PromptRequest>,
) -> Result<Json<ListPromptsResult>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;

if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}

let agent = state.agent.lock().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;

// Get prompts through agent passthrough
let result = agent
.passthrough(&request.system, "list_prompts", serde_json::json!({}))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

// Deserialize the result to ListPromptsResult
let prompts_result: ListPromptsResult =
serde_json::from_value(result).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(prompts_result))
}

async fn get_prompt_handler(
Path(prompt_name): Path<String>,
State(state): State<AppState>,
headers: HeaderMap,
Json(payload): Json<PromptRequest>,
) -> Result<Json<GetPromptResult>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
.and_then(|value| value.to_str().ok())
.ok_or(StatusCode::UNAUTHORIZED)?;

if secret_key != state.secret_key {
return Err(StatusCode::UNAUTHORIZED);
}

let agent = state.agent.lock().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;

// Get prompt through agent passthrough
let result = agent
.passthrough(
&payload.system,
"get_prompt",
serde_json::json!({
"name": prompt_name
}),
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

// Deserialize the result to GetPromptResult
let prompt_result: GetPromptResult =
serde_json::from_value(result).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(prompt_result))
}

pub fn routes(state: AppState) -> Router {
Router::new()
.route("/prompts/list", post(list_prompts_handler))
.route("/prompts/get/:prompt_name", post(get_prompt_handler))
.with_state(state)
}
2 changes: 1 addition & 1 deletion crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub trait Agent: Send + Sync {
async fn list_systems(&self) -> Vec<String>;

/// Pass through a JSON-RPC request to a specific system
async fn passthrough(&self, system: &str, request: Value) -> SystemResult<Value>;
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value>;

/// Get the total usage of the agent
async fn usage(&self) -> Vec<ProviderUsage>;
Expand Down
6 changes: 6 additions & 0 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -314,4 +314,10 @@ impl Capabilities {

result
}

/// Retrieve a client by name, if it exists
pub async fn get_system(&self, system: &str) -> Option<Arc<Mutex<Box<dyn McpClientTrait>>>> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: for clarity get_client to match get_client_for_tool?

// Use the clients hashmap to get the client associated with the system name
self.clients.get(system).cloned()
}
}
19 changes: 15 additions & 4 deletions crates/goose/src/agents/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tracing::{debug, instrument};

use super::Agent;
use crate::agents::capabilities::{Capabilities, ResourceItem};
use crate::agents::system::{SystemConfig, SystemResult};
use crate::agents::system::{SystemConfig, SystemError, SystemResult};
use crate::message::{Message, MessageContent, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
Expand Down Expand Up @@ -147,9 +147,20 @@ impl Agent for DefaultAgent {
.expect("Failed to list systems")
}

async fn passthrough(&self, _system: &str, _request: Value) -> SystemResult<Value> {
// TODO implement
Ok(Value::Null)
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value> {
let capabilities = self.capabilities.lock().await;
let client = capabilities
.get_system(system)
.await
.unwrap_or_else(|| panic!("System not found: {}", system));
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should return an Err (SystemError::UnknownSystem) here instead of panicing?

let client = client.lock().await;

let result: Value = client
.forward_request(method, params)
.await
.map_err(SystemError::Client)?;

Ok(result)
}

#[instrument(skip(self, messages), fields(user_message))]
Expand Down
19 changes: 15 additions & 4 deletions crates/goose/src/agents/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::{debug, instrument};

use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::system::{SystemConfig, SystemResult};
use crate::agents::system::{SystemConfig, SystemError, SystemResult};
use crate::message::{Message, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
Expand Down Expand Up @@ -52,9 +52,20 @@ impl Agent for ReferenceAgent {
.expect("Failed to list systems")
}

async fn passthrough(&self, _system: &str, _request: Value) -> SystemResult<Value> {
// TODO implement
Ok(Value::Null)
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value> {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

meganit: maybe implement self.capabilities.passthrough helper - i think this code would be the same throughout

let capabilities = self.capabilities.lock().await;
let client = capabilities
.get_system(system)
.await
.unwrap_or_else(|| panic!("System not found: {}", system));
let client = client.lock().await;

let result: Value = client
.forward_request(method, params)
.await
.map_err(SystemError::Client)?;

Ok(result)
}

#[instrument(skip(self, messages), fields(user_message))]
Expand Down
57 changes: 52 additions & 5 deletions crates/mcp-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
use mcp_core::protocol::{
CallToolResult, Implementation, InitializeResult, JsonRpcError, JsonRpcMessage,
JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListResourcesResult, ListToolsResult,
ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND,
use mcp_core::{
prompt::Prompt,

Check failure on line 2 in crates/mcp-client/src/client.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `prompt::Prompt`
protocol::{
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities,
METHOD_NOT_FOUND,
},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
Expand Down Expand Up @@ -94,6 +98,12 @@
async fn list_tools(&self, next_cursor: Option<String>) -> Result<ListToolsResult, Error>;

async fn call_tool(&self, name: &str, arguments: Value) -> Result<CallToolResult, Error>;

async fn list_prompts(&self) -> Result<ListPromptsResult, Error>;

async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;

async fn forward_request(&self, method: &str, params: Value) -> Result<Value, Error>;
}

/// The MCP client is the interface for MCP operations.
Expand Down Expand Up @@ -125,7 +135,7 @@
}

/// Send a JSON-RPC request and check we don't get an error response.
async fn send_request<R>(&self, method: &str, params: Value) -> Result<R, Error>
pub async fn send_request<R>(&self, method: &str, params: Value) -> Result<R, Error>
where
R: for<'de> Deserialize<'de>,
{
Expand Down Expand Up @@ -339,4 +349,41 @@
// https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2
self.send_request("tools/call", params).await
}

async fn list_prompts(&self) -> Result<ListPromptsResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}
if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'prompts' capability".to_string(),
});
}

self.send_request("prompts/list", serde_json::json!({}))
.await
}

async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
}

if self.server_capabilities.as_ref().unwrap().prompts.is_none() {
return Err(Error::RpcError {
code: METHOD_NOT_FOUND,
message: "Server does not support 'prompts' capability".to_string(),
});
}

let params = serde_json::json!({ "name": name, "arguments": arguments });
// TODO ERROR: check that if there is an error, we send back is_error: true with msg
// https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2
self.send_request("prompts/get", params).await
}

async fn forward_request(&self, method: &str, params: Value) -> Result<Value, Error> {
self.send_request(method, params).await
}
}
Loading