-
Notifications
You must be signed in to change notification settings - Fork 606
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[feat] MCP prompts #639
[feat] MCP prompts #639
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
use mcp_core::protocol::{GetPromptResult, ListPromptsResult}; | ||
|
||
use crate::state::AppState; | ||
use axum::{ | ||
extract::State, | ||
http::{HeaderMap, StatusCode}, | ||
routing::post, | ||
Json, Router, | ||
}; | ||
use serde::{Deserialize, Serialize}; | ||
use serde_json::Value; | ||
|
||
#[derive(Serialize, Deserialize)] | ||
struct Prompt { | ||
name: String, | ||
description: Option<String>, | ||
required: Option<bool>, | ||
} | ||
|
||
#[derive(Serialize, Deserialize)] | ||
struct ListPromptRequest { | ||
system: String | ||
} | ||
|
||
#[derive(Serialize, Deserialize)] | ||
struct PromptRequest { | ||
system: String, | ||
name: String, | ||
arguments: Value | ||
} | ||
|
||
async fn list_prompts_handler( | ||
State(state): State<AppState>, | ||
headers: HeaderMap, | ||
Json(request): Json<ListPromptRequest>, | ||
) -> Result<Json<ListPromptsResult>, StatusCode> { | ||
// Verify secret key | ||
let secret_key = headers | ||
.get("X-Secret-Key") | ||
.and_then(|value| value.to_str().ok()) | ||
.ok_or(StatusCode::UNAUTHORIZED)?; | ||
|
||
if secret_key != state.secret_key { | ||
return Err(StatusCode::UNAUTHORIZED); | ||
} | ||
|
||
let agent = state.agent.lock().await; | ||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; | ||
|
||
// Get prompts through agent passthrough | ||
let result = agent | ||
.passthrough(&request.system, "prompts/list", serde_json::json!({})) | ||
.await | ||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
|
||
// Deserialize the result to ListPromptsResult | ||
let prompts_result: ListPromptsResult = | ||
serde_json::from_value(result).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
|
||
Ok(Json(prompts_result)) | ||
} | ||
|
||
async fn get_prompt_handler( | ||
State(state): State<AppState>, | ||
headers: HeaderMap, | ||
Json(request): Json<PromptRequest>, | ||
) -> Result<Json<GetPromptResult>, StatusCode> { | ||
// Verify secret key | ||
let secret_key = headers | ||
.get("X-Secret-Key") | ||
.and_then(|value| value.to_str().ok()) | ||
.ok_or(StatusCode::UNAUTHORIZED)?; | ||
|
||
if secret_key != state.secret_key { | ||
return Err(StatusCode::UNAUTHORIZED); | ||
} | ||
|
||
let agent = state.agent.lock().await; | ||
let agent = agent.as_ref().ok_or(StatusCode::NOT_FOUND)?; | ||
|
||
// Get prompt through agent passthrough | ||
let result = agent | ||
.passthrough( | ||
&request.system, | ||
"prompts/get", | ||
serde_json::json!({ | ||
"name": &request.name, | ||
"arguments": &request.arguments, | ||
}), | ||
) | ||
.await | ||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
|
||
// Deserialize the result to GetPromptResult | ||
let prompt_result: GetPromptResult = | ||
serde_json::from_value(result).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; | ||
Ok(Json(prompt_result)) | ||
} | ||
|
||
pub fn routes(state: AppState) -> Router { | ||
Router::new() | ||
.route("/prompts/list", post(list_prompts_handler)) | ||
.route("/prompts/get", post(get_prompt_handler)) | ||
.with_state(state) | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,7 +6,7 @@ use tracing::{debug, instrument}; | |
|
||
use super::Agent; | ||
use crate::agents::capabilities::{Capabilities, ResourceItem}; | ||
use crate::agents::system::{SystemConfig, SystemResult}; | ||
use crate::agents::system::{SystemConfig, SystemError, SystemResult}; | ||
use crate::message::{Message, MessageContent, ToolRequest}; | ||
use crate::providers::base::Provider; | ||
use crate::providers::base::ProviderUsage; | ||
|
@@ -147,9 +147,21 @@ impl Agent for DefaultAgent { | |
.expect("Failed to list systems") | ||
} | ||
|
||
async fn passthrough(&self, _system: &str, _request: Value) -> SystemResult<Value> { | ||
// TODO implement | ||
Ok(Value::Null) | ||
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value> { | ||
println!("in pass through"); | ||
let capabilities = self.capabilities.lock().await; | ||
let client = capabilities | ||
.get_system(system) | ||
.await | ||
.unwrap_or_else(|| panic!("System not found: {}", system)); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think we should return an Err (SystemError::UnknownSystem) here instead of panicing? |
||
let client = client.lock().await; | ||
|
||
let result: Value = client | ||
.forward_request(method, params) | ||
.await | ||
.map_err(SystemError::Client)?; | ||
|
||
Ok(result) | ||
} | ||
|
||
#[instrument(skip(self, messages), fields(user_message))] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,7 +7,7 @@ use tracing::{debug, instrument}; | |
|
||
use super::Agent; | ||
use crate::agents::capabilities::Capabilities; | ||
use crate::agents::system::{SystemConfig, SystemResult}; | ||
use crate::agents::system::{SystemConfig, SystemError, SystemResult}; | ||
use crate::message::{Message, ToolRequest}; | ||
use crate::providers::base::Provider; | ||
use crate::providers::base::ProviderUsage; | ||
|
@@ -52,9 +52,22 @@ impl Agent for ReferenceAgent { | |
.expect("Failed to list systems") | ||
} | ||
|
||
async fn passthrough(&self, _system: &str, _request: Value) -> SystemResult<Value> { | ||
// TODO implement | ||
Ok(Value::Null) | ||
async fn passthrough(&self, system: &str, method: &str, params: Value) -> SystemResult<Value> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. meganit: maybe implement self.capabilities.passthrough helper - i think this code would be the same throughout |
||
let capabilities = self.capabilities.lock().await; | ||
tracing::info!("passthrough" = system, method); | ||
let client = capabilities | ||
.get_system(system) | ||
.await | ||
.unwrap_or_else(|| panic!("System not found: {}", system)); | ||
let client = client.lock().await; | ||
|
||
tracing::info!("forwarding request" = method); | ||
let result: Value = client | ||
.forward_request(method, params) | ||
.await | ||
.map_err(SystemError::Client)?; | ||
|
||
Ok(result) | ||
} | ||
|
||
#[instrument(skip(self, messages), fields(user_message))] | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: for clarity
get_client
to matchget_client_for_tool
?