From 11998d2bdf1161912e29d69cdae084bf3aca2f47 Mon Sep 17 00:00:00 2001 From: rincel <0xrinegade@gmail.com> Date: Thu, 23 Jan 2025 11:40:17 +0300 Subject: [PATCH] fucking finally this shit works --- .gitignore | 3 +- Cargo.toml | 4 + config.json | 7 +- src/config.rs | 38 +++ src/lib.rs | 3 + src/main.rs | 18 ++ src/protocol.rs | 222 ++++++++++++++++ src/server/mod.rs | 69 ++++- src/tools/mod.rs | 648 ++++++++++++++++++++++++++++++++-------------- src/transport.rs | 61 ++++- test.sh | 2 + test_request.json | 2 +- test_sequence.sh | 4 + tests/e2e.rs | 20 +- 14 files changed, 875 insertions(+), 226 deletions(-) create mode 100644 src/config.rs create mode 100644 src/protocol.rs create mode 100755 test.sh create mode 100755 test_sequence.sh diff --git a/.gitignore b/.gitignore index 795166e..4211dd3 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,8 @@ # Generated by Cargo /target/ Cargo.lock - +logs/ +logs-runtime.log # Remove Cargo.lock from gitignore if creating an executable, leave it for libraries # More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html diff --git a/Cargo.toml b/Cargo.toml index 0a86eb4..b2ab05d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,10 @@ version = "0.1.2" edition = "2021" [dependencies] +log = "0.4" +env_logger = "0.10" +chrono = "0.4" +url = { version = "2.4.1", features = ["serde"] } anyhow = "1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" diff --git a/config.json b/config.json index a508cc7..c2b90f9 100644 --- a/config.json +++ b/config.json @@ -1,6 +1,5 @@ { - "url": "https://svmai.com", - "apikey": "svmASEodmoemdwoe242424", - "iamAgent": false, - "tools": ["thinkchain"] + "rpc_url": "https://api.mainnet-beta.solana.com", + "commitment": "confirmed", + "protocol_version": "2024-11-05" } diff --git a/src/config.rs b/src/config.rs new file mode 100644 index 0000000..7c48565 --- /dev/null +++ b/src/config.rs @@ -0,0 +1,38 @@ +use anyhow::{Result, Context}; +use serde::Deserialize; +use std::{env, fs}; +use crate::protocol::LATEST_PROTOCOL_VERSION; + +#[derive(Debug, Deserialize)] +pub struct Config { + pub rpc_url: String, + pub commitment: String, + pub protocol_version: String, +} + +impl Config { + pub fn load() -> Result { + // Try to load from config file first + if let Ok(content) = fs::read_to_string("config.json") { + let config: Config = serde_json::from_str(&content) + .context("Failed to parse config.json")?; + return Ok(config); + } + + // Fall back to environment variables + let rpc_url = env::var("SOLANA_RPC_URL") + .unwrap_or_else(|_| "http://api.opensvm.com".to_string()); + + let commitment = env::var("SOLANA_COMMITMENT") + .unwrap_or_else(|_| "confirmed".to_string()); + + let protocol_version = env::var("SOLANA_PROTOCOL_VERSION") + .unwrap_or_else(|_| LATEST_PROTOCOL_VERSION.to_string()); + + Ok(Config { + rpc_url, + commitment, + protocol_version, + }) + } +} diff --git a/src/lib.rs b/src/lib.rs index e9efe33..0e864d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,10 @@ +pub mod config; +pub mod protocol; pub mod rpc; pub mod server; pub mod tools; pub mod transport; +pub use config::Config; pub use server::start_server; pub use transport::CustomStdioTransport; diff --git a/src/main.rs b/src/main.rs index 68034a1..9fe4ecf 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,6 +1,24 @@ use anyhow::Result; +use env_logger::Builder; +use log::LevelFilter; +use std::io::Write; #[tokio::main] async fn main() -> Result<()> { + // Set up logging to stderr + Builder::new() + .filter_level(LevelFilter::Error) + .format(|buf, record| { + writeln!( + buf, + "{} [{}] {}", + chrono::Local::now().format("%Y-%m-%d %H:%M:%S"), + record.level(), + record.args() + ) + }) + .init(); + + log::info!("Starting Solana MCP server..."); solana_mcp_server::server::start_server().await } diff --git a/src/protocol.rs b/src/protocol.rs new file mode 100644 index 0000000..334d556 --- /dev/null +++ b/src/protocol.rs @@ -0,0 +1,222 @@ +use std::collections::HashMap; +use serde::{Deserialize, Serialize}; +use url::Url; + +pub const LATEST_PROTOCOL_VERSION: &str = "2024-11-05"; + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[serde(default)] +pub struct Implementation { + pub name: String, + pub version: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[serde(default)] +pub struct InitializeRequest { + pub protocol_version: String, + pub capabilities: ClientCapabilities, + pub client_info: Implementation, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[serde(default)] +pub struct InitializeResponse { + pub protocol_version: String, + pub capabilities: ServerCapabilities, + pub server_info: Implementation, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[serde(default)] +pub struct ServerCapabilities { + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub experimental: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub logging: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub prompts: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub resources: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[serde(default)] +pub struct PromptCapabilities { + pub list_changed: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[serde(default)] +pub struct ResourceCapabilities { + pub subscribe: Option, + pub list_changed: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[serde(default)] +pub struct ClientCapabilities { + pub experimental: Option, + pub sampling: Option, + pub roots: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize, Default)] +#[serde(rename_all = "camelCase")] +#[serde(default)] +pub struct RootCapabilities { + pub list_changed: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolDefinition { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub input_schema: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CallToolRequest { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CallToolResponse { + pub content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub is_error: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum ToolResponseContent { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image")] + Image { data: String, mime_type: String }, + #[serde(rename = "resource")] + Resource { resource: ResourceContents }, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourceContents { + pub uri: Url, + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ListRequest { + #[serde(skip_serializing_if = "Option::is_none")] + pub cursor: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct ToolsListResponse { + pub tools: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptsListResponse { + pub prompts: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Prompt { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub arguments: Option>, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PromptArgument { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub required: Option, +} + +#[derive(Debug, Deserialize, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ResourcesListResponse { + pub resources: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub next_cursor: Option, + #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")] + pub meta: Option>, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct Resource { + pub uri: Url, + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub mime_type: Option, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ErrorCode { + // SDK error codes + ConnectionClosed = -1, + RequestTimeout = -2, + + // Standard JSON-RPC error codes + ParseError = -32700, + InvalidRequest = -32600, + MethodNotFound = -32601, + InvalidParams = -32602, + InternalError = -32603, +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_server_capabilities() { + let capabilities = ServerCapabilities::default(); + let json = serde_json::to_string(&capabilities).unwrap(); + assert_eq!(json, "{}"); + } +} diff --git a/src/server/mod.rs b/src/server/mod.rs index 8fdd048..ac6a1f0 100644 --- a/src/server/mod.rs +++ b/src/server/mod.rs @@ -1,34 +1,83 @@ use anyhow::Result; -use crate::transport::{Transport, JsonRpcMessage}; -use crate::CustomStdioTransport; -use serde_json::Value; +use solana_client::nonblocking::rpc_client::RpcClient; +use solana_sdk::commitment_config::CommitmentConfig; +use std::sync::Arc; +use tokio::sync::RwLock; +use crate::transport::{Transport, JsonRpcMessage, JsonRpcNotification, JsonRpcVersion}; +use crate::{Config, CustomStdioTransport}; + +pub struct ServerState { + pub rpc_client: RpcClient, + pub initialized: bool, + pub protocol_version: String, +} + +impl ServerState { + pub fn new(config: &Config) -> Self { + let commitment = match config.commitment.as_str() { + "processed" => CommitmentConfig::processed(), + "confirmed" => CommitmentConfig::confirmed(), + "finalized" => CommitmentConfig::finalized(), + _ => CommitmentConfig::default(), + }; + + let rpc_client = RpcClient::new_with_commitment( + config.rpc_url.clone(), + commitment, + ); + + Self { + rpc_client, + initialized: false, + protocol_version: config.protocol_version.clone() + } + } +} pub async fn start_server() -> Result<()> { - eprintln!("Solana MCP server ready - {} v{}", - env!("CARGO_PKG_NAME"), - env!("CARGO_PKG_VERSION") - ); + log::info!("Starting Solana MCP server..."); + let config = Config::load()?; + log::info!("Loaded config: RPC URL: {}, Protocol Version: {}", config.rpc_url, config.protocol_version); + + let state = Arc::new(RwLock::new(ServerState::new(&config))); + let transport = CustomStdioTransport::new(); transport.open()?; + log::info!("Opened stdio transport"); + + // Send initial protocol version notification + log::info!("Sending protocol version notification: {}", config.protocol_version); + transport.send(&JsonRpcMessage::Notification(JsonRpcNotification { + jsonrpc: JsonRpcVersion::V2, + method: "protocol".to_string(), + params: Some(serde_json::json!({ + "version": config.protocol_version.clone() + })), + }))?; + // Start message loop + log::info!("Starting message loop"); loop { match transport.receive() { Ok(message) => { let message_str = serde_json::to_string(&message)?; - let response = crate::tools::handle_request(&message_str).await?; + log::debug!("Received message: {}", message_str); + let response = crate::tools::handle_request(&message_str, state.clone()).await?; + log::debug!("Sending response: {}", serde_json::to_string(&response)?); transport.send(&response)?; } Err(e) => { if e.to_string().contains("Connection closed") { - eprintln!("Client disconnected"); + log::info!("Client disconnected"); break; } - eprintln!("Error receiving message: {}", e); + log::error!("Error receiving message: {}", e); } } } + log::info!("Closing transport"); transport.close()?; Ok(()) } diff --git a/src/tools/mod.rs b/src/tools/mod.rs index 3764d5a..a248a39 100644 --- a/src/tools/mod.rs +++ b/src/tools/mod.rs @@ -1,9 +1,13 @@ use anyhow::Result; use crate::transport::{JsonRpcMessage, JsonRpcResponse, JsonRpcError, JsonRpcVersion}; -use serde::{Deserialize, Serialize}; +use crate::protocol::{InitializeRequest, InitializeResponse, ServerCapabilities, Implementation, LATEST_PROTOCOL_VERSION, ToolDefinition, ToolsListResponse, Resource, ResourcesListResponse}; +use url::Url; +use serde::Deserialize; use serde_json::Value; +use std::collections::HashMap; pub fn create_success_response(result: Value, id: u64) -> JsonRpcMessage { + log::debug!("Creating success response with id {}", id); JsonRpcMessage::Response(JsonRpcResponse { jsonrpc: JsonRpcVersion::V2, id, @@ -12,175 +16,213 @@ pub fn create_success_response(result: Value, id: u64) -> JsonRpcMessage { }) } -pub fn create_error_response(code: i32, message: String, id: u64) -> JsonRpcMessage { +pub fn create_error_response(code: i32, message: String, id: u64, protocol_version: Option<&str>) -> JsonRpcMessage { + log::error!("Creating error response: {} (code: {})", message, code); + let error = JsonRpcError { + code, + message, + data: protocol_version.map(|v| serde_json::json!({ "protocolVersion": v })), + }; + JsonRpcMessage::Response(JsonRpcResponse { jsonrpc: JsonRpcVersion::V2, id, result: None, - error: Some(JsonRpcError { - code, - message, - data: None, - }), + error: Some(error), }) } -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -struct InitializeParams { - protocol_version: String, - capabilities: Value, - client_info: ClientInfo, -} - -#[derive(Debug, Deserialize)] -#[serde(rename_all = "camelCase")] -struct ClientInfo { - name: String, - version: String, -} - #[derive(Debug, Deserialize)] #[serde(rename_all = "camelCase")] struct CancelledParams { + #[allow(dead_code)] request_id: i64, + #[allow(dead_code)] reason: String, } -pub async fn handle_initialize(params: Option, id: Option) -> Result { +pub async fn handle_initialize(params: Option, id: Option, state: &ServerState) -> Result { + log::info!("Handling initialize request"); if let Some(params) = params { - let _init_params: InitializeParams = serde_json::from_value(params)?; - Ok(create_success_response( - serde_json::json!({ - "serverInfo": { - "name": "solana-mcp-server", - "version": env!("CARGO_PKG_VERSION") + let init_params = match serde_json::from_value::(params.clone()) { + Ok(params) => params, + Err(e) => { + log::error!("Failed to parse initialize params: {}", e); + return Ok(create_error_response( + -32602, + "Invalid params: protocolVersion is required".to_string(), + id.and_then(|v| v.as_u64()).unwrap_or(0), + Some(state.protocol_version.as_str()), + )); + } + }; + + log::info!( + "Initializing with protocol version: {}, client: {} v{}", + init_params.protocol_version, + init_params.client_info.name, + init_params.client_info.version + ); + + // Validate protocol version + if init_params.protocol_version != state.protocol_version { + log::error!( + "Protocol version mismatch. Server: {}, Client: {}", + state.protocol_version, + init_params.protocol_version + ); + return Ok(create_error_response( + -32002, + format!("Protocol version mismatch. Server: {}, Client: {}", + state.protocol_version, init_params.protocol_version), + id.and_then(|v| v.as_u64()).unwrap_or(0), + Some(state.protocol_version.as_str()), + )); + } + + let response = InitializeResponse { + protocol_version: LATEST_PROTOCOL_VERSION.to_string(), + server_info: Implementation { + name: "solana-mcp-server".to_string(), + version: env!("CARGO_PKG_VERSION").to_string(), + }, + capabilities: ServerCapabilities { + tools: { + let mut tools = HashMap::new(); + tools.insert("getAccountInfo".to_string(), ToolDefinition { + name: "getAccountInfo".to_string(), + description: Some("Returns all information associated with the account".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "pubkey": { + "type": "string", + "description": "Account public key (base58 encoded)" + }, + "commitment": { + "type": "string", + "description": "Commitment level", + "enum": ["processed", "confirmed", "finalized"] + }, + "encoding": { + "type": "string", + "description": "Encoding format", + "enum": ["base58", "base64", "jsonParsed"] + } + }, + "required": ["pubkey"] + }), + }); + tools.insert("getBalance".to_string(), ToolDefinition { + name: "getBalance".to_string(), + description: Some("Returns the balance of the account".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "pubkey": { + "type": "string", + "description": "Account public key (base58 encoded)" + }, + "commitment": { + "type": "string", + "description": "Commitment level", + "enum": ["processed", "confirmed", "finalized"] + } + }, + "required": ["pubkey"] + }), + }); + tools.insert("getProgramAccounts".to_string(), ToolDefinition { + name: "getProgramAccounts".to_string(), + description: Some("Returns all accounts owned by the program".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "programId": { + "type": "string", + "description": "Program public key (base58 encoded)" + }, + "config": { + "type": "object", + "description": "Configuration object", + "properties": { + "encoding": { + "type": "string", + "enum": ["base58", "base64", "jsonParsed"] + }, + "commitment": { + "type": "string", + "enum": ["processed", "confirmed", "finalized"] + } + } + } + }, + "required": ["programId"] + }), + }); + tools.insert("getTransaction".to_string(), ToolDefinition { + name: "getTransaction".to_string(), + description: Some("Returns transaction details".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "signature": { + "type": "string", + "description": "Transaction signature (base58 encoded)" + }, + "commitment": { + "type": "string", + "enum": ["processed", "confirmed", "finalized"] + } + }, + "required": ["signature"] + }), + }); + tools.insert("getHealth".to_string(), ToolDefinition { + name: "getHealth".to_string(), + description: Some("Returns the current health of the node".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": {} + }), + }); + tools.insert("getVersion".to_string(), ToolDefinition { + name: "getVersion".to_string(), + description: Some("Returns the current Solana version".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": {} + }), + }); + Some(tools) }, - "capabilities": { - "tools": { - // Account methods - "getAccountInfo": { - "description": "Returns all information associated with the account", - "parameters": { - "pubkey": "Account public key (base58 encoded)", - "commitment": "Commitment level (optional)", - "encoding": "Encoding format (optional)" - } - }, - "getBalance": { - "description": "Returns the balance of the account", - "parameters": { - "pubkey": "Account public key (base58 encoded)", - "commitment": "Commitment level (optional)" - } - }, - "getProgramAccounts": { - "description": "Returns all accounts owned by the program", - "parameters": { - "programId": "Program public key (base58 encoded)", - "config": "Configuration object (optional)" - } - }, - // Transaction methods - "getTransaction": { - "description": "Returns transaction details", - "parameters": { - "signature": "Transaction signature (base58 encoded)", - "config": "Configuration object (optional)" - } - }, - "getSignaturesForAddress": { - "description": "Returns signatures for transactions involving an address", - "parameters": { - "address": "Account address (base58 encoded)", - "config": "Configuration object (optional)" - } - }, - "getSignatureStatuses": { - "description": "Returns the statuses of a list of signatures", - "parameters": { - "signatures": "Array of transaction signatures", - "config": "Configuration object (optional)" - } - }, - // Token methods - "getTokenAccountBalance": { - "description": "Returns the token balance of an SPL Token account", - "parameters": { - "accountAddress": "Token account address (base58 encoded)", - "commitment": "Commitment level (optional)" - } - }, - "getTokenSupply": { - "description": "Returns the total supply of an SPL Token type", - "parameters": { - "mint": "Token mint address (base58 encoded)", - "commitment": "Commitment level (optional)" - } - }, - "getTokenLargestAccounts": { - "description": "Returns the 20 largest accounts of a particular SPL Token type", - "parameters": { - "mint": "Token mint address (base58 encoded)", - "commitment": "Commitment level (optional)" - } - }, - // Block methods - "getBlock": { - "description": "Returns information about a confirmed block", - "parameters": { - "slot": "Slot number", - "config": "Configuration object (optional)" - } - }, - "getBlockHeight": { - "description": "Returns the current block height", - "parameters": { - "commitment": "Commitment level (optional)" - } - }, - // System methods - "getHealth": { - "description": "Returns the current health of the node", - "parameters": {} - }, - "getVersion": { - "description": "Returns the current Solana version", - "parameters": {} - }, - "getSlot": { - "description": "Returns the current slot", - "parameters": { - "commitment": "Commitment level (optional)" - } - }, - "getEpochInfo": { - "description": "Returns information about the current epoch", - "parameters": { - "commitment": "Commitment level (optional)" - } - } - }, - "resources": { - "solana_docs": { - "description": "Solana documentation", - "uri": "solana://docs/core" - }, - "rpc_docs": { - "description": "RPC API documentation", - "uri": "solana://docs/rpc" - } - } - } - }), + resources: { + let mut resources = HashMap::new(); + resources.insert("docs".to_string(), Resource { + name: "Documentation".to_string(), + description: Some("Solana API documentation".to_string()), + uri: Url::parse("https://docs.solana.com/developing/clients/jsonrpc-api").unwrap(), + mime_type: Some("text/html".to_string()), + }); + Some(resources) + }, + ..Default::default() + }, + }; + + log::info!("Server initialized successfully"); + Ok(create_success_response( + serde_json::to_value(response).unwrap(), id.and_then(|v| v.as_u64()).unwrap_or(0), )) } else { - Ok(create_error_response(-32602, "Invalid params".to_string(), id.and_then(|v| v.as_u64()).unwrap_or(0))) + log::error!("Missing initialization params"); + Ok(create_error_response(-32602, "Invalid params".to_string(), id.and_then(|v| v.as_u64()).unwrap_or(0), Some(state.protocol_version.as_str()))) } } -pub async fn handle_cancelled(params: Option, id: Option) -> Result { +pub async fn handle_cancelled(params: Option, id: Option, state: &ServerState) -> Result { + log::info!("Handling cancelled request"); if let Some(params) = params { let _cancel_params: CancelledParams = serde_json::from_value(params)?; Ok(create_success_response( @@ -188,85 +230,303 @@ pub async fn handle_cancelled(params: Option, id: Option) -> Resul id.and_then(|v| v.as_u64()).unwrap_or(0) )) } else { - Ok(create_error_response(-32602, "Invalid params".to_string(), id.and_then(|v| v.as_u64()).unwrap_or(0))) + log::error!("Missing cancelled params"); + Ok(create_error_response(-32602, "Invalid params".to_string(), id.and_then(|v| v.as_u64()).unwrap_or(0), Some(state.protocol_version.as_str()))) } } -pub async fn handle_tools_list(id: Option) -> Result { - Ok(create_success_response( - serde_json::json!({ - "tools": [ - { - "name": "getAccountInfo", - "description": "Returns all information associated with the account", - "parameters": { - "pubkey": "Account public key (base58 encoded)" +pub async fn handle_tools_list(id: Option, _state: &ServerState) -> Result { + log::info!("Handling tools/list request"); + let tools = vec![ + ToolDefinition { + name: "getAccountInfo".to_string(), + description: Some("Returns all information associated with the account".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "pubkey": { + "type": "string", + "description": "Account public key (base58 encoded)" + }, + "commitment": { + "type": "string", + "description": "Commitment level", + "enum": ["processed", "confirmed", "finalized"] + }, + "encoding": { + "type": "string", + "description": "Encoding format", + "enum": ["base58", "base64", "jsonParsed"] } }, - { - "name": "getBalance", - "description": "Returns the balance of the account", - "parameters": { - "pubkey": "Account public key (base58 encoded)" + "required": ["pubkey"] + }), + }, + ToolDefinition { + name: "getBalance".to_string(), + description: Some("Returns the balance of the account".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "pubkey": { + "type": "string", + "description": "Account public key (base58 encoded)" + }, + "commitment": { + "type": "string", + "description": "Commitment level", + "enum": ["processed", "confirmed", "finalized"] } - } - ], - "resources": [ - { - "name": "solana_docs", - "description": "Solana documentation", - "uri": "solana://docs/core" }, - { - "name": "rpc_docs", - "description": "RPC API documentation", - "uri": "solana://docs/rpc" - } - ] - }), + "required": ["pubkey"] + }), + }, + ToolDefinition { + name: "getProgramAccounts".to_string(), + description: Some("Returns all accounts owned by the program".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "programId": { + "type": "string", + "description": "Program public key (base58 encoded)" + }, + "config": { + "type": "object", + "description": "Configuration object", + "properties": { + "encoding": { + "type": "string", + "enum": ["base58", "base64", "jsonParsed"] + }, + "commitment": { + "type": "string", + "enum": ["processed", "confirmed", "finalized"] + } + } + } + }, + "required": ["programId"] + }), + }, + ToolDefinition { + name: "getTransaction".to_string(), + description: Some("Returns transaction details".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": { + "signature": { + "type": "string", + "description": "Transaction signature (base58 encoded)" + }, + "commitment": { + "type": "string", + "enum": ["processed", "confirmed", "finalized"] + } + }, + "required": ["signature"] + }), + }, + ToolDefinition { + name: "getHealth".to_string(), + description: Some("Returns the current health of the node".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": {} + }), + }, + ToolDefinition { + name: "getVersion".to_string(), + description: Some("Returns the current Solana version".to_string()), + input_schema: serde_json::json!({ + "type": "object", + "properties": {} + }), + }, + ]; + + let tools_len = tools.len(); + log::debug!("Returning {} tools", tools_len); + + let response = ToolsListResponse { + tools, + next_cursor: None, + meta: None, + }; + + Ok(create_success_response( + serde_json::to_value(response).unwrap(), id.and_then(|v| v.as_u64()).unwrap_or(0), )) } -pub async fn handle_request(request: &str) -> Result { - let message: JsonRpcMessage = serde_json::from_str(request).map_err(|_| { - anyhow::anyhow!("Invalid JSON-RPC request") +use std::sync::Arc; +use tokio::sync::RwLock; +use crate::server::ServerState; +use solana_sdk::pubkey::Pubkey; + +pub async fn handle_request(request: &str, state: Arc>) -> Result { + log::debug!("Received request: {}", request); + let message: JsonRpcMessage = serde_json::from_str(request).map_err(|e| { + log::error!("Failed to parse JSON-RPC request: {}", e); + anyhow::anyhow!("Invalid JSON-RPC request: {}", e) })?; match message { JsonRpcMessage::Request(req) => { + let mut state_guard = state.write().await; + let protocol_version = Some(state_guard.protocol_version.as_str()); + if req.jsonrpc != JsonRpcVersion::V2 { + log::error!("Invalid JSON-RPC version: {:?}", req.jsonrpc); return Ok(create_error_response( -32600, "Invalid Request: jsonrpc version must be 2.0".to_string(), req.id, + protocol_version, )); } - match req.method.as_str() { - "initialize" => handle_initialize(req.params, Some(serde_json::Value::Number(req.id.into()))).await, - "cancelled" => handle_cancelled(req.params, Some(serde_json::Value::Number(req.id.into()))).await, - "tools/list" => handle_tools_list(Some(serde_json::Value::Number(req.id.into()))).await, - _ => Ok(create_error_response( - -32601, - "Method not found".to_string(), + // Only allow initialize method if not initialized + if !state_guard.initialized && req.method.as_str() != "initialize" { + log::error!("Server not initialized, received method: {}", req.method); + return Ok(create_error_response( + -32002, + "Server not initialized".to_string(), req.id, - )), + protocol_version, + )); + } + + log::info!("Handling method: {}", req.method); + match req.method.as_str() { + "initialize" => { + let response = handle_initialize( + req.params, + Some(serde_json::Value::Number(req.id.into())), + &state_guard + ).await?; + + if response.is_success() { + state_guard.initialized = true; + log::info!("Server initialized successfully"); + } else { + log::error!("Server initialization failed"); + } + Ok(response) + }, + "cancelled" => handle_cancelled(req.params, Some(serde_json::Value::Number(req.id.into())), &state_guard).await, + "tools/list" => handle_tools_list(Some(serde_json::Value::Number(req.id.into())), &state_guard).await, + + // Account methods + "getAccountInfo" => { + log::info!("Getting account info"); + let params = req.params.ok_or_else(|| anyhow::anyhow!("Missing params"))?; + let pubkey_str = params.get("pubkey") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing pubkey parameter"))?; + let pubkey = Pubkey::try_from(pubkey_str)?; + + let state = state.read().await; + let result = crate::rpc::accounts::get_account_info(&state.rpc_client, &pubkey).await?; + Ok(create_success_response(result, req.id)) + }, + "getBalance" => { + log::info!("Getting balance"); + let params = req.params.ok_or_else(|| anyhow::anyhow!("Missing params"))?; + let pubkey_str = params.get("pubkey") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing pubkey parameter"))?; + let pubkey = Pubkey::try_from(pubkey_str)?; + + let state = state.read().await; + let result = crate::rpc::accounts::get_balance(&state.rpc_client, &pubkey).await?; + Ok(create_success_response(result, req.id)) + }, + "getProgramAccounts" => { + log::info!("Getting program accounts"); + let params = req.params.ok_or_else(|| anyhow::anyhow!("Missing params"))?; + let program_id_str = params.get("programId") + .and_then(|v| v.as_str()) + .ok_or_else(|| anyhow::anyhow!("Missing programId parameter"))?; + let program_id = Pubkey::try_from(program_id_str)?; + + let state = state.read().await; + let result = crate::rpc::accounts::get_program_accounts(&state.rpc_client, &program_id).await?; + Ok(create_success_response(result, req.id)) + }, + + "resources/templates/list" => { + log::info!("Handling resources/templates/list request"); + let response = ResourcesListResponse { + resources: vec![], + next_cursor: None, + meta: None, + }; + + Ok(create_success_response( + serde_json::to_value(response).unwrap(), + req.id + )) + }, + "resources/list" => { + log::info!("Handling resources/list request"); + let resources = vec![ + Resource { + uri: Url::parse("https://docs.solana.com/developing/clients/jsonrpc-api").unwrap(), + name: "Documentation".to_string(), + description: Some("Solana API documentation".to_string()), + mime_type: Some("text/html".to_string()), + } + ]; + + let response = ResourcesListResponse { + resources, + next_cursor: None, + meta: None, + }; + + Ok(create_success_response( + serde_json::to_value(response).unwrap(), + req.id + )) + }, + _ => { + log::error!("Method not found: {}", req.method); + Ok(create_error_response( + -32601, + "Method not found".to_string(), + req.id, + protocol_version, + )) + }, } }, JsonRpcMessage::Response(_) => { + log::error!("Received response message when expecting request"); Ok(create_error_response( -32600, "Invalid Request: expected request message".to_string(), 0, + None, )) }, - JsonRpcMessage::Notification(_) => { - Ok(create_error_response( - -32600, - "Invalid Request: notifications not supported".to_string(), - 0, - )) + JsonRpcMessage::Notification(notification) => { + match notification.method.as_str() { + "notifications/initialized" => { + log::info!("Received initialized notification"); + Ok(JsonRpcMessage::Notification(notification)) + }, + _ => { + log::error!("Unsupported notification: {}", notification.method); + Ok(create_error_response( + -32600, + format!("Unsupported notification: {}", notification.method), + 0, + None, + )) + } + } }, } } diff --git a/src/transport.rs b/src/transport.rs index a22711c..29016fb 100644 --- a/src/transport.rs +++ b/src/transport.rs @@ -18,32 +18,41 @@ impl Default for JsonRpcVersion { } #[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct JsonRpcRequest { pub jsonrpc: JsonRpcVersion, pub id: u64, pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] pub params: Option, } #[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct JsonRpcResponse { pub jsonrpc: JsonRpcVersion, pub id: u64, + #[serde(skip_serializing_if = "Option::is_none")] pub result: Option, + #[serde(skip_serializing_if = "Option::is_none")] pub error: Option, } #[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct JsonRpcNotification { pub jsonrpc: JsonRpcVersion, pub method: String, + #[serde(skip_serializing_if = "Option::is_none")] pub params: Option, } #[derive(Debug, Serialize, Deserialize)] +#[serde(rename_all = "camelCase")] pub struct JsonRpcError { pub code: i32, pub message: String, + #[serde(skip_serializing_if = "Option::is_none")] pub data: Option, } @@ -55,8 +64,18 @@ pub enum JsonRpcMessage { Notification(JsonRpcNotification), } +impl JsonRpcMessage { + pub fn is_success(&self) -> bool { + match self { + JsonRpcMessage::Response(resp) => resp.error.is_none(), + _ => false, + } + } +} + pub trait Transport { fn send(&self, message: &JsonRpcMessage) -> Result<()>; + fn send_raw(&self, json: &str) -> Result<()>; fn receive(&self) -> Result; fn open(&self) -> Result<()>; fn close(&self) -> Result<()>; @@ -77,40 +96,70 @@ impl CustomStdioTransport { } impl Transport for CustomStdioTransport { - fn send(&self, message: &JsonRpcMessage) -> Result<()> { - let json = serde_json::to_string(&message)?; + fn send_raw(&self, json: &str) -> Result<()> { let mut writer = self.writer.lock().map_err(|_| { io::Error::new(io::ErrorKind::Other, "Failed to acquire writer lock") })?; + let json = json.trim(); writeln!(writer, "{}", json)?; writer.flush()?; Ok(()) } + fn send(&self, message: &JsonRpcMessage) -> Result<()> { + log::debug!("Sending message: {}", serde_json::to_string(message)?); + let mut writer = self.writer.lock().map_err(|_| { + let err = io::Error::new(io::ErrorKind::Other, "Failed to acquire writer lock"); + log::error!("Transport error: {}", err); + err + })?; + let mut buf = Vec::new(); + let mut ser = serde_json::Serializer::new(&mut buf); + message.serialize(&mut ser)?; + writer.write_all(&buf)?; + writer.write_all(b"\n")?; + writer.flush()?; + Ok(()) + } + fn receive(&self) -> Result { let mut line = String::new(); let mut reader = self.reader.lock().map_err(|_| { - io::Error::new(io::ErrorKind::Other, "Failed to acquire reader lock") + let err = io::Error::new(io::ErrorKind::Other, "Failed to acquire reader lock"); + log::error!("Transport error: {}", err); + err })?; match reader.read_line(&mut line) { - Ok(0) => Err(io::Error::new(io::ErrorKind::UnexpectedEof, "Connection closed").into()), + Ok(0) => { + let err = io::Error::new(io::ErrorKind::UnexpectedEof, "Connection closed"); + log::info!("Transport connection closed"); + Err(err.into()) + }, Ok(_) => { if line.trim().is_empty() { - return Err(io::Error::new(io::ErrorKind::InvalidData, "Empty message received").into()); + let err = io::Error::new(io::ErrorKind::InvalidData, "Empty message received"); + log::error!("Transport error: {}", err); + return Err(err.into()); } + log::debug!("Received raw message: {}", line.trim()); let message = serde_json::from_str(&line)?; Ok(message) }, - Err(e) => Err(e.into()), + Err(e) => { + log::error!("Transport error: {}", e); + Err(e.into()) + }, } } fn open(&self) -> Result<()> { + log::info!("Opening stdio transport"); Ok(()) } fn close(&self) -> Result<()> { + log::info!("Closing stdio transport"); Ok(()) } } diff --git a/test.sh b/test.sh new file mode 100755 index 0000000..a197901 --- /dev/null +++ b/test.sh @@ -0,0 +1,2 @@ +#!/bin/bash +tr -d '\n' < test_request.json && echo diff --git a/test_request.json b/test_request.json index 19c633f..d2a4237 100644 --- a/test_request.json +++ b/test_request.json @@ -1 +1 @@ -{"jsonrpc": "2.0", "method": "tools/list", "id": 1} +{"jsonrpc":"2.0","id":1,"method":"tools/list"} diff --git a/test_sequence.sh b/test_sequence.sh new file mode 100755 index 0000000..2738e1a --- /dev/null +++ b/test_sequence.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +# Start the server and send both requests through the same pipe +(echo '{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","clientInfo":{"name":"test-client","version":"1.0.0"},"capabilities":{}}}'; sleep 1; echo '{"jsonrpc":"2.0","id":2,"method":"tools/list"}') | ./target/release/solana-mcp-server diff --git a/tests/e2e.rs b/tests/e2e.rs index 7a7c4e6..6488b4b 100644 --- a/tests/e2e.rs +++ b/tests/e2e.rs @@ -35,7 +35,7 @@ impl Transport for TestTransport { let json = match message { JsonRpcMessage::Request(req) => { json!({ - "jsonrpc": "2.0", + "jsonrpc": JsonRpcVersion::V2, "id": req.id, "method": req.method, "params": req.params @@ -43,7 +43,7 @@ impl Transport for TestTransport { }, JsonRpcMessage::Response(resp) => { json!({ - "jsonrpc": "2.0", + "jsonrpc": JsonRpcVersion::V2, "id": resp.id, "result": resp.result, "error": resp.error @@ -114,7 +114,7 @@ fn setup_mock_server() -> TestTransport { match (method, auth) { ("initialize", "test_key_123") => { server_transport.tx.send(json!({ - "jsonrpc": "2.0", + "jsonrpc": JsonRpcVersion::V2, "id": id, "result": { "server": { @@ -130,7 +130,7 @@ fn setup_mock_server() -> TestTransport { }, (_, "invalid_key") => { server_transport.tx.send(json!({ - "jsonrpc": "2.0", + "jsonrpc": JsonRpcVersion::V2, "id": id, "error": { "code": -32601, @@ -143,14 +143,14 @@ fn setup_mock_server() -> TestTransport { match tool_name { "get_slot" => { server_transport.tx.send(json!({ - "jsonrpc": "2.0", + "jsonrpc": JsonRpcVersion::V2, "id": id, "result": 12345 })).unwrap(); }, _ => { server_transport.tx.send(json!({ - "jsonrpc": "2.0", + "jsonrpc": JsonRpcVersion::V2, "id": id, "error": { "code": -32601, @@ -187,7 +187,7 @@ async fn test_server_initialization() { match response { JsonRpcMessage::Response(resp) => { - assert_eq!(serde_json::to_string(&resp.jsonrpc).unwrap(), "\"2.0\""); + assert_eq!(resp.jsonrpc, JsonRpcVersion::V2); assert_eq!(resp.id, 1); let result = resp.result.unwrap(); assert!(result["server"]["name"].as_str().unwrap().contains("solana-mcp")); @@ -215,7 +215,7 @@ async fn test_invalid_api_key() { match response { JsonRpcMessage::Response(resp) => { - assert_eq!(serde_json::to_string(&resp.jsonrpc).unwrap(), "\"2.0\""); + assert_eq!(resp.jsonrpc, JsonRpcVersion::V2); assert_eq!(resp.id, 1); let error = resp.error.unwrap(); assert_eq!(error.code, -32601); @@ -259,7 +259,7 @@ async fn test_tool_execution() { match response { JsonRpcMessage::Response(resp) => { - assert_eq!(serde_json::to_string(&resp.jsonrpc).unwrap(), "\"2.0\""); + assert_eq!(resp.jsonrpc, JsonRpcVersion::V2); assert_eq!(resp.id, 2); assert!(resp.result.is_some()); }, @@ -301,7 +301,7 @@ async fn test_invalid_tool() { match response { JsonRpcMessage::Response(resp) => { - assert_eq!(serde_json::to_string(&resp.jsonrpc).unwrap(), "\"2.0\""); + assert_eq!(resp.jsonrpc, JsonRpcVersion::V2); assert_eq!(resp.id, 2); let error = resp.error.unwrap(); assert_eq!(error.code, -32601);