Skip to content

Commit a87330a

Browse files
feat: output filters
1 parent d6fc958 commit a87330a

19 files changed

+542
-49
lines changed

pkg/engine/engine.go

+1
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ const (
9999
CredentialToolCategory ToolCategory = "credential"
100100
ContextToolCategory ToolCategory = "context"
101101
InputToolCategory ToolCategory = "input"
102+
OutputToolCategory ToolCategory = "output"
102103
NoCategory ToolCategory = ""
103104
)
104105

pkg/parser/parser.go

+5
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ func isParam(line string, tool *types.Tool) (_ bool, err error) {
109109
tool.Parameters.InputFilters = append(tool.Parameters.InputFilters, csv(value)...)
110110
case "shareinputfilter", "shareinputfilters":
111111
tool.Parameters.ExportInputFilters = append(tool.Parameters.ExportInputFilters, csv(value)...)
112+
case "outputfilter", "outputfilters":
113+
tool.Parameters.OutputFilters = append(tool.Parameters.OutputFilters, csv(value)...)
114+
case "shareoutputfilter", "shareoutputfilters":
115+
tool.Parameters.ExportOutputFilters = append(tool.Parameters.ExportOutputFilters, csv(value)...)
112116
case "agent", "agents":
113117
tool.Parameters.Agents = append(tool.Parameters.Agents, csv(value)...)
114118
case "globaltool", "globaltools":
@@ -194,6 +198,7 @@ func (c *context) finish(tools *[]Node) {
194198
c.tool.GlobalModelName != "" ||
195199
len(c.tool.GlobalTools) > 0 ||
196200
len(c.tool.ExportInputFilters) > 0 ||
201+
len(c.tool.ExportOutputFilters) > 0 ||
197202
c.tool.Chat {
198203
*tools = append(*tools, Node{
199204
ToolNode: &ToolNode{

pkg/parser/parser_test.go

+24
Original file line numberDiff line numberDiff line change
@@ -215,3 +215,27 @@ share input filters: shared
215215
}},
216216
}}).Equal(t, out)
217217
}
218+
219+
func TestParseOutput(t *testing.T) {
220+
output := `
221+
output filters: output
222+
share output filters: shared
223+
`
224+
out, err := Parse(strings.NewReader(output))
225+
require.NoError(t, err)
226+
autogold.Expect(Document{Nodes: []Node{
227+
{ToolNode: &ToolNode{
228+
Tool: types.Tool{
229+
ToolDef: types.ToolDef{
230+
Parameters: types.Parameters{
231+
OutputFilters: []string{
232+
"output",
233+
},
234+
ExportOutputFilters: []string{"shared"},
235+
},
236+
},
237+
Source: types.ToolSource{LineNo: 1},
238+
},
239+
}},
240+
}}).Equal(t, out)
241+
}

pkg/runner/input.go

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package runner
22

33
import (
4+
"encoding/json"
45
"fmt"
56

67
"github.com/gptscript-ai/gptscript/pkg/engine"
@@ -13,7 +14,13 @@ func (r *Runner) handleInput(callCtx engine.Context, monitor Monitor, env []stri
1314
}
1415

1516
for _, inputToolRef := range inputToolRefs {
16-
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, input, "", engine.InputToolCategory)
17+
inputData, err := json.Marshal(map[string]any{
18+
"input": input,
19+
})
20+
if err != nil {
21+
return "", fmt.Errorf("failed to marshal input: %w", err)
22+
}
23+
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, inputToolRef.ToolID, string(inputData), "", engine.InputToolCategory)
1724
if err != nil {
1825
return "", err
1926
}

pkg/runner/output.go

