diff --git a/crates/goose-server/src/routes/prompts.rs b/crates/goose-server/src/routes/prompts.rs index 1eae29651..8d3ebbc99 100644 --- a/crates/goose-server/src/routes/prompts.rs +++ b/crates/goose-server/src/routes/prompts.rs @@ -1,27 +1,31 @@ +use mcp_core::protocol::{GetPromptResult, ListPromptsResult}; + use crate::state::AppState; use axum::{ - extract::{State, Path}, + extract::{Path, State}, http::{HeaderMap, StatusCode}, routing::post, Json, Router, }; -use serde::Serialize; +use serde::{Deserialize, Serialize}; -#[derive(Serialize)] -struct ListPromptsResponse { - prompts: Vec, +#[derive(Serialize, Deserialize)] +struct Prompt { + name: String, + description: Option, + required: Option, } -#[derive(Serialize)] -struct GetPromptResponse { - name: String, - content: String, +#[derive(Serialize, Deserialize)] +struct PromptRequest { + system: String, } -async fn list_prompts( +async fn list_prompts_handler( State(state): State, headers: HeaderMap, -) -> Result, StatusCode> { + Json(request): Json, +) -> Result, StatusCode> { // Verify secret key let secret_key = headers .get("X-Secret-Key") @@ -34,26 +38,25 @@ async fn list_prompts( let agent = state.agent.lock().await; let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; - + // Get prompts through agent passthrough let result = agent - .passthrough("prompts", serde_json::json!({ "method": "list" })) + .passthrough(&request.system, "list_prompts", serde_json::json!({})) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - let prompts = result - .as_array() - .map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect()) - .unwrap_or_default(); - Ok(Json(ListPromptsResponse { prompts })) + // Deserialize the result to ListPromptsResult + let prompts_result: ListPromptsResult = + serde_json::from_value(result).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(Json(prompts_result)) } -async fn get_prompt( +async fn get_prompt_handler( + Path(prompt_name): Path, State(state): State, headers: HeaderMap, - Path(prompt_name): Path, -) -> Result, StatusCode> { + Json(payload): Json, +) -> Result, StatusCode> { // Verify secret key let secret_key = headers .get("X-Secret-Key") @@ -66,33 +69,28 @@ async fn get_prompt( let agent = state.agent.lock().await; let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; - + // Get prompt through agent passthrough let result = agent .passthrough( - "prompts", + &payload.system, + "get_prompt", serde_json::json!({ - "method": "get", "name": prompt_name - }) + }), ) .await .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; - - let content = result - .as_str() - .map(String::from) - .ok_or(StatusCode::INTERNAL_SERVER_ERROR)?; - Ok(Json(GetPromptResponse { - name: prompt_name, - content, - })) + // Deserialize the result to GetPromptResult + let prompt_result: GetPromptResult = + serde_json::from_value(result).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + Ok(Json(prompt_result)) } pub fn routes(state: AppState) -> Router { Router::new() - .route("/prompts/list", post(list_prompts)) - .route("/prompts/get/:prompt_name", post(get_prompt)) + .route("/prompts/list", post(list_prompts_handler)) + .route("/prompts/get/:prompt_name", post(get_prompt_handler)) .with_state(state) } diff --git a/crates/goose/src/agents/agent.rs b/crates/goose/src/agents/agent.rs index a735c5784..82898c6e0 100644 --- a/crates/goose/src/agents/agent.rs +++ b/crates/goose/src/agents/agent.rs @@ -24,7 +24,7 @@ pub trait Agent: Send + Sync { async fn list_systems(&self) -> Vec; /// Pass through a JSON-RPC request to a specific system - async fn passthrough(&self, system: &str, request: Value) -> SystemResult; + async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult; /// Get the total usage of the agent async fn usage(&self) -> Vec; diff --git a/crates/goose/src/agents/capabilities.rs b/crates/goose/src/agents/capabilities.rs index 9d22bf33f..058413df9 100644 --- a/crates/goose/src/agents/capabilities.rs +++ b/crates/goose/src/agents/capabilities.rs @@ -316,9 +316,8 @@ impl Capabilities { } /// Retrieve a client by name, if it exists - pub async fn system(&self, system: &str) -> Option>>> { + pub async fn get_system(&self, system: &str) -> Option>>> { // Use the clients hashmap to get the client associated with the system name self.clients.get(system).cloned() } } - diff --git a/crates/goose/src/agents/default.rs b/crates/goose/src/agents/default.rs index 6215c7773..749777bbe 100644 --- a/crates/goose/src/agents/default.rs +++ b/crates/goose/src/agents/default.rs @@ -6,7 +6,7 @@ use tracing::{debug, instrument}; use super::Agent; use crate::agents::capabilities::{Capabilities, ResourceItem}; -use crate::agents::system::{SystemConfig, SystemResult}; +use crate::agents::system::{SystemConfig, SystemError, SystemResult}; use crate::message::{Message, MessageContent, ToolRequest}; use crate::providers::base::Provider; use crate::providers::base::ProviderUsage; @@ -147,11 +147,20 @@ impl Agent for DefaultAgent { .expect("Failed to list systems") } - async fn passthrough(&self, system: &str, request: Value) -> SystemResult { + async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult { let capabilities = self.capabilities.lock().await; - - // Get the client/system that corresponds to the given system string - let system = capabilities.get_system(system).await?; + let client = capabilities + .get_system(system) + .await + .unwrap_or_else(|| panic!("System not found: {}", system)); + let client = client.lock().await; + + let result: Value = client + .forward_request(method, params) + .await + .map_err(SystemError::Client)?; + + Ok(result) } #[instrument(skip(self, messages), fields(user_message))] diff --git a/crates/goose/src/agents/reference.rs b/crates/goose/src/agents/reference.rs index f5cfa2afa..696c3848b 100644 --- a/crates/goose/src/agents/reference.rs +++ b/crates/goose/src/agents/reference.rs @@ -7,7 +7,7 @@ use tracing::{debug, instrument}; use super::Agent; use crate::agents::capabilities::Capabilities; -use crate::agents::system::{SystemConfig, SystemResult}; +use crate::agents::system::{SystemConfig, SystemError, SystemResult}; use crate::message::{Message, ToolRequest}; use crate::providers::base::Provider; use crate::providers::base::ProviderUsage; @@ -52,21 +52,20 @@ impl Agent for ReferenceAgent { .expect("Failed to list systems") } - async fn passthrough(&self, system: &str, request: Value) -> SystemResult { + async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult { let capabilities = self.capabilities.lock().await; - - // Get the client by name let client = capabilities - .get_client_for_tool(&format!("{}__dummy", system)) - .ok_or_else(|| SystemError::NotFound(system.to_string()))?; - - let client_guard = client.lock().await; - - // Forward the request to the client's execute endpoint - client_guard - .execute(request) + .get_system(system) .await - .map_err(|e| SystemError::ExecutionError(e.to_string())) + .unwrap_or_else(|| panic!("System not found: {}", system)); + let client = client.lock().await; + + let result: Value = client + .forward_request(method, params) + .await + .map_err(SystemError::Client)?; + + Ok(result) } #[instrument(skip(self, messages), fields(user_message))] diff --git a/crates/mcp-client/src/client.rs b/crates/mcp-client/src/client.rs index f1d6b5acb..a9d1d7aad 100644 --- a/crates/mcp-client/src/client.rs +++ b/crates/mcp-client/src/client.rs @@ -1,6 +1,12 @@ -use mcp_core::{prompt::Prompt, protocol::{ - CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND -}}; +use mcp_core::{ + prompt::Prompt, + protocol::{ + CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError, + JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult, + ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, + METHOD_NOT_FOUND, + }, +}; use serde::{Deserialize, Serialize}; use serde_json::Value; use std::sync::atomic::{AtomicU64, Ordering}; @@ -96,6 +102,8 @@ pub trait McpClientTrait: Send + Sync { async fn list_prompts(&self) -> Result; async fn get_prompt(&self, name: &str, arguments: Value) -> Result; + + async fn forward_request(&self, method: &str, params: Value) -> Result; } /// The MCP client is the interface for MCP operations. @@ -352,10 +360,11 @@ where message: "Server does not support 'prompts' capability".to_string(), }); } - - self.send_request("prompts/list", serde_json::json!({})).await + + self.send_request("prompts/list", serde_json::json!({})) + .await } - + async fn get_prompt(&self, name: &str, arguments: Value) -> Result { if !self.completed_initialization() { return Err(Error::NotInitialized); @@ -369,9 +378,12 @@ where } let params = serde_json::json!({ "name": name, "arguments": arguments }); - // TODO ERROR: check that if there is an error, we send back is_error: true with msg + // TODO ERROR: check that if there is an error, we send back is_error: true with msg // https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2 self.send_request("prompts/get", params).await - + } + + async fn forward_request(&self, method: &str, params: Value) -> Result { + self.send_request(method, params).await } }