Skip to content

Commit fa96194

Browse files
committed
feat(mcptest): Add package help testing.
The new `mcptest` package provides functions for setting up an in-process MCP server and an MCP client connected to it. This allows testing MCP implementations end-to-end without spawning any processes. The `mcptest` package itself has a unit test that demonstrates how to use the package.
1 parent 090e9e3 commit fa96194

File tree

2 files changed

+209
-0
lines changed

2 files changed

+209
-0
lines changed

mcptest/mcptest.go

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
// Package mcptest implements helper functions for testing MCP servers.
2+
package mcptest
3+
4+
import (
5+
"bytes"
6+
"context"
7+
"io"
8+
"log"
9+
"testing"
10+
11+
"github.com/mark3labs/mcp-go/client"
12+
"github.com/mark3labs/mcp-go/client/transport"
13+
"github.com/mark3labs/mcp-go/mcp"
14+
"github.com/mark3labs/mcp-go/server"
15+
)
16+
17+
// Server encapsulates an MCP server and manages resources like pipes and context.
18+
type Server struct {
19+
name string
20+
tools []server.ServerTool
21+
22+
ctx context.Context
23+
cancel func()
24+
25+
serverReader io.Reader
26+
serverWriter io.Writer
27+
clientReader io.Reader
28+
clientWriter io.WriteCloser
29+
30+
logBuffer bytes.Buffer
31+
32+
transport transport.Interface
33+
}
34+
35+
// NewServer starts a new MCP server with the provided tools and returns the server instance.
36+
func NewServer(t *testing.T, tools ...server.ServerTool) (*Server, error) {
37+
server := NewUnstartedServer(t)
38+
server.AddTools(tools...)
39+
40+
if err := server.Start(); err != nil {
41+
return nil, err
42+
}
43+
44+
return server, nil
45+
}
46+
47+
// NewUnstartedServer creates a new MCP server instance with the given name, but does not start the server.
48+
// Useful for tests where you need to add tools before starting the server.
49+
func NewUnstartedServer(t *testing.T) *Server {
50+
server := &Server{
51+
name: t.Name(),
52+
}
53+
54+
// Use t.Context() once we switch to go >= 1.24
55+
ctx := context.TODO()
56+
57+
// Set up context with cancellation, used to stop the server
58+
server.ctx, server.cancel = context.WithCancel(ctx)
59+
60+
// Set up pipes for client-server communication
61+
server.serverReader, server.clientWriter = io.Pipe()
62+
server.clientReader, server.serverWriter = io.Pipe()
63+
64+
// Return the configured server
65+
return server
66+
}
67+
68+
// AddTools adds multiple tools to an unstarted server.
69+
func (s *Server) AddTools(tools ...server.ServerTool) {
70+
s.tools = append(s.tools, tools...)
71+
}
72+
73+
// AddTool adds a tool to an unstarted server.
74+
func (s *Server) AddTool(tool mcp.Tool, handler server.ToolHandlerFunc) {
75+
s.tools = append(s.tools, server.ServerTool{
76+
Tool: tool,
77+
Handler: handler,
78+
})
79+
}
80+
81+
// Start starts the server in a goroutine. Make sure to defer Close() after Start().
82+
// When using NewServer(), the returned server is already started.
83+
func (s *Server) Start() error {
84+
// Start the MCP server in a goroutine
85+
go func() {
86+
mcpServer := server.NewMCPServer(s.name, "1.0.0")
87+
88+
mcpServer.AddTools(s.tools...)
89+
90+
logger := log.New(&s.logBuffer, "", 0)
91+
92+
stdioServer := server.NewStdioServer(mcpServer)
93+
stdioServer.SetErrorLogger(logger)
94+
95+
if err := stdioServer.Listen(s.ctx, s.serverReader, s.serverWriter); err != nil {
96+
logger.Println("StdioServer.Listen failed:", err)
97+
}
98+
}()
99+
100+
s.transport = transport.NewIO(s.clientReader, s.clientWriter, io.NopCloser(&s.logBuffer))
101+
102+
return s.transport.Start(s.ctx)
103+
}
104+
105+
// Close stops the server and cleans up resources like temporary directories.
106+
func (s *Server) Close() {
107+
if s.transport != nil {
108+
s.transport.Close()
109+
s.transport = nil
110+
}
111+
112+
if s.cancel != nil {
113+
s.cancel()
114+
s.cancel = nil
115+
}
116+
}
117+
118+
// Client returns an MCP client connected to the server.
119+
func (s *Server) Client() client.MCPClient {
120+
return client.NewClient(s.transport)
121+
}

mcptest/mcptest_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package mcptest_test
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"strings"
7+
"testing"
8+
9+
"github.com/mark3labs/mcp-go/mcp"
10+
"github.com/mark3labs/mcp-go/mcptest"
11+
"github.com/mark3labs/mcp-go/server"
12+
)
13+
14+
func TestServer(t *testing.T) {
15+
ctx := context.Background()
16+
17+
srv, err := mcptest.NewServer(t, server.ServerTool{
18+
Tool: mcp.NewTool("hello",
19+
mcp.WithDescription("Says hello to the provided name, or world."),
20+
mcp.WithString("name", mcp.Description("The name to say hello to.")),
21+
),
22+
Handler: helloWorldHandler,
23+
})
24+
if err != nil {
25+
t.Fatal(err)
26+
}
27+
defer srv.Close()
28+
29+
client := srv.Client()
30+
31+
var initReq mcp.InitializeRequest
32+
33+
initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION
34+
35+
if _, err := client.Initialize(ctx, initReq); err != nil {
36+
t.Fatal("Initialize:", err)
37+
}
38+
39+
var req mcp.CallToolRequest
40+
41+
req.Params.Name = "hello"
42+
req.Params.Arguments = map[string]any{
43+
"name": "Claude",
44+
}
45+
46+
result, err := client.CallTool(ctx, req)
47+
if err != nil {
48+
t.Fatal("CallTool:", err)
49+
}
50+
51+
got, err := resultToString(result)
52+
if err != nil {
53+
t.Fatal(err)
54+
}
55+
56+
want := "Hello, Claude!"
57+
if got != want {
58+
t.Errorf("Got %q, want %q", got, want)
59+
}
60+
}
61+
62+
func helloWorldHandler(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
63+
// Extract name from request arguments
64+
name, ok := request.Params.Arguments["name"].(string)
65+
if !ok {
66+
name = "World"
67+
}
68+
69+
return mcp.NewToolResultText(fmt.Sprintf("Hello, %s!", name)), nil
70+
}
71+
72+
func resultToString(result *mcp.CallToolResult) (string, error) {
73+
var b strings.Builder
74+
75+
for _, content := range result.Content {
76+
text, ok := content.(mcp.TextContent)
77+
if !ok {
78+
return "", fmt.Errorf("unsupported content type: %T", content)
79+
}
80+
b.WriteString(text.Text)
81+
}
82+
83+
if result.IsError {
84+
return "", fmt.Errorf("%s", b.String())
85+
}
86+
87+
return b.String(), nil
88+
}

0 commit comments

Comments
 (0)