Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions pkg/runtime/toolexec/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -724,13 +724,21 @@ func buildMultiContent(text string, images []tools.MediaContent) []chat.MessageP
parts := make([]chat.MessagePart, 0, 1+len(images))
parts = append(parts, chat.MessagePart{Type: chat.MessagePartTypeText, Text: text})
for _, img := range images {
parts = append(parts, chat.MessagePart{
Type: chat.MessagePartTypeImageURL,
ImageURL: &chat.MessageImageURL{
URL: "data:" + img.MimeType + ";base64," + img.Data,
Detail: chat.ImageURLDetailAuto,
},
})
switch {
case img.FilePath != "":
parts = append(parts, chat.MessagePart{
Type: chat.MessagePartTypeText,
Text: fmt.Sprintf("[image saved to %s (%s)]", img.FilePath, img.MimeType),
})
case img.Data != "":
parts = append(parts, chat.MessagePart{
Type: chat.MessagePartTypeImageURL,
ImageURL: &chat.MessageImageURL{
URL: "data:" + img.MimeType + ";base64," + img.Data,
Detail: chat.ImageURLDetailAuto,
},
})
}
}
return parts
}
Expand Down
2 changes: 0 additions & 2 deletions pkg/tools/builtin/filesystem/filesystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,6 @@ type ReadFileArgs struct {

type ReadFileMeta struct {
Path string `json:"path"`
Content string `json:"content"`
LineCount int `json:"lineCount"`
Error string `json:"error,omitempty"`
}
Expand Down Expand Up @@ -1086,7 +1085,6 @@ func (t *ToolSet) handleReadMultipleFiles(ctx context.Context, args ReadMultiple
Path: path,
Content: text,
})
entry.Content = text
entry.LineCount = strings.Count(text, "\n") + 1
meta.Files = append(meta.Files, entry)
}
Expand Down
2 changes: 2 additions & 0 deletions pkg/tools/builtin/filesystem/filesystem_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ func TestFilesystemTool_ReadFile_TildePath(t *testing.T) {
require.NoError(t, err)
assert.False(t, result.IsError)
assert.Equal(t, content, result.Output)
assert.Equal(t, ReadFileMeta{LineCount: 1}, result.Meta)
}

func TestFilesystemTool_WriteFile(t *testing.T) {
Expand Down Expand Up @@ -166,6 +167,7 @@ func TestFilesystemTool_ReadFile(t *testing.T) {
})
require.NoError(t, err)
assert.Equal(t, content, result.Output)
assert.Equal(t, ReadFileMeta{LineCount: 1}, result.Meta)

