Skip to content

Commit 9845f06

Browse files
Merge pull request #214 from ibuildthecloud/completiontext
chore: add program.GetCompletionTools()
2 parents fede61d + 0e4fe10 commit 9845f06

File tree

7 files changed

+106
-84
lines changed

7 files changed

+106
-84
lines changed

pkg/engine/engine.go

+9-78
Original file line numberDiff line numberDiff line change
@@ -7,21 +7,12 @@ import (
77
"sync"
88
"sync/atomic"
99

10-
"github.com/gptscript-ai/gptscript/pkg/system"
1110
"github.com/gptscript-ai/gptscript/pkg/types"
1211
"github.com/gptscript-ai/gptscript/pkg/version"
1312
)
1413

1514
var completionID int64
1615

17-
type ErrToolNotFound struct {
18-
ToolName string
19-
}
20-
21-
func (e *ErrToolNotFound) Error() string {
22-
return fmt.Sprintf("tool not found: %s", e.ToolName)
23-
}
24-
2516
type Model interface {
2617
Call(ctx context.Context, messageRequest types.CompletionRequest, status chan<- types.CompletionStatus) (*types.CompletionMessage, error)
2718
}
@@ -62,12 +53,11 @@ type CallResult struct {
6253
}
6354

6455
type Context struct {
65-
ID string
66-
Ctx context.Context
67-
Parent *Context
68-
Program *types.Program
69-
Tool types.Tool
70-
toolNames map[string]struct{}
56+
ID string
57+
Ctx context.Context
58+
Parent *Context
59+
Program *types.Program
60+
Tool types.Tool
7161
}
7262

7363
func (c *Context) ParentID() string {
@@ -119,65 +109,6 @@ func (c *Context) SubCall(ctx context.Context, toolID, callID string) (Context,
119109
}, nil
120110
}
121111

