|
| 1 | +// Package mcptest implements helper functions for testing MCP servers. |
| 2 | +package mcptest |
| 3 | + |
| 4 | +import ( |
| 5 | + "bytes" |
| 6 | + "context" |
| 7 | + "fmt" |
| 8 | + "io" |
| 9 | + "log" |
| 10 | + "sync" |
| 11 | + "testing" |
| 12 | + |
| 13 | + "github.com/mark3labs/mcp-go/client" |
| 14 | + "github.com/mark3labs/mcp-go/client/transport" |
| 15 | + "github.com/mark3labs/mcp-go/mcp" |
| 16 | + "github.com/mark3labs/mcp-go/server" |
| 17 | +) |
| 18 | + |
| 19 | +// Server encapsulates an MCP server and manages resources like pipes and context. |
| 20 | +type Server struct { |
| 21 | + name string |
| 22 | + tools []server.ServerTool |
| 23 | + |
| 24 | + ctx context.Context |
| 25 | + cancel func() |
| 26 | + |
| 27 | + serverReader *io.PipeReader |
| 28 | + serverWriter *io.PipeWriter |
| 29 | + clientReader *io.PipeReader |
| 30 | + clientWriter *io.PipeWriter |
| 31 | + |
| 32 | + logBuffer bytes.Buffer |
| 33 | + |
| 34 | + transport transport.Interface |
| 35 | + client *client.Client |
| 36 | + |
| 37 | + wg sync.WaitGroup |
| 38 | +} |
| 39 | + |
| 40 | +// NewServer starts a new MCP server with the provided tools and returns the server instance. |
| 41 | +func NewServer(t *testing.T, tools ...server.ServerTool) (*Server, error) { |
| 42 | + server := NewUnstartedServer(t) |
| 43 | + server.AddTools(tools...) |
| 44 | + |
| 45 | + if err := server.Start(); err != nil { |
| 46 | + return nil, err |
| 47 | + } |
| 48 | + |
| 49 | + return server, nil |
| 50 | +} |
| 51 | + |
| 52 | +// NewUnstartedServer creates a new MCP server instance with the given name, but does not start the server. |
| 53 | +// Useful for tests where you need to add tools before starting the server. |
| 54 | +func NewUnstartedServer(t *testing.T) *Server { |
| 55 | + server := &Server{ |
| 56 | + name: t.Name(), |
| 57 | + } |
| 58 | + |
| 59 | + // Use t.Context() once we switch to go >= 1.24 |
| 60 | + ctx := context.TODO() |
| 61 | + |
| 62 | + // Set up context with cancellation, used to stop the server |
| 63 | + server.ctx, server.cancel = context.WithCancel(ctx) |
| 64 | + |
| 65 | + // Set up pipes for client-server communication |
| 66 | + server.serverReader, server.clientWriter = io.Pipe() |
| 67 | + server.clientReader, server.serverWriter = io.Pipe() |
| 68 | + |
| 69 | + // Return the configured server |
| 70 | + return server |
| 71 | +} |
| 72 | + |
| 73 | +// AddTools adds multiple tools to an unstarted server. |
| 74 | +func (s *Server) AddTools(tools ...server.ServerTool) { |
| 75 | + s.tools = append(s.tools, tools...) |
| 76 | +} |
| 77 | + |
| 78 | +// AddTool adds a tool to an unstarted server. |
| 79 | +func (s *Server) AddTool(tool mcp.Tool, handler server.ToolHandlerFunc) { |
| 80 | + s.tools = append(s.tools, server.ServerTool{ |
| 81 | + Tool: tool, |
| 82 | + Handler: handler, |
| 83 | + }) |
| 84 | +} |
| 85 | + |
| 86 | +// Start starts the server in a goroutine. Make sure to defer Close() after Start(). |
| 87 | +// When using NewServer(), the returned server is already started. |
| 88 | +func (s *Server) Start() error { |
| 89 | + s.wg.Add(1) |
| 90 | + |
| 91 | + // Start the MCP server in a goroutine |
| 92 | + go func() { |
| 93 | + defer s.wg.Done() |
| 94 | + |
| 95 | + mcpServer := server.NewMCPServer(s.name, "1.0.0") |
| 96 | + |
| 97 | + mcpServer.AddTools(s.tools...) |
| 98 | + |
| 99 | + logger := log.New(&s.logBuffer, "", 0) |
| 100 | + |
| 101 | + stdioServer := server.NewStdioServer(mcpServer) |
| 102 | + stdioServer.SetErrorLogger(logger) |
| 103 | + |
| 104 | + if err := stdioServer.Listen(s.ctx, s.serverReader, s.serverWriter); err != nil { |
| 105 | + logger.Println("StdioServer.Listen failed:", err) |
| 106 | + } |
| 107 | + }() |
| 108 | + |
| 109 | + s.transport = transport.NewIO(s.clientReader, s.clientWriter, io.NopCloser(&s.logBuffer)) |
| 110 | + if err := s.transport.Start(s.ctx); err != nil { |
| 111 | + return fmt.Errorf("transport.Start(): %w", err) |
| 112 | + } |
| 113 | + |
| 114 | + s.client = client.NewClient(s.transport) |
| 115 | + |
| 116 | + var initReq mcp.InitializeRequest |
| 117 | + initReq.Params.ProtocolVersion = mcp.LATEST_PROTOCOL_VERSION |
| 118 | + if _, err := s.client.Initialize(s.ctx, initReq); err != nil { |
| 119 | + return fmt.Errorf("client.Initialize(): %w", err) |
| 120 | + } |
| 121 | + |
| 122 | + return nil |
| 123 | +} |
| 124 | + |
| 125 | +// Close stops the server and cleans up resources like temporary directories. |
| 126 | +func (s *Server) Close() { |
| 127 | + if s.transport != nil { |
| 128 | + s.transport.Close() |
| 129 | + s.transport = nil |
| 130 | + s.client = nil |
| 131 | + } |
| 132 | + |
| 133 | + if s.cancel != nil { |
| 134 | + s.cancel() |
| 135 | + s.cancel = nil |
| 136 | + } |
| 137 | + |
| 138 | + // Wait for server goroutine to finish |
| 139 | + s.wg.Wait() |
| 140 | + |
| 141 | + s.serverWriter.Close() |
| 142 | + s.serverReader.Close() |
| 143 | + s.serverReader, s.serverWriter = nil, nil |
| 144 | + |
| 145 | + s.clientWriter.Close() |
| 146 | + s.clientReader.Close() |
| 147 | + s.clientReader, s.clientWriter = nil, nil |
| 148 | +} |
| 149 | + |
| 150 | +// Client returns an MCP client connected to the server. |
| 151 | +// The client is already initialized, i.e. you do _not_ need to call Client.Initialize(). |
| 152 | +func (s *Server) Client() *client.Client { |
| 153 | + return s.client |
| 154 | +} |
0 commit comments