Skip to content

Commit 31d4f8d

Browse files
chore: assign provider tool category when launch provider
1 parent 3bc5a6d commit 31d4f8d

File tree

4 files changed

+16
-5
lines changed

4 files changed

+16
-5
lines changed

pkg/engine/engine.go

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,7 @@ type Context struct {
7777
type ToolCategory string
7878

7979
const (
80+
ProviderToolCategory ToolCategory = "provider"
8081
CredentialToolCategory ToolCategory = "credential"
8182
ContextToolCategory ToolCategory = "context"
8283
NoCategory ToolCategory = ""
@@ -120,11 +121,20 @@ func (c *Context) MarshalJSON() ([]byte, error) {
120121
return json.Marshal(c.GetCallContext())
121122
}
122123

124+
type toolCategoryKey struct{}
125+
126+
func WithToolCategory(ctx context.Context, toolCategory ToolCategory) context.Context {
127+
return context.WithValue(ctx, toolCategoryKey{}, toolCategory)
128+
}
129+
123130
func NewContext(ctx context.Context, prg *types.Program) Context {
131+
category, _ := ctx.Value(toolCategoryKey{}).(ToolCategory)
132+
124133
callCtx := Context{
125134
commonContext: commonContext{
126-
ID: counter.Next(),
127-
Tool: prg.ToolSet[prg.EntryToolID],
135+
ID: counter.Next(),
136+
Tool: prg.ToolSet[prg.EntryToolID],
137+
ToolCategory: category,
128138
},
129139
Ctx: ctx,
130140
Program: prg,

pkg/remote/remote.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"sync"
1111

1212
"github.com/gptscript-ai/gptscript/pkg/cache"
13+
"github.com/gptscript-ai/gptscript/pkg/engine"
1314
env2 "github.com/gptscript-ai/gptscript/pkg/env"
1415
"github.com/gptscript-ai/gptscript/pkg/loader"
1516
"github.com/gptscript-ai/gptscript/pkg/mvl"
@@ -144,7 +145,7 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
144145
return nil, err
145146
}
146147

147-
url, err := c.runner.Run(ctx, prg.SetBlocking(), c.envs, "")
148+
url, err := c.runner.Run(engine.WithToolCategory(ctx, engine.ProviderToolCategory), prg.SetBlocking(), c.envs, "")
148149
if err != nil {
149150
return nil, err
150151
}

pkg/runner/runner.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -527,7 +527,7 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
527527
err error
528528
)
529529

530-
state, callResults, err = r.subCalls(callCtx, monitor, env, state, engine.NoCategory)
530+
state, callResults, err = r.subCalls(callCtx, monitor, env, state, callCtx.ToolCategory)
531531
if errMessage := (*builtin.ErrChatFinish)(nil); errors.As(err, &errMessage) && callCtx.Tool.Chat {
532532
return &State{
533533
Result: &errMessage.Message,

pkg/types/completion.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ func (in CompletionMessage) String() string {
9898
}
9999
buf.WriteString(content.Text)
100100
if content.ToolCall != nil {
101-
buf.WriteString(fmt.Sprintf("tool call %s -> %s", color.GreenString(content.ToolCall.Function.Name), content.ToolCall.Function.Arguments))
101+
buf.WriteString(fmt.Sprintf("<tool call> %s -> %s", color.GreenString(content.ToolCall.Function.Name), content.ToolCall.Function.Arguments))
102102
}
103103
}
104104
return buf.String()

0 commit comments

Comments
 (0)