122-
func (c *Context) getTool(parent types.Tool, name string) (types.Tool, error) {
123-
toolID, ok := parent.ToolMapping[name]
124-
if !ok {
125-
return types.Tool{}, &ErrToolNotFound{
126-
ToolName: name,
127-
}
128-
}
129-
tool, ok := c.Program.ToolSet[toolID]
130-
if !ok {
131-
return types.Tool{}, &ErrToolNotFound{
132-
ToolName: name,
133-
}
134-
}
135-
return tool, nil
136-
}
137-
138-
func (c *Context) appendTool(completion *types.CompletionRequest, parentTool types.Tool, subToolName string) error {
139-
subTool, err := c.getTool(parentTool, subToolName)
140-
if err != nil {
141-
return err
142-
}
143-
144-
args := subTool.Parameters.Arguments
145-
if args == nil && !subTool.IsCommand() {
146-
args = &system.DefaultToolSchema
147-
}
148-
149-
for _, existingTool := range completion.Tools {
150-
if existingTool.Function.ToolID == subTool.ID {
151-
return nil
152-
}
153-
}
154-
155-
if c.toolNames == nil {
156-
c.toolNames = map[string]struct{}{}
157-
}
158-
159-
if subTool.Instructions == "" {
160-
log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID)
161-
} else {
162-
completion.Tools = append(completion.Tools, types.CompletionTool{
163-
Function: types.CompletionFunctionDefinition{
164-
ToolID: subTool.ID,
165-
Name: PickToolName(subToolName, c.toolNames),
166-
Description: subTool.Parameters.Description,
167-
Parameters: args,
168-
},
169-
})
170-
}
171-
172-
for _, export := range subTool.Export {
173-
if err := c.appendTool(completion, subTool, export); err != nil {
174-
return err
175-
}
176-
}
177-
178-
return nil
179-
}
180-
181112
func (e *Engine) Start(ctx Context, input string) (*Return, error) {
182113
tool := ctx.Tool
183114

@@ -207,10 +138,10 @@ func (e *Engine) Start(ctx Context, input string) (*Return, error) {
207138
InternalSystemPrompt: tool.Parameters.InternalPrompt,
208139
}
209140

210-
for _, subToolName := range tool.Parameters.Tools {
211-
if err := ctx.appendTool(&completion, ctx.Tool, subToolName); err != nil {
212-
return nil, err
213-
}
141+
var err error
142+
completion.Tools, err = tool.GetCompletionTools(*ctx.Program)
143+
if err != nil {
144+
return nil, err
214145
}
215146

216147
if tool.Instructions != "" {

pkg/loader/loader.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@ import (
1717
"github.com/getkin/kin-openapi/openapi3"
1818
"github.com/gptscript-ai/gptscript/pkg/assemble"
1919
"github.com/gptscript-ai/gptscript/pkg/builtin"
20-
"github.com/gptscript-ai/gptscript/pkg/engine"
2120
"github.com/gptscript-ai/gptscript/pkg/parser"
2221
"github.com/gptscript-ai/gptscript/pkg/types"
2322
"gopkg.in/yaml.v3"
@@ -109,7 +108,7 @@ func loadProgram(data []byte, into *types.Program, targetToolName string) (types
109108

110109
tool, ok := into.ToolSet[tool.LocalTools[targetToolName]]
111110
if !ok {
112-
return tool, &engine.ErrToolNotFound{
111+
return tool, &types.ErrToolNotFound{
113112
ToolName: targetToolName,
114113
}
115114
}

pkg/tests/runner_test.go

+1-2
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@ package tests
33
import (
44
"testing"
55

6-
"github.com/gptscript-ai/gptscript/pkg/engine"
76
"github.com/gptscript-ai/gptscript/pkg/tests/tester"
87
"github.com/gptscript-ai/gptscript/pkg/types"
98
"github.com/stretchr/testify/assert"
@@ -15,7 +14,7 @@ func TestCwd(t *testing.T) {
1514

1615
runner.RespondWith(tester.Result{
1716
Func: types.CompletionFunctionCall{
18-
Name: engine.ToolNormalizer("./subtool/test.gpt"),
17+
Name: types.ToolNormalizer("./subtool/test.gpt"),
1918
},
2019
})
2120
runner.RespondWith(tester.Result{

pkg/types/completion.go

-1
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ type CompletionFunctionDefinition struct {
2828
ToolID string `json:"toolID,omitempty"`
2929
Name string `json:"name"`
3030
Description string `json:"description,omitempty"`
31-
Domain string `json:"domain,omitempty"`
3231
Parameters *openapi3.Schema `json:"parameters"`
3332
}
3433

pkg/types/log.go

+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
package types
2+
3+
import "github.com/gptscript-ai/gptscript/pkg/mvl"
4+
5+
var log = mvl.Package()

pkg/types/tool.go

+89
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ import (
77
"strings"
88

99
"github.com/getkin/kin-openapi/openapi3"
10+
"github.com/gptscript-ai/gptscript/pkg/system"
1011
"golang.org/x/exp/maps"
1112
)
1213

@@ -16,6 +17,14 @@ const (
1617
CommandPrefix = "#!"
1718
)
1819

20+
type ErrToolNotFound struct {
21+
ToolName string
22+
}
23+
24+
func (e *ErrToolNotFound) Error() string {
25+
return fmt.Sprintf("tool not found: %s", e.ToolName)
26+
}
27+
1928
type ToolSet map[string]Tool
2029

2130
type Program struct {
@@ -24,6 +33,17 @@ type Program struct {
2433
ToolSet ToolSet `json:"toolSet,omitempty"`
2534
}
2635

36+
func (p Program) GetCompletionTools() (result []CompletionTool, err error) {
37+
return Tool{
38+
Parameters: Parameters{
39+
Tools: []string{"main"},
40+
},
41+
ToolMapping: map[string]string{
42+
"main": p.EntryToolID,
43+
},
44+
}.GetCompletionTools(p)
45+
}
46+
2747
func (p Program) TopLevelTools() (result []Tool) {
2848
for _, tool := range p.ToolSet[p.EntryToolID].LocalTools {
2949
result = append(result, p.ToolSet[tool])
@@ -124,6 +144,75 @@ func (t Tool) String() string {
124144
return buf.String()
125145
}
126146

147+
func (t Tool) GetCompletionTools(prg Program) (result []CompletionTool, err error) {
148+
toolNames := map[string]struct{}{}
149+
150+
for _, subToolName := range t.Parameters.Tools {
151+
result, err = appendTool(result, prg, t, subToolName, toolNames)
152+
if err != nil {
153+
return nil, err
154+
}
155+
}
156+
157+
return result, nil
158+
}
159+
160+
func getTool(prg Program, parent Tool, name string) (Tool, error) {
161+
toolID, ok := parent.ToolMapping[name]
162+
if !ok {
163+
return Tool{}, &ErrToolNotFound{
164+
ToolName: name,
165+
}
166+
}
167+
tool, ok := prg.ToolSet[toolID]
168+
if !ok {
169+
return Tool{}, &ErrToolNotFound{
170+
ToolName: name,
171+
}
172+
}
173+
return tool, nil
174+
}
175+
176+
func appendTool(completionTools []CompletionTool, prg Program, parentTool Tool, subToolName string, toolNames map[string]struct{}) ([]CompletionTool, error) {
177+
subTool, err := getTool(prg, parentTool, subToolName)
178+
if err != nil {
179+
return nil, err
180+
}
181+
182+
args := subTool.Parameters.Arguments
183+
if args == nil && !subTool.IsCommand() {
184+
args = &system.DefaultToolSchema
185+
}
186+
187+
for _, existingTool := range completionTools {
188+
if existingTool.Function.ToolID == subTool.ID {
189+
return completionTools, nil
190+
}
191+
}
192+
193+
if subTool.Instructions == "" {
194+
log.Debugf("Skipping zero instruction tool %s (%s)", subToolName, subTool.ID)
195+
} else {
196+
completionTools = append(completionTools, CompletionTool{
197+
Function: CompletionFunctionDefinition{
198+
ToolID: subTool.ID,
199+
Name: PickToolName(subToolName, toolNames),
200+
Description: subTool.Parameters.Description,
201+
Parameters: args,
202+
},
203+
})
204+
}
205+
206+
for _, export := range subTool.Export {
207+
completionTools, err = appendTool(completionTools, prg, subTool, export, toolNames)
208+
if err != nil {
209+
return nil, err
210+
}
211+
}
212+
213+
return completionTools, nil
214+
}
215+
127216
type Repo struct {
128217
// VCS The VCS type, such as "git"
129218
VCS string

pkg/engine/toolname.go renamed to pkg/types/toolname.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
package engine
1+
package types
22

33
import (
44
"path/filepath"

0 commit comments

Comments
 (0)