Skip to content

enhance: send MCP errors back to the LLM so it can correct if possible #974

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func compressEnv(envs []string) (result []string) {
return
}

func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCategory ToolCategory) (cmdOut string, cmdErr error) {
func (e *Engine) runCommand(ctx Context, tool types.Tool, input string) (cmdOut string, cmdErr error) {
id := counter.Next()

var combinedOutput string
Expand Down Expand Up @@ -128,7 +128,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate

cmd, stop, err := e.newCommand(commandCtx, extraEnv, tool, input, true)
if err != nil {
if toolCategory == NoCategory && ctx.Parent != nil {
if ctx.ToolCategory == NoCategory && ctx.Parent != nil {
return fmt.Sprintf("ERROR: got (%v) while parsing command", err), nil
}
return "", fmt.Errorf("got (%v) while parsing command", err)
Expand Down Expand Up @@ -167,7 +167,7 @@ func (e *Engine) runCommand(ctx Context, tool types.Tool, input string, toolCate

if err := cmd.Run(); err != nil && (commandCtx.Err() == nil || ctx.Ctx.Err() != nil) {
// If the command failed and the context hasn't been canceled, then return the error.
if toolCategory == NoCategory && ctx.Parent != nil {
if ctx.ToolCategory == NoCategory && ctx.Parent != nil {
// If this is a sub-call, then don't return the error; return the error as a message so that the LLM can retry.
return fmt.Sprintf("ERROR: got (%v) while running tool, OUTPUT: %s", err, stdoutAndErr), nil
}
Expand Down
6 changes: 3 additions & 3 deletions pkg/engine/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ type Engine struct {
}

type MCPRunner interface {
Run(ctx context.Context, progress chan<- types.CompletionStatus, tool types.Tool, input string) (string, error)
Run(ctx Context, progress chan<- types.CompletionStatus, tool types.Tool, input string) (string, error)
}

type State struct {
Expand Down Expand Up @@ -313,7 +313,7 @@ func populateMessageParams(ctx Context, completion *types.CompletionRequest, too
}

func (e *Engine) runMCPInvoke(ctx Context, tool types.Tool, input string) (*Return, error) {
output, err := e.MCPRunner.Run(ctx.Ctx, e.Progress, tool, input)
output, err := e.MCPRunner.Run(ctx, e.Progress, tool, input)
if err != nil {
return nil, fmt.Errorf("failed to run MCP invoke: %w", err)
}
Expand All @@ -335,7 +335,7 @@ func (e *Engine) runCommandTools(ctx Context, tool types.Tool, input string) (*R
} else if tool.IsCall() {
return e.runCall(ctx, tool, input)
}
s, err := e.runCommand(ctx, tool, input, ctx.ToolCategory)
s, err := e.runCommand(ctx, tool, input)
return &Return{
Result: &s,
}, err
Expand Down
14 changes: 11 additions & 3 deletions pkg/mcp/runner.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
package mcp

import (
"context"
"encoding/json"
"fmt"
"strings"

"github.com/gptscript-ai/gptscript/pkg/engine"
"github.com/gptscript-ai/gptscript/pkg/types"
"github.com/mark3labs/mcp-go/mcp"
)

func (l *Local) Run(ctx context.Context, _ chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) {
func (l *Local) Run(ctx engine.Context, _ chan<- types.CompletionStatus, tool types.Tool, input string) (string, error) {
fields := strings.Fields(tool.Instructions)
if len(fields) < 2 {
return "", fmt.Errorf("invalid mcp call, invalid number of fields in %s", tool.Instructions)
Expand Down Expand Up @@ -41,8 +41,16 @@ func (l *Local) Run(ctx context.Context, _ chan<- types.CompletionStatus, tool t
request.Params.Name = toolName
request.Params.Arguments = arguments

result, err := session.Client.CallTool(ctx, request)
result, err := session.Client.CallTool(ctx.Ctx, request)
if err != nil {
if ctx.ToolCategory == engine.NoCategory && ctx.Parent != nil {
var output []byte
if result != nil {
output, _ = json.Marshal(result)
}
// If this is a sub-call, then don't return the error; return the error as a message so that the LLM can retry.
return fmt.Sprintf("ERROR: got (%v) while running tool, OUTPUT: %s", err, string(output)), nil
}
return "", fmt.Errorf("failed to call tool %s: %w", toolName, err)
}

Expand Down