Skip to content

Add mcptest package for in-process MCP testing #149

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Apr 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 40 additions & 9 deletions client/transport/stdio.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,20 @@ type Stdio struct {
notifyMu sync.RWMutex
}

// NewIO returns a new stdio-based transport using existing input, output, and
// logging streams instead of spawning a subprocess.
// This is useful for testing and simulating client behavior.
func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio {
return &Stdio{
stdin: output,
stdout: bufio.NewReader(input),
stderr: logging,

responses: make(map[int64]chan *JSONRPCResponse),
done: make(chan struct{}),
}
}

// NewStdio creates a new stdio transport to communicate with a subprocess.
// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
// Returns an error if the subprocess cannot be started or the pipes cannot be created.
Expand All @@ -55,6 +69,26 @@ func NewStdio(
}

func (c *Stdio) Start(ctx context.Context) error {
if err := c.spawnCommand(ctx); err != nil {
return err
}

ready := make(chan struct{})
go func() {
close(ready)
c.readResponses()
}()
<-ready

return nil
}

// spawnCommand spawns a new process running c.command.
func (c *Stdio) spawnCommand(ctx context.Context) error {
if c.command == "" {
return nil
}

cmd := exec.CommandContext(ctx, c.command, c.args...)

mergedEnv := os.Environ()
Expand Down Expand Up @@ -86,14 +120,6 @@ func (c *Stdio) Start(ctx context.Context) error {
return fmt.Errorf("failed to start command: %w", err)
}

// Start reading responses in a goroutine and wait for it to be ready
ready := make(chan struct{})
go func() {
close(ready)
c.readResponses()
}()
<-ready

return nil
}

Expand All @@ -107,7 +133,12 @@ func (c *Stdio) Close() error {
if err := c.stderr.Close(); err != nil {
return fmt.Errorf("failed to close stderr: %w", err)
}
return c.cmd.Wait()

if c.cmd != nil {
return c.cmd.Wait()
}

return nil
}

// OnNotification registers a handler function to be called when notifications are received.
Expand Down
154 changes: 154 additions & 0 deletions mcptest/mcptest.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
// Package mcptest implements helper functions for testing MCP servers.
package mcptest

import (
"bytes"
"context"
"fmt"
"io"
"log"
"sync"
"testing"

"github.com/mark3labs/mcp-go/client"
"github.com/mark3labs/mcp-go/client/transport"
"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/server"
)

// Server encapsulates an MCP server and manages resources like pipes and context.
type Server struct {
name string
tools []server.ServerTool

ctx context.Context
cancel func()

serverReader *io.PipeReader
serverWriter *io.PipeWriter
clientReader *io.PipeReader
clientWriter *io.PipeWriter

logBuffer bytes.Buffer

transport transport.Interface
client *client.Client

wg sync.WaitGroup
}

// NewServer starts a new MCP server with the provided tools and returns the server instance.
func NewServer(t *testing.T, tools ...server.ServerTool) (*Server, error) {
server := NewUnstartedServer(t)
server.AddTools(tools...)

if err := server.Start(); err != nil {
return nil, err
}

return server, nil
}

// NewUnstartedServer creates a new MCP server instance with the given name, but does not start the server.
// Useful for tests where you need to add tools before starting the server.
func NewUnstartedServer(t *testing.T) *Server {
server := &Server{
name: t.Name(),
}

// Use t.Context() once we switch to go >= 1.24
ctx := context.TODO()

// Set up context with cancellation, used to stop the server
server.ctx, server.cancel = context.WithCancel(ctx)

// Set up pipes for client-server communication
server.serverReader, server.clientWriter = io.Pipe()
server.clientReader, server.serverWriter = io.Pipe()

// Return the configured server
return server
}

// AddTools adds multiple tools to an unstarted server.
func (s *Server) AddTools(tools ...server.ServerTool) {
s.tools = append(s.tools, tools...)
}