+72
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
package runner
2+
3+
import (
4+
"encoding/json"
5+
"errors"
6+
"fmt"
7+
8+
"github.com/gptscript-ai/gptscript/pkg/engine"
9+
)
10+
11+
func (r *Runner) handleOutput(callCtx engine.Context, monitor Monitor, env []string, state *State, retErr error) (*State, error) {
12+
outputToolRefs, err := callCtx.Tool.GetOutputFilterTools(*callCtx.Program)
13+
if err != nil {
14+
return nil, err
15+
}
16+
17+
if len(outputToolRefs) == 0 {
18+
return state, retErr
19+
}
20+
21+
var (
22+
continuation bool
23+
chatFinish bool
24+
output string
25+
)
26+
27+
if errMessage := (*engine.ErrChatFinish)(nil); errors.As(retErr, &errMessage) && callCtx.Tool.Chat {
28+
chatFinish = true
29+
output = errMessage.Message
30+
} else if retErr != nil {
31+
return state, retErr
32+
} else if state.Continuation != nil && state.Continuation.Result != nil {
33+
continuation = true
34+
output = *state.Continuation.Result
35+
} else if state.Result != nil {
36+
output = *state.Result
37+
} else {
38+
return state, nil
39+
}
40+
41+
for _, outputToolRef := range outputToolRefs {
42+
inputData, err := json.Marshal(map[string]any{
43+
"output": output,
44+
"chatFinish": chatFinish,
45+
"continuation": continuation,
46+
"chat": callCtx.Tool.Chat,
47+
})
48+
if err != nil {
49+
return nil, fmt.Errorf("marshaling input for output filter: %w", err)
50+
}
51+
res, err := r.subCall(callCtx.Ctx, callCtx, monitor, env, outputToolRef.ToolID, string(inputData), "", engine.OutputToolCategory)
52+
if err != nil {
53+
return nil, err
54+
}
55+
if res.Result == nil {
56+
return nil, fmt.Errorf("invalid state: output tool [%s] can not result in a chat continuation", outputToolRef.Reference)
57+
}
58+
output = *res.Result
59+
}
60+
61+
if chatFinish {
62+
return state, &engine.ErrChatFinish{
63+
Message: output,
64+
}
65+
} else if continuation {
66+
state.Continuation.Result = &output
67+
} else {
68+
state.Result = &output
69+
}
70+
71+
return state, nil
72+
}

pkg/runner/runner.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -536,7 +536,11 @@ type Needed struct {
536536
Input string `json:"input,omitempty"`
537537
}
538538

