Skip to content

Commit 0013a0e

Browse files
authored
feat(contrib/mcp-go): Trace MCP session initializations with MLObs spans (#4101)
Co-authored-by: julian.boilen <[email protected]>
1 parent 854c849 commit 0013a0e

File tree

4 files changed

+180
-5
lines changed

4 files changed

+180
-5
lines changed

contrib/mark3labs/mcp-go/README.md

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
This integration provides Datadog tracing for the [mark3labs/mcp-go](https://github.com/mark3labs/mcp-go) library.
44

5+
Both hooks and middleware are used.
6+
57
## Usage
68

79
```go
@@ -15,13 +17,18 @@ func main() {
1517
tracer.Start()
1618
defer tracer.Stop()
1719

18-
srv := server.NewMCPServer("my-server", "1.0.0",
19-
server.WithToolHandlerMiddleware(mcpgotrace.NewToolHandlerMiddleware()))
20-
_ = srv
20+
// Add tracing to your server hooks
21+
hooks := &server.Hooks{}
22+
mcpgotrace.AddServerHooks(hooks)
23+
24+
srv := server.NewMCPServer("my-server", "1.0.0",
25+
server.WithHooks(hooks),
26+
server.WithToolHandlerMiddleware(mcpgotrace.NewToolHandlerMiddleware()))
2127
}
2228
```
2329

2430
## Features
2531

2632
The integration automatically traces:
2733
- **Tool calls**: Creates LLMObs tool spans with input/output annotation for all tool invocations
34+
- **Session initialization**: Create LLMObs task spans for session initialization, including client information.

contrib/mark3labs/mcp-go/example_test.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,12 @@ func Example() {
1515
tracer.Start()
1616
defer tracer.Stop()
1717

18+
// Create server hooks and add Datadog tracing
19+
hooks := &server.Hooks{}
20+
mcpgotrace.AddServerHooks(hooks)
21+
1822
srv := server.NewMCPServer("my-server", "1.0.0",
23+
server.WithHooks(hooks),
1924
server.WithToolHandlerMiddleware(mcpgotrace.NewToolHandlerMiddleware()))
2025
_ = srv
2126
}

contrib/mark3labs/mcp-go/mcpgo.go

Lines changed: 89 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ package mcpgo // import "github.com/DataDog/dd-trace-go/contrib/mark3labs/mcp-go
88
import (
99
"context"
1010
"encoding/json"
11+
"sync"
1112

1213
"github.com/DataDog/dd-trace-go/v2/instrumentation"
1314
"github.com/DataDog/dd-trace-go/v2/llmobs"
@@ -22,17 +23,35 @@ func init() {
2223
instr = instrumentation.Load(instrumentation.PackageMark3LabsMCPGo)
2324
}
2425

26+
type hooks struct {
27+
spanCache *sync.Map
28+
}
29+
30+
// AddServerHooks appends Datadog tracing hooks to an existing server.Hooks object.
31+
func AddServerHooks(hooks *server.Hooks) {
32+
ddHooks := newHooks()
33+
hooks.AddBeforeInitialize(ddHooks.onBeforeInitialize)
34+
hooks.AddAfterInitialize(ddHooks.onAfterInitialize)
35+
hooks.AddOnError(ddHooks.onError)
36+
}
37+
2538
func NewToolHandlerMiddleware() server.ToolHandlerMiddleware {
2639
return func(next server.ToolHandlerFunc) server.ToolHandlerFunc {
2740
return func(ctx context.Context, request mcp.CallToolRequest) (*mcp.CallToolResult, error) {
2841
toolSpan, ctx := llmobs.StartToolSpan(ctx, request.Params.Name, llmobs.WithIntegration(string(instrumentation.PackageMark3LabsMCPGo)))
2942

3043
result, err := next(ctx, request)
3144

32-
inputJSON, _ := json.Marshal(request)
45+
inputJSON, marshalErr := json.Marshal(request)
46+
if marshalErr != nil {
47+
instr.Logger().Warn("mcp-go: failed to marshal tool request: %v", marshalErr)
48+
}
3349
var outputText string
3450
if result != nil {
35-
resultJSON, _ := json.Marshal(result)
51+
resultJSON, marshalErr := json.Marshal(result)
52+
if marshalErr != nil {
53+
instr.Logger().Warn("mcp-go: failed to marshal tool result: %v", marshalErr)
54+
}
3655
outputText = string(resultJSON)
3756
}
3857

@@ -48,3 +67,71 @@ func NewToolHandlerMiddleware() server.ToolHandlerMiddleware {
4867
}
4968
}
5069
}
70+
71+
func newHooks() *hooks {
72+
return &hooks{
73+
spanCache: &sync.Map{},
74+
}
75+
}
76+
77+
func (h *hooks) onBeforeInitialize(ctx context.Context, id any, request *mcp.InitializeRequest) {
78+
taskSpan, _ := llmobs.StartTaskSpan(ctx, "mcp.initialize", llmobs.WithIntegration("mark3labs/mcp-go"))
79+
80+
clientName := request.Params.ClientInfo.Name
81+
clientVersion := request.Params.ClientInfo.Version
82+
taskSpan.Annotate(llmobs.WithAnnotatedTags(map[string]string{"client_name": clientName, "client_version": clientName + "_" + clientVersion}))
83+
84+
h.spanCache.Store(id, taskSpan)
85+
}
86+
87+
func (h *hooks) onAfterInitialize(ctx context.Context, id any, request *mcp.InitializeRequest, result *mcp.InitializeResult) {
88+
finishSpanWithIO(h, id, request, result)
89+
}
90+
91+
func (h *hooks) onError(ctx context.Context, id any, method mcp.MCPMethod, message any, err error) {
92+
if method != mcp.MethodInitialize {
93+
return
94+
}
95+
value, ok := h.spanCache.LoadAndDelete(id)
96+
if !ok {
97+
return
98+
}
99+
100+
span, ok := value.(*llmobs.TaskSpan)
101+
if !ok {
102+
return
103+
}
104+
105+
defer span.Finish(llmobs.WithError(err))
106+
107+
inputJSON, marshalErr := json.Marshal(message)
108+
if marshalErr != nil {
109+
instr.Logger().Warn("mcp-go: failed to marshal error message: %v", marshalErr)
110+
}
111+
span.AnnotateTextIO(string(inputJSON), err.Error())
112+
113+
}
114+
115+
func finishSpanWithIO[Req any, Res any](h *hooks, id any, request Req, result Res) {
116+
value, ok := h.spanCache.LoadAndDelete(id)
117+
if !ok {
118+
return
119+
}
120+
span, ok := value.(*llmobs.TaskSpan)
121+
if !ok {
122+
return
123+
}
124+
125+
defer span.Finish()
126+
127+
inputJSON, marshalErr := json.Marshal(request)
128+
if marshalErr != nil {
129+
instr.Logger().Warn("mcp-go: failed to marshal request: %v", marshalErr)
130+
}
131+
resultJSON, marshalErr := json.Marshal(result)
132+
if marshalErr != nil {
133+
instr.Logger().Warn("mcp-go: failed to marshal result: %v", marshalErr)
134+
}
135+
136+
span.AnnotateTextIO(string(inputJSON), string(resultJSON))
137+
}

contrib/mark3labs/mcp-go/mcpgo_test.go

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"fmt"
1313
"testing"
1414

15+
"github.com/DataDog/dd-trace-go/v2/ddtrace/mocktracer"
1516
"github.com/DataDog/dd-trace-go/v2/ddtrace/tracer"
1617
"github.com/DataDog/dd-trace-go/v2/instrumentation/testutils/testtracer"
1718
"github.com/mark3labs/mcp-go/mcp"
@@ -20,6 +21,81 @@ import (
2021
"github.com/stretchr/testify/require"
2122
)
2223

24+
func TestNewToolHandlerMiddleware(t *testing.T) {
25+
mt := mocktracer.Start()
26+
defer mt.Stop()
27+
28+
middleware := NewToolHandlerMiddleware()
29+
assert.NotNil(t, middleware)
30+
}
31+
32+
func TestAddServerHooks(t *testing.T) {
33+
mt := mocktracer.Start()
34+
defer mt.Stop()
35+
36+
serverHooks := &server.Hooks{}
37+
AddServerHooks(serverHooks)
38+
39+
assert.Len(t, serverHooks.OnBeforeInitialize, 1)
40+
assert.Len(t, serverHooks.OnAfterInitialize, 1)
41+
assert.Len(t, serverHooks.OnError, 1)
42+
}
43+
44+
func TestIntegrationSessionInitialize(t *testing.T) {
45+
tt := testTracer(t)
46+
defer tt.Stop()
47+
48+
hooks := &server.Hooks{}
49+
AddServerHooks(hooks)
50+
51+
srv := server.NewMCPServer("test-server", "1.0.0",
52+
server.WithHooks(hooks))
53+
54+
ctx := context.Background()
55+
initRequest := `{"jsonrpc":"2.0","id":1,"method":"initialize","params":{"protocolVersion":"2024-11-05","capabilities":{},"clientInfo":{"name":"test-client","version":"1.0.0"}}}`
56+
57+
response := srv.HandleMessage(ctx, []byte(initRequest))
58+
assert.NotNil(t, response)
59+
60+
responseBytes, err := json.Marshal(response)
61+
require.NoError(t, err)
62+
63+
var resp map[string]interface{}
64+
err = json.Unmarshal(responseBytes, &resp)
65+
require.NoError(t, err)
66+
assert.Equal(t, "2.0", resp["jsonrpc"])
67+
assert.Equal(t, float64(1), resp["id"])
68+
assert.NotNil(t, resp["result"])
69+
70+
spans := tt.WaitForLLMObsSpans(t, 1)
71+
require.Len(t, spans, 1)
72+
73+
taskSpan := spans[0]
74+
assert.Equal(t, "mcp.initialize", taskSpan.Name)
75+
assert.Equal(t, "task", taskSpan.Meta["span.kind"])
76+
77+
assert.Contains(t, taskSpan.Tags, "client_name:test-client")
78+
assert.Contains(t, taskSpan.Tags, "client_version:test-client_1.0.0")
79+
80+
assert.Contains(t, taskSpan.Meta, "input")
81+
assert.Contains(t, taskSpan.Meta, "output")
82+
83+
inputMeta := taskSpan.Meta["input"]
84+
assert.NotNil(t, inputMeta)
85+
inputJSON, err := json.Marshal(inputMeta)
86+
require.NoError(t, err)
87+
inputStr := string(inputJSON)
88+
assert.Contains(t, inputStr, "2024-11-05")
89+
assert.Contains(t, inputStr, "test-client")
90+
91+
outputMeta := taskSpan.Meta["output"]
92+
assert.NotNil(t, outputMeta)
93+
outputJSON, err := json.Marshal(outputMeta)
94+
require.NoError(t, err)
95+
outputStr := string(outputJSON)
96+
assert.Contains(t, outputStr, "serverInfo")
97+
}
98+
2399
// Test tool spans are recorded on a successful tool call
24100
func TestIntegrationToolCallSuccess(t *testing.T) {
25101
tt := testTracer(t)

0 commit comments

Comments
 (0)