result, err = tool.handleReadFile(t.Context(), ReadFileArgs{
Path: "nonexistent.txt",
Expand Down
38 changes: 33 additions & 5 deletions pkg/tools/lifecycle/supervisor.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,10 @@ type Supervisor struct {
// fresh channel by Start when transitioning out of a terminal state.
done chan struct{}

// watchDone is closed by the current watcher goroutine. Stop waits on it
// after closing the session so no transport goroutines are left behind.
watchDone chan struct{}

// randFloat is the jitter source; tests may override.
randFloat func() float64
}
Expand Down Expand Up @@ -214,6 +218,9 @@ func (s *Supervisor) Start(ctx context.Context) error {
}
s.session = sess
spawnWatcher := !s.watcherAlive
if spawnWatcher {
s.watchDone = make(chan struct{})
}
s.watcherAlive = true
// Recovering from a terminal state (Failed → Start, or a watcher
// that previously exited): refresh `done` so RestartAndWait callers
Expand Down Expand Up @@ -244,24 +251,40 @@ func (s *Supervisor) Start(ctx context.Context) error {
func (s *Supervisor) Stop(ctx context.Context) error {
s.mu.Lock()
if s.stopping {
watchDone := s.watchDone
s.mu.Unlock()
return nil
return waitForWatcher(ctx, watchDone)
}
s.stopping = true
sess := s.session
s.session = nil
watchDone := s.watchDone
s.mu.Unlock()

s.tracker.Set(StateStopped)
s.signalDone()

if sess == nil {
var closeErr error
if sess != nil {
closeErr = sess.Close(context.WithoutCancel(ctx))
}
waitErr := waitForWatcher(ctx, watchDone)
if closeErr != nil && ctx.Err() == nil {
return closeErr
}
return waitErr
}

func waitForWatcher(ctx context.Context, done <-chan struct{}) error {
if done == nil {
return nil
}
if err := sess.Close(context.WithoutCancel(ctx)); err != nil && ctx.Err() == nil {
return err
select {
case <-done:
return nil
case <-ctx.Done():
return ctx.Err()
}
return nil
}

// RestartAndWait closes the current session (if any) so the watcher
Expand Down Expand Up @@ -326,7 +349,12 @@ func (s *Supervisor) watch(ctx context.Context) {
defer func() {
s.mu.Lock()
s.watcherAlive = false
watchDone := s.watchDone
s.watchDone = nil
s.mu.Unlock()
if watchDone != nil {
close(watchDone)
}
}()

log := s.policy.logger()
Expand Down
80 changes: 76 additions & 4 deletions pkg/tools/lifecycle/supervisor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,20 +17,41 @@ import (
// fakeSession is a controllable session: its Wait blocks until either
// Close is called or fail is invoked.
type fakeSession struct {
mu sync.Mutex
closed bool
failCh chan error
mu sync.Mutex
closed bool
waitDone atomic.Bool // set true after Wait returns
waiting chan struct{} // closed once Wait has parked on failCh
waitOnce sync.Once
failCh chan error
}

func newFakeSession() *fakeSession {
return &fakeSession{failCh: make(chan error, 1)}
return &fakeSession{
waiting: make(chan struct{}),
failCh: make(chan error, 1),
}
}

func (f *fakeSession) Wait() error {
f.waitOnce.Do(func() { close(f.waiting) })
err := <-f.failCh
f.waitDone.Store(true)
return err
}

// waitParked blocks until the watcher goroutine has entered sess.Wait().
// Used by tests that need to exercise Stop against an actively-blocking
// watcher rather than the racy connect-then-stop path where the watcher
// could exit before parking.
func (f *fakeSession) waitParked(t *testing.T) {
t.Helper()
select {
case <-f.waiting:
case <-time.After(time.Second):
t.Fatal("watcher did not enter Wait()")
}
}

func (f *fakeSession) Close(context.Context) error {
f.mu.Lock()
if !f.closed {
Expand Down Expand Up @@ -458,3 +479,54 @@ func TestBackoff_Jitter(t *testing.T) {
d = lifecycle.ExportedBackoffDelay(b, 0, func() float64 { return 0 })
assert.Check(t, d == 50*time.Millisecond)
}

func TestSupervisor_StopWaitsForWatcher(t *testing.T) {
t.Parallel()

sess := newFakeSession()
c := newScriptedConnector(scriptStep{session: sess})
s := lifecycle.New("test", c, lifecycle.Policy{})

assert.NilError(t, s.Start(t.Context()))
sess.waitParked(t)

assert.NilError(t, s.Stop(t.Context()))
assert.Check(t, is.Equal(s.State().State, lifecycle.StateStopped))
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[Non-blocking] Brittle timing assertion and weak concurrency coverage.

This < time.Second bound will be flaky under heavy CI load. More importantly, the test only covers the sequential Start→Stop path, not the concurrent-Stop path that the s.stopping guard was added to fix. A test with two goroutines calling Stop concurrently while the watcher is alive (e.g. blocked in sess.Wait()) would be a stronger regression guard. Consider using goleak to assert no goroutines are left behind.


// Stop must not return until the watcher has observed Wait() unblock.
assert.Check(t, sess.waitDone.Load(), "Stop returned before watcher's Wait() completed")
}

// TestSupervisor_StopConcurrent exercises the s.stopping guard: several
// goroutines call Stop concurrently while the watcher is live in
// sess.Wait(). All calls must return without error and observe a
// fully-shut-down supervisor.
func TestSupervisor_StopConcurrent(t *testing.T) {
t.Parallel()

sess := newFakeSession()
c := newScriptedConnector(scriptStep{session: sess})
s := lifecycle.New("test", c, lifecycle.Policy{})

assert.NilError(t, s.Start(t.Context()))
sess.waitParked(t)

const n = 4
errs := make(chan error, n)
var wg sync.WaitGroup
wg.Add(n)
for range n {
go func() {
defer wg.Done()
errs <- s.Stop(t.Context())
}()
}
wg.Wait()
close(errs)

for err := range errs {
assert.NilError(t, err)
}
assert.Check(t, is.Equal(s.State().State, lifecycle.StateStopped))
assert.Check(t, sess.waitDone.Load(), "a Stop returned before watcher's Wait() completed")
}
114 changes: 104 additions & 10 deletions pkg/tools/mcp/mcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ type Toolset struct {

supervisor *lifecycle.Supervisor

// mediaDir is the toolset-scoped temp dir holding spooled media
// payloads. Created lazily on first spool, removed by Stop.
mediaMu sync.Mutex
mediaDir string

mu sync.Mutex

// Cached tools and prompts, invalidated via MCP notifications and
Expand Down Expand Up @@ -426,6 +431,7 @@ func (ts *Toolset) Start(ctx context.Context) error {
// Stop tears the supervisor down. Idempotent.
func (ts *Toolset) Stop(ctx context.Context) error {
slog.DebugContext(ctx, "Stopping MCP toolset", "server", ts.logID)
defer ts.cleanupMediaDir()
if ts.supervisor == nil {
return nil
}
Expand Down Expand Up @@ -694,7 +700,7 @@ func (ts *Toolset) callTool(ctx context.Context, toolCall tools.ToolCall) (*tool
return nil, fmt.Errorf("failed to call tool: %w", err)
}

result := processMCPContent(resp)
result := ts.processMCPContent(resp)
slog.DebugContext(ctx, "MCP tool call completed", "tool", toolCall.Function.Name, "output_length", len(result.Output))
slog.DebugContext(ctx, result.Output)
return result, nil
Expand All @@ -714,7 +720,13 @@ func isInitNotificationSendError(err error) bool {
return false
}

func processMCPContent(toolResult *mcp.CallToolResult) *tools.ToolCallResult {
const maxInlineMediaBytes = 256 * 1024

// writeMediaFile is a package-level indirection so tests can simulate
// disk failures without manipulating the filesystem.
var writeMediaFile = defaultWriteMediaFile

func (ts *Toolset) processMCPContent(toolResult *mcp.CallToolResult) *tools.ToolCallResult {
var text strings.Builder
var images, audios []tools.MediaContent

Expand All @@ -723,9 +735,9 @@ func processMCPContent(toolResult *mcp.CallToolResult) *tools.ToolCallResult {
case *mcp.TextContent:
text.WriteString(c.Text)
case *mcp.ImageContent:
images = append(images, encodeMedia(c.Data, c.MIMEType))
images = append(images, ts.encodeMedia(c.Data, c.MIMEType))
case *mcp.AudioContent:
audios = append(audios, encodeMedia(c.Data, c.MIMEType))
audios = append(audios, ts.encodeMedia(c.Data, c.MIMEType))
case *mcp.ResourceLink:
if c.Name != "" {
// Escape ] in name and ) in URI to prevent broken markdown links.
Expand Down Expand Up @@ -760,12 +772,94 @@ func processMCPContent(toolResult *mcp.CallToolResult) *tools.ToolCallResult {
}
}

// encodeMedia re-encodes raw bytes (as decoded by the MCP SDK) back to base64
// for our internal MediaContent representation.
func encodeMedia(data []byte, mimeType string) tools.MediaContent {
return tools.MediaContent{
Data: base64.StdEncoding.EncodeToString(data),
MimeType: mimeType,
// encodeMedia keeps small payloads inline and spools larger ones to disk so the
// session and TUI do not retain duplicate base64 copies. Spooled files live
// under a toolset-scoped temp directory removed by Stop.
func (ts *Toolset) encodeMedia(data []byte, mimeType string) tools.MediaContent {
media := tools.MediaContent{MimeType: mimeType}
if len(data) <= maxInlineMediaBytes {
media.Data = base64.StdEncoding.EncodeToString(data)
return media
}

dir, err := ts.ensureMediaDir()
if err == nil {
var path string
path, err = writeMediaFile(dir, data, mimeType)
if err == nil {
media.FilePath = path
return media
}
}
slog.Warn("failed to spool MCP media to disk", "mime_type", mimeType, "bytes", len(data), "error", err)
media.Data = base64.StdEncoding.EncodeToString(data)
return media
}

// ensureMediaDir lazily creates the toolset-scoped temp dir for spooled
// media payloads. The directory is removed by Stop.
func (ts *Toolset) ensureMediaDir() (string, error) {
ts.mediaMu.Lock()
defer ts.mediaMu.Unlock()
if ts.mediaDir != "" {
return ts.mediaDir, nil
}
dir, err := os.MkdirTemp("", "docker-agent-mcp-media-*")
if err != nil {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[BLOCKING] Disk leak — these temp directories are never removed.

os.MkdirTemp creates a directory in os.TempDir() on every call to writeMediaFile, but nothing cleans it up when the ToolCallResult is discarded or the session ends. Over a long session with many large media tool responses this will exhaust disk space.

Suggested fix: introduce a session-scoped temp directory (created once, removed on Stop) instead of a per-payload directory. Alternatively, register a cleanup closure on the Toolset lifecycle the same way GatewayToolset does today.

return "", err
}
ts.mediaDir = dir
return dir, nil
}

// cleanupMediaDir removes the toolset-scoped media spool directory, if any.
func (ts *Toolset) cleanupMediaDir() {
ts.mediaMu.Lock()
dir := ts.mediaDir
ts.mediaDir = ""
ts.mediaMu.Unlock()
if dir == "" {
return
}
if err := os.RemoveAll(dir); err != nil {
slog.Warn("failed to remove MCP media spool directory", "dir", dir, "error", err)
}
}

func defaultWriteMediaFile(dir string, data []byte, mimeType string) (string, error) {
f, err := os.CreateTemp(dir, "media-*"+mediaExtension(mimeType))
if err != nil {
return "", err
}
path := f.Name()
if _, err := f.Write(data); err != nil {
_ = f.Close()
_ = os.Remove(path)
return "", err
}
if err := f.Close(); err != nil {
_ = os.Remove(path)
return "", err
}
return path, nil
}

func mediaExtension(mimeType string) string {
switch mimeType {
case "image/png":
return ".png"
case "image/jpeg":
return ".jpg"
case "image/gif":
return ".gif"
case "image/webp":
return ".webp"
case "audio/wav", "audio/wave", "audio/x-wav":
return ".wav"
case "audio/mpeg", "audio/mp3":
return ".mp3"
default:
return ".bin"
}
}

Expand Down
Loading
Loading