539-
func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, state *State) (*State, error) {
539+
func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, state *State) (retState *State, retErr error) {
540+
defer func() {
541+
retState, retErr = r.handleOutput(callCtx, monitor, env, retState, retErr)
542+
}()
543+
540544
if state.StartContinuation {
541545
return nil, fmt.Errorf("invalid state, resume should not have StartContinuation set to true")
542546
}

pkg/tests/runner_test.go

+46
Original file line numberDiff line numberDiff line change
@@ -849,6 +849,52 @@ func TestInput(t *testing.T) {
849849
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step2"))
850850
}
851851

852+
func TestOutput(t *testing.T) {
853+
if runtime.GOOS == "windows" {
854+
t.Skip()
855+
}
856+
857+
r := tester.NewRunner(t)
858+
r.RespondWith(tester.Result{
859+
Text: "Response 1",
860+
})
861+
862+
prg, err := r.Load("")
863+
require.NoError(t, err)
864+
865+
resp, err := r.Chat(context.Background(), nil, prg, nil, "Input 1")
866+
require.NoError(t, err)
867+
r.AssertResponded(t)
868+
assert.False(t, resp.Done)
869+
autogold.Expect(`CHAT: true CONTENT: Response 1 CONTINUATION: true FINISH: false suffix
870+
`).Equal(t, resp.Content)
871+
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1"))
872+
873+
r.RespondWith(tester.Result{
874+
Text: "Response 2",
875+
})
876+
resp, err = r.Chat(context.Background(), resp.State, prg, nil, "Input 2")
877+
require.NoError(t, err)
878+
r.AssertResponded(t)
879+
assert.False(t, resp.Done)
880+
autogold.Expect(`CHAT: true CONTENT: Response 2 CONTINUATION: true FINISH: false suffix
881+
`).Equal(t, resp.Content)
882+
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step2"))
883+
884+
r.RespondWith(tester.Result{
885+
Err: &engine.ErrChatFinish{
886+
Message: "Chat Done",
887+
},
888+
})
889+
resp, err = r.Chat(context.Background(), resp.State, prg, nil, "Input 3")
890+
require.NoError(t, err)
891+
r.AssertResponded(t)
892+
assert.True(t, resp.Done)
893+
autogold.Expect(`CHAT FINISH: CHAT: true CONTENT: Chat Done CONTINUATION: false FINISH: true suffix
894+
`).Equal(t, resp.Content)
895+
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step3"))
896+
}
897+
852898
func TestSysContext(t *testing.T) {
853899
if runtime.GOOS == "windows" {
854900
t.Skip()

pkg/tests/testdata/TestInput/test.gpt

+4-2
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,10 @@ Tool body
77
---
88
name: taunt
99
args: foo: this is useless
10+
args: input: this is used
1011
#!/bin/bash
1112

12-
echo "No, ${GPTSCRIPT_INPUT}!"
13+
echo "No, ${INPUT}!"
1314

1415
---
1516
name: exporter
@@ -18,6 +19,7 @@ share input filters: taunt2
1819
---
1920
name: taunt2
2021
args: foo: this is useless
22+
args: input: this is used
2123

2224
#!/bin/bash
23-
echo "${GPTSCRIPT_INPUT} ha ha ha"
25+
echo "${INPUT} ha ha ha"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
`{
2+
"role": "assistant",
3+
"content": [
4+
{
5+
"text": "Response 1"
6+
}
7+
],
8+
"usage": {}
9+
}`
+24
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
`{
2+
"model": "gpt-4o",
3+
"internalSystemPrompt": false,
4+
"messages": [
5+
{
6+
"role": "system",
7+
"content": [
8+
{
9+
"text": "\nTool body"
10+
}
11+
],
12+
"usage": {}
13+
},
14+
{
15+
"role": "user",
16+
"content": [
17+
{
18+
"text": "Input 1"
19+
}
20+
],
21+
"usage": {}
22+
}
23+
]
24+
}`
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
`{
2+
"role": "assistant",
3+
"content": [
4+
{
5+
"text": "Response 2"
6+
}
7+
],
8+
"usage": {}
9+
}`
+42
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
`{
2+
"model": "gpt-4o",
3+
"internalSystemPrompt": false,
4+
"messages": [
5+
{
6+
"role": "system",
7+
"content": [
8+
{
9+
"text": "\nTool body"
10+
}
11+
],
12+
"usage": {}
13+
},
14+
{
15+
"role": "user",
16+
"content": [
17+
{
18+
"text": "Input 1"
19+
}
20+
],
21+
"usage": {}
22+
},
23+
{
24+
"role": "assistant",
25+
"content": [
26+
{
27+
"text": "Response 1"
28+
}
29+
],
30+
"usage": {}
31+
},
32+
{
33+
"role": "user",
34+
"content": [
35+
{
36+
"text": "Input 2"
37+
}
38+
],
39+
"usage": {}
40+
}
41+
]
42+
}`
+60
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,60 @@
1+
`{
2+
"model": "gpt-4o",
3+
"internalSystemPrompt": false,
4+
"messages": [
5+
{
6+
"role": "system",
7+
"content": [
8+
{
9+
"text": "\nTool body"
10+
}
11+
],
12+
"usage": {}
13+
},
14+
{
15+
"role": "user",
16+
"content": [
17+
{
18+
"text": "Input 1"
19+
}
20+
],
21+
"usage": {}
22+
},
23+
{
24+
"role": "assistant",
25+
"content": [
26+
{
27+
"text": "Response 1"
28+
}
29+
],
30+
"usage": {}
31+
},
32+
{
33+
"role": "user",
34+
"content": [
35+
{
36+
"text": "Input 2"
37+
}
38+
],
39+
"usage": {}
40+
},
41+
{
42+
"role": "assistant",
43+
"content": [
44+
{
45+
"text": "Response 2"
46+
}
47+
],
48+
"usage": {}
49+
},
50+
{
51+
"role": "user",
52+
"content": [
53+
{
54+
"text": "Input 3"
55+
}
56+
],
57+
"usage": {}
58+
}
59+
]
60+
}`

0 commit comments

Comments
 (0)