Skip to content

feat: add hooks for sse client #221

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions client/transport/sse.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,18 @@ import (
"github.com/mark3labs/mcp-go/mcp"
)

// OnBeforeRequestFunc is called before sending the request, with context.
type OnBeforeRequestFunc func(ctx context.Context, req *http.Request)

// OnAfterResponseFunc is called after receiving the response, with context. (Regardless of error, when err is not nil resp may be nil.) The req parameter is included.
type OnAfterResponseFunc func(ctx context.Context, req *http.Request, resp *http.Response, err error)

// SSEHooks supports multiple before and after processing functions.
type SSEHooks struct {
OnBeforeRequest []OnBeforeRequestFunc
OnAfterResponse []OnAfterResponseFunc
}

// SSE implements the transport layer of the MCP protocol using Server-Sent Events (SSE).
// It maintains a persistent HTTP connection to receive server-pushed events
// while sending requests over regular HTTP POST calls. The client handles
Expand All @@ -32,6 +44,8 @@ type SSE struct {
endpointChan chan struct{}
headers map[string]string

hooks SSEHooks

started atomic.Bool
closed atomic.Bool
cancelSSEStream context.CancelFunc
Expand All @@ -45,6 +59,13 @@ func WithHeaders(headers map[string]string) ClientOption {
}
}

// Register a set of hooks (overwrites existing hooks)
func WithSSEHooks(hooks SSEHooks) ClientOption {
return func(sc *SSE) {
sc.hooks = hooks
}
}

// NewSSE creates a new SSE-based MCP client with the given base URL.
// Returns an error if the URL is invalid.
func NewSSE(baseURL string, options ...ClientOption) (*SSE, error) {
Expand Down Expand Up @@ -261,6 +282,13 @@ func (c *SSE) SendRequest(
req.Header.Set(k, v)
}

// hooks: before request
for _, hook := range c.hooks.OnBeforeRequest {
if hook != nil {
hook(ctx, req)
}
}

// Register response channel
responseChan := make(chan *JSONRPCResponse, 1)
c.mu.Lock()
Expand All @@ -274,6 +302,14 @@ func (c *SSE) SendRequest(

// Send request
resp, err := c.httpClient.Do(req)

// hooks: after response
for _, hook := range c.hooks.OnAfterResponse {
if hook != nil {
hook(ctx, req, resp, err)
}
}

if err != nil {
deleteResponseChan()
return nil, fmt.Errorf("failed to send request: %w", err)
Expand Down Expand Up @@ -348,7 +384,22 @@ func (c *SSE) SendNotification(ctx context.Context, notification mcp.JSONRPCNoti
req.Header.Set(k, v)
}

// hooks: before request
for _, hook := range c.hooks.OnBeforeRequest {
if hook != nil {
hook(ctx, req)
}
}

resp, err := c.httpClient.Do(req)

// hooks: after response
for _, hook := range c.hooks.OnAfterResponse {
if hook != nil {
hook(ctx, req, resp, err)
}
}

if err != nil {
return fmt.Errorf("failed to send notification: %w", err)
}
Expand Down