Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ enum StatusCode(val statusType: StatusType):
case TEST_FAILED extends StatusCode(StatusType.UserError)
case INVALID_MODEL_CONFIG extends StatusCode(StatusType.UserError)
case INVALID_MESSAGE_TYPE extends StatusCode(StatusType.UserError)
case JSON_RPC_ERROR extends StatusCode(StatusType.UserError)

// Internal errors
case INTERNAL_ERROR extends StatusCode(StatusType.InternalError)
Expand Down
92 changes: 92 additions & 0 deletions ai-agent/src/main/scala/wvlet/ai/agent/mcp/JsonRpc.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package wvlet.ai.agent.mcp

import wvlet.airframe.codec.MessageCodec
import wvlet.ai.agent.core.{AIException, StatusCode}

/**
* JSON-RPC 2.0 protocol messages for MCP communication.
*
* Based on: https://www.jsonrpc.org/specification
*/
object JsonRpc:

/**
* JSON-RPC request message
*/
case class Request(
jsonrpc: String = "2.0",
id: Option[Any],
method: String,
params: Option[Map[String, Any]] = None
):
def withParams(params: Map[String, Any]): Request = copy(params = Some(params))
def noParams: Request = copy(params = None)

/**
* JSON-RPC response message
*/
case class Response(
jsonrpc: String = "2.0",
id: Option[Any],
result: Option[Any] = None,
error: Option[ErrorObject] = None
)

/**
* JSON-RPC error object
*/
case class ErrorObject(code: Int, message: String, data: Option[Any] = None)

/**
* JSON-RPC notification (request without id)
*/
case class Notification(
jsonrpc: String = "2.0",
method: String,
params: Option[Map[String, Any]] = None
):
def withParams(params: Map[String, Any]): Notification = copy(params = Some(params))
def noParams: Notification = copy(params = None)

/**
* Standard JSON-RPC error codes
*/
object ErrorCode:
val ParseError = -32700
val InvalidRequest = -32600
val MethodNotFound = -32601
val InvalidParams = -32602
val InternalError = -32603

/**
* Parse a JSON string into a JSON-RPC message
* @throws AIException
* if parsing fails
*/
def parse(json: String): Request | Response | Notification =
try
val parsed = MessageCodec.fromJson[Map[String, Any]](json)
// Check if it has an id field to distinguish request/response from notification
val hasId = parsed.contains("id")
val hasMethod = parsed.contains("method")
val hasResult = parsed.contains("result")
val hasError = parsed.contains("error")

if hasMethod && hasId then
val codec = MessageCodec.of[Request]
codec.fromMap(parsed)
else if hasMethod && !hasId then
val codec = MessageCodec.of[Notification]
codec.fromMap(parsed)
else if hasResult || hasError then
val codec = MessageCodec.of[Response]
codec.fromMap(parsed)
else
throw StatusCode.JSON_RPC_ERROR.newException("Invalid JSON-RPC message format")
catch
case e: AIException =>
throw e
case e: Exception =>
throw StatusCode.JSON_RPC_ERROR.newException(s"Failed to parse JSON: ${e.getMessage}", e)

end JsonRpc
61 changes: 61 additions & 0 deletions ai-agent/src/main/scala/wvlet/ai/agent/mcp/MCPClient.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
package wvlet.ai.agent.mcp

import wvlet.airframe.rx.Rx
import wvlet.ai.agent.mcp.MCPMessages.*

/**
* Client interface for communicating with MCP servers.
*/
trait MCPClient:
/**
* Initialize the connection with the MCP server.
*
* @return
* Server capabilities and information
*/
def initialize(): Rx[InitializeResult]

/**
* List available tools from the MCP server.
*
* @return
* List of available tools
*/
def listTools(): Rx[ListToolsResult]

/**
* Call a tool on the MCP server.
*
* @param toolName
* Name of the tool to call
* @param arguments
* Arguments to pass to the tool
* @return
* Tool execution result
*/
def callTool(toolName: String, arguments: Map[String, Any]): Rx[CallToolResult]

