Skip to content

Commit b1c1bc3

Browse files
committed
Correct context management
Signed-off-by: Donnie Adams <[email protected]>
1 parent 2591aff commit b1c1bc3

File tree

1 file changed

+20
-5
lines changed

1 file changed

+20
-5
lines changed

pkg/mcp/loader.go

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,10 @@ var (
2727
)
2828

2929
type Local struct {
30-
lock sync.Mutex
31-
sessions map[string]*Session
30+
lock sync.Mutex
31+
sessions map[string]*Session
32+
sessionCtx context.Context
33+
cancel context.CancelFunc
3234
}
3335

3436
type Session struct {
@@ -97,7 +99,7 @@ func (l *Local) Load(ctx context.Context, tool types.Tool) (result []types.Tool,
9799
}
98100

99101
for server := range maps.Keys(servers.MCPServers) {
100-
session, err := l.loadSession(ctx, servers.MCPServers[server])
102+
session, err := l.loadSession(servers.MCPServers[server])
101103
if err != nil {
102104
return nil, fmt.Errorf("failed to load MCP session for server %s: %w", server, err)
103105
}
@@ -117,6 +119,15 @@ func (l *Local) Close() error {
117119
l.lock.Lock()
118120
defer l.lock.Unlock()
119121

122+
if l.sessionCtx == nil {
123+
return nil
124+
}
125+
126+
defer func() {
127+
l.cancel()
128+
l.sessionCtx = nil
129+
}()
130+
120131
var errs []error
121132
for id, session := range l.sessions {
122133
logger.Infof("closing MCP session %s", id)
@@ -222,10 +233,14 @@ func (l *Local) sessionToTools(ctx context.Context, session *Session, toolName s
222233
return toolDefs, nil
223234
}
224235

225-
func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session, error) {
236+
func (l *Local) loadSession(server ServerConfig) (*Session, error) {
226237
id := hash.Digest(server)
227238
l.lock.Lock()
228239
existing, ok := l.sessions[id]
240+
if l.sessionCtx == nil {
241+
l.sessionCtx, l.cancel = context.WithCancel(context.Background())
242+
}
243+
ctx := l.sessionCtx
229244
l.lock.Unlock()
230245

231246
if ok {
@@ -259,7 +274,7 @@ func (l *Local) loadSession(ctx context.Context, server ServerConfig) (*Session,
259274
}
260275

261276
// We expect the client to outlive this one request.
262-
if err = c.Start(context.Background()); err != nil {
277+
if err = c.Start(ctx); err != nil {
263278
return nil, fmt.Errorf("failed to start MCP client: %w", err)
264279
}
265280
}

0 commit comments

Comments
 (0)