// AddTool adds a tool to an unstarted server.
func (s *Server) AddTool(tool mcp.Tool, handler server.ToolHandlerFunc) {
s.tools = append(s.tools, server.ServerTool{
Tool: tool,
Handler: handler,
})
}

// Start starts the server in a goroutine. Make sure to defer Close() after Start().
// When using NewServer(), the returned server is already started.
func (s *Server) Start() error {
s.wg.Add(1)

// Start the MCP server in a goroutine
go func() {
defer s.wg.Done()

mcpServer := server.NewMCPServer(s.name, "1.0.0")

mcpServer.AddTools(s.tools...)

logger := log.New(&s.logBuffer, "", 0)

stdioServer := server.NewStdioServer(mcpServer)
stdioServer.SetErrorLogger(logger)

if err := stdioServer.Listen(s.ctx, s.serverReader, s.serverWriter); err != nil {
logger.Println("StdioServer.Listen failed:", err)
}
}()

s.transport = transport.NewIO(s.clientReader, s.clientWriter, io.NopCloser(&s.logBuffer))
if err := s.transport.Start(s.ctx); err != nil {
return fmt.Errorf("transport.Start(): %w", err)
}

s.client = client.NewClient(s.transport)

var initReq mcp.InitializeRequest
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
if _, err := s.client.Initialize(s.ctx, initReq); err != nil {
return fmt.Errorf("client.Initialize(): %w", err)
}

return nil
}

// Close stops the server and cleans up resources like temporary directories.
func (s *Server) Close() {
if s.transport != nil {
s.transport.Close()
s.transport = nil
s.client = nil
}

if s.cancel != nil {
s.cancel()
s.cancel = nil
}

// Wait for server goroutine to finish
s.wg.Wait()

s.serverWriter.Close()
s.serverReader.Close()
s.serverReader, s.serverWriter = nil, nil

s.clientWriter.Close()
s.clientReader.Close()
s.clientReader, s.clientWriter = nil, nil
}

// Client returns an MCP client connected to the server.
// The client is already initialized, i.e. you do _not_ need to call Client.Initialize().
func (s *Server) Client() *client.Client {
return s.client
}
79 changes: 79 additions & 0 deletions mcptest/mcptest_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package mcptest_test

import (
"context"
"fmt"
"strings"
"testing"

"github.com/mark3labs/mcp-go/mcp"
"github.com/mark3labs/mcp-go/mcptest"
"github.com/mark3labs/mcp-go/server"
)

func TestServer(t *testing.T) {
ctx := context.Background()

srv, err := mcptest.NewServer(t, server.ServerTool{
Tool: mcp.NewTool("hello",
mcp.WithDescription("Says hello to the provided name, or world."),
mcp.WithString("name", mcp.Description("The name to say hello to.")),
),
Handler: helloWorldHandler,
})
if err != nil {
t.Fatal(err)
}
defer srv.Close()

client := srv.Client()

var req mcp.CallToolRequest
req.Params.Name = "hello"
req.Params.Arguments = map[string]any{
"name": "Claude",
}

result, err := client.CallTool(ctx, req)
if err != nil {
t.Fatal("CallTool:", err)
}

got, err := resultToString(result)
if err != nil {
t.Fatal(err)
}

want := "Hello, Claude!"
if got != want {
t.Errorf("Got %q, want %q", got, want)
}
}

func helloWorldHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
// Extract name from request arguments
name, ok := request.Params.Arguments["name"].(string)
if !ok {
name = "World"
}

return mcp.NewToolResultText(fmt.Sprintf("Hello, %s!", name)), nil
}

func resultToString(result *mcp.CallToolResult) (string, error) {
var b strings.Builder

for _, content := range result.Content {
text, ok := content.(mcp.TextContent)
if !ok {
return "", fmt.Errorf("unsupported content type: %T", content)
}
b.WriteString(text.Text)
}

if result.IsError {
return "", fmt.Errorf("%s", b.String())
}

return b.String(), nil
}