/**
* Send a raw JSON-RPC request to the server.
*
* @param request
* The JSON-RPC request
* @return
* The JSON-RPC response
*/
def sendRequest(request: JsonRpc.Request): Rx[JsonRpc.Response]

/**
* Close the connection to the MCP server.
*/
def close(): Unit

/**
* Check if the client is connected to the server.
*
* @return
* true if connected, false otherwise
*/
def isConnected: Boolean

end MCPClient
139 changes: 139 additions & 0 deletions ai-agent/src/main/scala/wvlet/ai/agent/mcp/MCPMessages.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
package wvlet.ai.agent.mcp

/**
* MCP-specific message types that extend JSON-RPC protocol.
*
* Based on: https://modelcontextprotocol.io/specification/basic/messages/
*/
object MCPMessages:

/**
* MCP protocol version
*/
val PROTOCOL_VERSION = "2024-11-05"

/**
* Client capabilities
*/
case class ClientCapabilities(
experimental: Option[Map[String, Any]] = None,
sampling: Option[Map[String, Any]] = None
)

/**
* Server capabilities
*/
case class ServerCapabilities(
experimental: Option[Map[String, Any]] = None,
logging: Option[Map[String, Any]] = None,
prompts: Option[PromptsCapability] = None,
resources: Option[ResourcesCapability] = None,
tools: Option[ToolsCapability] = None
)

case class PromptsCapability(listChanged: Option[Boolean] = None)
case class ResourcesCapability(
subscribe: Option[Boolean] = None,
listChanged: Option[Boolean] = None
)

case class ToolsCapability(listChanged: Option[Boolean] = None)

/**
* Implementation details
*/
case class Implementation(name: String, version: String)

/**
* Initialize request parameters
*/
case class InitializeParams(
protocolVersion: String = PROTOCOL_VERSION,
capabilities: ClientCapabilities,
clientInfo: Implementation
)

/**
* Initialize result
*/
case class InitializeResult(
protocolVersion: String = PROTOCOL_VERSION,
capabilities: ServerCapabilities,
serverInfo: Implementation
)

/**
* Tool definition from MCP server
*/
case class MCPTool(
name: String,
description: Option[String] = None,
inputSchema: Map[String, Any]
)

/**
* List tools result
*/
case class ListToolsResult(tools: Seq[MCPTool])

/**
* Tool call request
*/
case class CallToolParams(name: String, arguments: Option[Map[String, Any]] = None)

/**
* Tool call result
*/
case class CallToolResult(content: Seq[Map[String, Any]], isError: Option[Boolean] = None)

/**
* Content types for tool results
*/
sealed trait ToolResultContent
case class TextContent(`type`: String = "text", text: String) extends ToolResultContent
case class ImageContent(`type`: String = "image", data: String, mimeType: String)
extends ToolResultContent

case class ResourceContent(`type`: String = "resource", resource: ResourceReference)
extends ToolResultContent

/**
* Resource reference
*/
case class ResourceReference(uri: String, mimeType: Option[String] = None)

/**
* Create an initialize request
*/
def createInitializeRequest(clientName: String, clientVersion: String): JsonRpc.Request = JsonRpc
.Request(
id = Some("init"),
method = "initialize",
params = Some(
Map(
"protocolVersion" -> PROTOCOL_VERSION,
"capabilities" -> Map.empty[String, Any],
"clientInfo" -> Map("name" -> clientName, "version" -> clientVersion)
)
)
)

/**
* Create a list tools request
*/
def createListToolsRequest(): JsonRpc.Request = JsonRpc.Request(
id = Some("list-tools"),
method = "tools/list"
)

/**
* Create a tool call request
*/
def createCallToolRequest(toolName: String, arguments: Map[String, Any]): JsonRpc.Request =
JsonRpc.Request(
id = Some(s"call-$toolName-${java.util.UUID.randomUUID()}"),
method = "tools/call",
params = Some(Map("name" -> toolName, "arguments" -> arguments))
)

end MCPMessages
Loading
Loading