Skip to content

Commit

Permalink
first pass, still need to test
Browse files Browse the repository at this point in the history
  • Loading branch information
ahau-square committed Jan 17, 2025
1 parent 1b5442d commit a64bef2
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 66 deletions.
72 changes: 35 additions & 37 deletions crates/goose-server/src/routes/prompts.rs
Original file line number Diff line number Diff line change
@@ -1,27 +1,31 @@
use mcp_core::protocol::{GetPromptResult, ListPromptsResult};

use crate::state::AppState;
use axum::{
extract::{State, Path},
extract::{Path, State},
http::{HeaderMap, StatusCode},
routing::post,
Json, Router,
};
use serde::Serialize;
use serde::{Deserialize, Serialize};

#[derive(Serialize)]
struct ListPromptsResponse {
prompts: Vec<String>,
#[derive(Serialize, Deserialize)]
struct Prompt {
name: String,
description: Option<String>,
required: Option<bool>,
}

#[derive(Serialize)]
struct GetPromptResponse {
name: String,
content: String,
#[derive(Serialize, Deserialize)]
struct PromptRequest {
system: String,
}

async fn list_prompts(
async fn list_prompts_handler(
State(state): State<AppState>,
headers: HeaderMap,
) -> Result<Json<ListPromptsResponse>, StatusCode> {
Json(request): Json<PromptRequest>,
) -> Result<Json<ListPromptsResult>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
Expand All @@ -34,26 +38,25 @@ async fn list_prompts(

let agent = state.agent.lock().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;

// Get prompts through agent passthrough
let result = agent
.passthrough("prompts", serde_json::json!({ "method": "list" }))
.passthrough(&request.system, "list_prompts", serde_json::json!({}))
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let prompts = result
.as_array()
.map(|arr| arr.iter().filter_map(|v| v.as_str().map(String::from)).collect())
.unwrap_or_default();

Ok(Json(ListPromptsResponse { prompts }))
// Deserialize the result to ListPromptsResult
let prompts_result: ListPromptsResult =
serde_json::from_value(result).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(prompts_result))
}

async fn get_prompt(
async fn get_prompt_handler(
Path(prompt_name): Path<String>,
State(state): State<AppState>,
headers: HeaderMap,
Path(prompt_name): Path<String>,
) -> Result<Json<GetPromptResponse>, StatusCode> {
Json(payload): Json<PromptRequest>,
) -> Result<Json<GetPromptResult>, StatusCode> {
// Verify secret key
let secret_key = headers
.get("X-Secret-Key")
Expand All @@ -66,33 +69,28 @@ async fn get_prompt(

let agent = state.agent.lock().await;
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?;

// Get prompt through agent passthrough
let result = agent
.passthrough(
"prompts",
&payload.system,
"get_prompt",
serde_json::json!({
"method": "get",
"name": prompt_name
})
}),
)
.await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;

let content = result
.as_str()
.map(String::from)
.ok_or(StatusCode::INTERNAL_SERVER_ERROR)?;

Ok(Json(GetPromptResponse {
name: prompt_name,
content,
}))
// Deserialize the result to GetPromptResult
let prompt_result: GetPromptResult =
serde_json::from_value(result).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
Ok(Json(prompt_result))
}

pub fn routes(state: AppState) -> Router {
Router::new()
.route("/prompts/list", post(list_prompts))
.route("/prompts/get/:prompt_name", post(get_prompt))
.route("/prompts/list", post(list_prompts_handler))
.route("/prompts/get/:prompt_name", post(get_prompt_handler))
.with_state(state)
}
2 changes: 1 addition & 1 deletion crates/goose/src/agents/agent.rs
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ pub trait Agent: Send + Sync {
async fn list_systems(&self) -> Vec<String>;

/// Pass through a JSON-RPC request to a specific system
async fn passthrough(&self, system: &str, request: Value) -> SystemResult<Value>;
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value>;

/// Get the total usage of the agent
async fn usage(&self) -> Vec<ProviderUsage>;
Expand Down
3 changes: 1 addition & 2 deletions crates/goose/src/agents/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -316,9 +316,8 @@ impl Capabilities {
}

/// Retrieve a client by name, if it exists
pub async fn system(&self, system: &str) -> Option<Arc<Mutex<Box<dyn McpClientTrait>>>> {
pub async fn get_system(&self, system: &str) -> Option<Arc<Mutex<Box<dyn McpClientTrait>>>> {
// Use the clients hashmap to get the client associated with the system name
self.clients.get(system).cloned()
}
}

19 changes: 14 additions & 5 deletions crates/goose/src/agents/default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ use tracing::{debug, instrument};

use super::Agent;
use crate::agents::capabilities::{Capabilities, ResourceItem};
use crate::agents::system::{SystemConfig, SystemResult};
use crate::agents::system::{SystemConfig, SystemError, SystemResult};
use crate::message::{Message, MessageContent, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
Expand Down Expand Up @@ -147,11 +147,20 @@ impl Agent for DefaultAgent {
.expect("Failed to list systems")
}

async fn passthrough(&self, system: &str, request: Value) -> SystemResult<Value> {
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value> {
let capabilities = self.capabilities.lock().await;

// Get the client/system that corresponds to the given system string
let system = capabilities.get_system(system).await?;
let client = capabilities
.get_system(system)
.await
.unwrap_or_else(|| panic!("System not found: {}", system));
let client = client.lock().await;

let result: Value = client
.forward_request(method, params)
.await
.map_err(SystemError::Client)?;

Ok(result)
}

#[instrument(skip(self, messages), fields(user_message))]
Expand Down
25 changes: 12 additions & 13 deletions crates/goose/src/agents/reference.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use tracing::{debug, instrument};

use super::Agent;
use crate::agents::capabilities::Capabilities;
use crate::agents::system::{SystemConfig, SystemResult};
use crate::agents::system::{SystemConfig, SystemError, SystemResult};
use crate::message::{Message, ToolRequest};
use crate::providers::base::Provider;
use crate::providers::base::ProviderUsage;
Expand Down Expand Up @@ -52,21 +52,20 @@ impl Agent for ReferenceAgent {
.expect("Failed to list systems")
}

async fn passthrough(&self, system: &str, request: Value) -> SystemResult<Value> {
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value> {
let capabilities = self.capabilities.lock().await;

// Get the client by name
let client = capabilities
.get_client_for_tool(&format!("{}__dummy", system))
.ok_or_else(|| SystemError::NotFound(system.to_string()))?;

let client_guard = client.lock().await;

// Forward the request to the client's execute endpoint
client_guard
.execute(request)
.get_system(system)
.await
.map_err(|e| SystemError::ExecutionError(e.to_string()))
.unwrap_or_else(|| panic!("System not found: {}", system));
let client = client.lock().await;

let result: Value = client
.forward_request(method, params)
.await
.map_err(SystemError::Client)?;

Ok(result)
}

#[instrument(skip(self, messages), fields(user_message))]
Expand Down
28 changes: 20 additions & 8 deletions crates/mcp-client/src/client.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
use mcp_core::{prompt::Prompt, protocol::{
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError, JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult, ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities, METHOD_NOT_FOUND
}};
use mcp_core::{
prompt::Prompt,

Check failure on line 2 in crates/mcp-client/src/client.rs

View workflow job for this annotation

GitHub Actions / build

unused import: `prompt::Prompt`
protocol::{
CallToolResult, GetPromptResult, Implementation, InitializeResult, JsonRpcError,
JsonRpcMessage, JsonRpcNotification, JsonRpcRequest, JsonRpcResponse, ListPromptsResult,
ListResourcesResult, ListToolsResult, ReadResourceResult, ServerCapabilities,
METHOD_NOT_FOUND,
},
};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::sync::atomic::{AtomicU64, Ordering};
Expand Down Expand Up @@ -96,6 +102,8 @@ pub trait McpClientTrait: Send + Sync {
async fn list_prompts(&self) -> Result<ListPromptsResult, Error>;

async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error>;

async fn forward_request(&self, method: &str, params: Value) -> Result<Value, Error>;
}

/// The MCP client is the interface for MCP operations.
Expand Down Expand Up @@ -352,10 +360,11 @@ where
message: "Server does not support 'prompts' capability".to_string(),
});
}

self.send_request("prompts/list", serde_json::json!({})).await

self.send_request("prompts/list", serde_json::json!({}))
.await
}

async fn get_prompt(&self, name: &str, arguments: Value) -> Result<GetPromptResult, Error> {
if !self.completed_initialization() {
return Err(Error::NotInitialized);
Expand All @@ -369,9 +378,12 @@ where
}

let params = serde_json::json!({ "name": name, "arguments": arguments });
// TODO ERROR: check that if there is an error, we send back is_error: true with msg
// TODO ERROR: check that if there is an error, we send back is_error: true with msg
// https://modelcontextprotocol.io/docs/concepts/tools#error-handling-2
self.send_request("prompts/get", params).await

}

async fn forward_request(&self, method: &str, params: Value) -> Result<Value, Error> {
self.send_request(method, params).await
}
}

0 comments on commit a64bef2

Please sign in to comment.