Skip to content

Commit 5a5781d

Browse files
authored
Add mcptest package for in-process MCP testing (#149)
* feat(transport): Add the `NewIO` function. This function allows creating a `*transport.Stdio` using provided `io.Reader` and `io.Writer`. This allows creating an MCP client to a server running in the same process, which significanly simplifies testing. * feat(mcptest): Add package help testing. The new `mcptest` package provides functions for setting up an in-process MCP server and an MCP client connected to it. This allows testing MCP implementations end-to-end without spawning any processes. The `mcptest` package itself has a unit test that demonstrates how to use the package. * feat(mcptest): Return an initialized client. This allows to omit the initialization code in tests, making them less verbose. * fix(mcptest): Close pipes on shutdown. * refactor(transport): Rename `startProc` to `spawnCommand`. The new name is more descriptive.
1 parent 33c98f1 commit 5a5781d

File tree

3 files changed

+273
-9
lines changed

3 files changed

+273
-9
lines changed

client/transport/stdio.go

+40-9
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,20 @@ type Stdio struct {
3333
notifyMu sync.RWMutex
3434
}
3535

36+
// NewIO returns a new stdio-based transport using existing input, output, and
37+
// logging streams instead of spawning a subprocess.
38+
// This is useful for testing and simulating client behavior.
39+
func NewIO(input io.Reader, output io.WriteCloser, logging io.ReadCloser) *Stdio {
40+
return &Stdio{
41+
stdin: output,
42+
stdout: bufio.NewReader(input),
43+
stderr: logging,
44+
45+
responses: make(map[int64]chan *JSONRPCResponse),
46+
done: make(chan struct{}),
47+
}
48+
}
49+
3650
// NewStdio creates a new stdio transport to communicate with a subprocess.
3751
// It launches the specified command with given arguments and sets up stdin/stdout pipes for communication.
3852
// Returns an error if the subprocess cannot be started or the pipes cannot be created.
@@ -55,6 +69,26 @@ func NewStdio(
5569
}
5670

5771
func (c *Stdio) Start(ctx context.Context) error {
72+
if err := c.spawnCommand(ctx); err != nil {
73+
return err
74+
}
75+
76+
ready := make(chan struct{})
77+
go func() {
78+
close(ready)
79+
c.readResponses()
80+
}()
81+
<-ready
82+
83+
return nil
84+
}
85+
86+
// spawnCommand spawns a new process running c.command.
87+
func (c *Stdio) spawnCommand(ctx context.Context) error {
88+
if c.command == "" {
89+
return nil
90+
}
91+
5892
cmd := exec.CommandContext(ctx, c.command, c.args...)
5993

6094
mergedEnv := os.Environ()
@@ -86,14 +120,6 @@ func (c *Stdio) Start(ctx context.Context) error {
86120
return fmt.Errorf("failed to start command: %w", err)
87121
}
88122

89-
// Start reading responses in a goroutine and wait for it to be ready
90-
ready := make(chan struct{})
91-
go func() {
92-
close(ready)
93-
c.readResponses()
94-
}()
95-
<-ready
96-
97123
return nil
98124
}
99125

@@ -114,7 +140,12 @@ func (c *Stdio) Close() error {
114140
if err := c.stderr.Close(); err != nil {
115141
return fmt.Errorf("failed to close stderr: %w", err)
116142
}
117-
return c.cmd.Wait()
143+
144+
if c.cmd != nil {
145+
return c.cmd.Wait()
146+
}
147+
148+
return nil
118149
}
119150

120151
// OnNotification registers a handler function to be called when notifications are received.

mcptest/mcptest.go

+154
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
// Package mcptest implements helper functions for testing MCP servers.
2+
package mcptest
3+
4+
import (
5+
"bytes"
6+
"context"
7+
"fmt"
8+
"io"
9+
"log"
10+
"sync"
11+
"testing"
12+
13+
"github.com/mark3labs/mcp-go/client"
14+
"github.com/mark3labs/mcp-go/client/transport"
15+
"github.com/mark3labs/mcp-go/mcp"
16+
"github.com/mark3labs/mcp-go/server"
17+
)
18+
19+
// Server encapsulates an MCP server and manages resources like pipes and context.
20+
type Server struct {
21+
name string
22+
tools []server.ServerTool
23+
24+
ctx context.Context
25+
cancel func()
26+
27+
serverReader *io.PipeReader
28+
serverWriter *io.PipeWriter
29+
clientReader *io.PipeReader
30+
clientWriter *io.PipeWriter
31+
32+
logBuffer bytes.Buffer
33+
34+
transport transport.Interface
35+
client *client.Client
36+
37+
wg sync.WaitGroup
38+
}
39+
40+
// NewServer starts a new MCP server with the provided tools and returns the server instance.
41+
func NewServer(t *testing.T, tools ...server.ServerTool) (*Server, error) {
42+
server := NewUnstartedServer(t)
43+
server.AddTools(tools...)
44+
45+
if err := server.Start(); err != nil {
46+
return nil, err
47+
}
48+
49+
return server, nil
50+
}
51+
52+
// NewUnstartedServer creates a new MCP server instance with the given name, but does not start the server.
53+
// Useful for tests where you need to add tools before starting the server.
54+
func NewUnstartedServer(t *testing.T) *Server {
55+
server := &Server{
56+
name: t.Name(),
57+
}
58+
59+
// Use t.Context() once we switch to go >= 1.24
60+
ctx := context.TODO()
61+
62+
// Set up context with cancellation, used to stop the server
63+
server.ctx, server.cancel = context.WithCancel(ctx)
64+
65+
// Set up pipes for client-server communication
66+
server.serverReader, server.clientWriter = io.Pipe()
67+
server.clientReader, server.serverWriter = io.Pipe()
68+
69+
// Return the configured server
70+
return server
71+
}
72+
73+
// AddTools adds multiple tools to an unstarted server.
74+
func (s *Server) AddTools(tools ...server.ServerTool) {
75+
s.tools = append(s.tools, tools...)
76+
}
77+
78+
// AddTool adds a tool to an unstarted server.
79+
func (s *Server) AddTool(tool mcp.Tool, handler server.ToolHandlerFunc) {
80+
s.tools = append(s.tools, server.ServerTool{
81+
Tool: tool,
82+
Handler: handler,
83+
})
84+
}
85+
86+
// Start starts the server in a goroutine. Make sure to defer Close() after Start().
87+
// When using NewServer(), the returned server is already started.
88+
func (s *Server) Start() error {
89+
s.wg.Add(1)
90+
91+
// Start the MCP server in a goroutine
92+
go func() {
93+
defer s.wg.Done()
94+
95+
mcpServer := server.NewMCPServer(s.name, "1.0.0")
96+
97+
mcpServer.AddTools(s.tools...)
98+
99+
logger := log.New(&s.logBuffer, "", 0)
100+
101+
stdioServer := server.NewStdioServer(mcpServer)
102+
stdioServer.SetErrorLogger(logger)
103+
104+
if err := stdioServer.Listen(s.ctx, s.serverReader, s.serverWriter); err != nil {
105+
logger.Println("StdioServer.Listen failed:", err)
106+
}
107+
}()
108+
109+
s.transport = transport.NewIO(s.clientReader, s.clientWriter, io.NopCloser(&s.logBuffer))
110+
if err := s.transport.Start(s.ctx); err != nil {
111+
return fmt.Errorf("transport.Start(): %w", err)
112+
}
113+
114+
s.client = client.NewClient(s.transport)
115+
116+
var initReq mcp.InitializeRequest
117+
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
118+
if _, err := s.client.Initialize(s.ctx, initReq); err != nil {
119+
return fmt.Errorf("client.Initialize(): %w", err)
120+
}
121+
122+
return nil
123+
}
124+
125+
// Close stops the server and cleans up resources like temporary directories.
126+
func (s *Server) Close() {
127+
if s.transport != nil {
128+
s.transport.Close()
129+
s.transport = nil
130+
s.client = nil
131+
}
132+
133+
if s.cancel != nil {
134+
s.cancel()
135+
s.cancel = nil
136+
}
137+
138+
// Wait for server goroutine to finish
139+
s.wg.Wait()
140+
141+
s.serverWriter.Close()
142+
s.serverReader.Close()
143+
s.serverReader, s.serverWriter = nil, nil
144+
145+
s.clientWriter.Close()
146+
s.clientReader.Close()
147+
s.clientReader, s.clientWriter = nil, nil
148+
}
149+
150+
// Client returns an MCP client connected to the server.
151+
// The client is already initialized, i.e. you do _not_ need to call Client.Initialize().
152+
func (s *Server) Client() *client.Client {
153+
return s.client
154+
}

mcptest/mcptest_test.go

+79
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
package mcptest_test
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strings"
7+
"testing"
8+
9+
"github.com/mark3labs/mcp-go/mcp"
10+
"github.com/mark3labs/mcp-go/mcptest"
11+
"github.com/mark3labs/mcp-go/server"
12+
)
13+
14+
func TestServer(t *testing.T) {
15+
ctx := context.Background()
16+
17+
srv, err := mcptest.NewServer(t, server.ServerTool{
18+
Tool: mcp.NewTool("hello",
19+
mcp.WithDescription("Says hello to the provided name, or world."),
20+
mcp.WithString("name", mcp.Description("The name to say hello to.")),
21+
),
22+
Handler: helloWorldHandler,
23+
})
24+
if err != nil {
25+
t.Fatal(err)
26+
}
27+
defer srv.Close()
28+
29+
client := srv.Client()
30+
31+
var req mcp.CallToolRequest
32+
req.Params.Name = "hello"
33+
req.Params.Arguments = map[string]any{
34+
"name": "Claude",
35+
}
36+
37+
result, err := client.CallTool(ctx, req)
38+
if err != nil {
39+
t.Fatal("CallTool:", err)
40+
}
41+
42+
got, err := resultToString(result)
43+
if err != nil {
44+
t.Fatal(err)
45+
}
46+
47+
want := "Hello, Claude!"
48+
if got != want {
49+
t.Errorf("Got %q, want %q", got, want)
50+
}
51+
}
52+
53+
func helloWorldHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
54+
// Extract name from request arguments
55+
name, ok := request.Params.Arguments["name"].(string)
56+
if !ok {
57+
name = "World"
58+
}
59+
60+
return mcp.NewToolResultText(fmt.Sprintf("Hello, %s!", name)), nil
61+
}
62+
63+
func resultToString(result *mcp.CallToolResult) (string, error) {
64+
var b strings.Builder
65+
66+
for _, content := range result.Content {
67+
text, ok := content.(mcp.TextContent)
68+
if !ok {
69+
return "", fmt.Errorf("unsupported content type: %T", content)
70+
}
71+
b.WriteString(text.Text)
72+
}
73+
74+
if result.IsError {
75+
return "", fmt.Errorf("%s", b.String())
76+
}
77+
78+
return b.String(), nil
79+
}

0 commit comments

Comments
 (0)