Skip to content

Commit 5d2505a

Browse files
Merge pull request #825 from ibuildthecloud/context-with
chore: add with * syntax to context tools
2 parents 61592f1 + 93e7706 commit 5d2505a

25 files changed

+111
-913
lines changed

pkg/repos/runtimes/node/node.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ var releasesData []byte
2626
const (
2727
downloadURL = "https://nodejs.org/dist/%s/"
2828
packageJSON = "package.json"
29+
nodeModules = "node_modules"
2930
)
3031

3132
type Runtime struct {
@@ -64,8 +65,15 @@ func (r *Runtime) supports(testCmd string, cmd []string) bool {
6465

6566
func (r *Runtime) GetHash(tool types.Tool) (string, error) {
6667
if !tool.Source.IsGit() && tool.WorkingDir != "" {
68+
var prefix string
69+
// This hashes if the node_modules directory was deleted
70+
if s, err := os.Stat(filepath.Join(tool.WorkingDir, nodeModules)); err == nil {
71+
prefix = hash.Digest(tool.WorkingDir + s.ModTime().String())[:7]
72+
} else if s, err := os.Stat(tool.WorkingDir); err == nil {
73+
prefix = hash.Digest(tool.WorkingDir + s.ModTime().String())[:7]
74+
}
6775
if s, err := os.Stat(filepath.Join(tool.WorkingDir, packageJSON)); err == nil {
68-
return hash.Digest(tool.WorkingDir + s.ModTime().String())[:7], nil
76+
return prefix + hash.Digest(tool.WorkingDir + s.ModTime().String())[:7], nil
6977
}
7078
}
7179
return "", nil

pkg/runner/runner.go

Lines changed: 23 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -172,11 +172,7 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
172172
return resp, err
173173
}
174174

175-
if state == nil || state.StartContinuation {
176-
if state != nil {
177-
state = state.WithResumeInput(&input)
178-
input = state.InputContextContinuationInput
179-
}
175+
if state == nil {
180176
state, err = r.start(callCtx, state, monitor, env, input)
181177
if err != nil {
182178
return resp, err
@@ -186,11 +182,9 @@ func (r *Runner) Chat(ctx context.Context, prevState ChatState, prg types.Progra
186182
state.ResumeInput = &input
187183
}
188184

189-
if !state.StartContinuation {
190-
state, err = r.resume(callCtx, monitor, env, state)
191-
if err != nil {
192-
return resp, err
193-
}
185+
state, err = r.resume(callCtx, monitor, env, state)
186+
if err != nil {
187+
return resp, err
194188
}
195189

196190
if state.Result != nil {
@@ -260,6 +254,10 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
260254
targetArgs := prg.ToolSet[ref.ToolID].Arguments
261255
targetKeys := map[string]string{}
262256

257+
if ref.Arg == "*" {
258+
return input, nil
259+
}
260+
263261
if targetArgs == nil {
264262
return "", nil
265263
}
@@ -331,24 +329,10 @@ func getToolRefInput(prg *types.Program, ref types.ToolReference, input string)
331329
return string(output), err
332330
}
333331

334-
func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (result []engine.InputContext, _ *State, _ error) {
332+
func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monitor, env []string, input string) (result []engine.InputContext, _ error) {
335333
toolRefs, err := callCtx.Tool.GetContextTools(*callCtx.Program)
336334
if err != nil {
337-
return nil, nil, err
338-
}
339-
340-
var newState *State
341-
if state != nil {
342-
cp := *state
343-
newState = &cp
344-
if newState.InputContextContinuation != nil {
345-
newState.InputContexts = nil
346-
newState.InputContextContinuation = nil
347-
newState.InputContextContinuationInput = ""
348-
newState.ResumeInput = state.InputContextContinuationResumeInput
349-
350-
input = state.InputContextContinuationInput
351-
}
335+
return nil, err
352336
}
353337

354338
for i, toolRef := range toolRefs {
@@ -359,47 +343,31 @@ func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monito
359343

360344
contextInput, err := getToolRefInput(callCtx.Program, toolRef, input)
361345
if err != nil {
362-
return nil, nil, err
346+
return nil, err
363347
}
364348

365349
var content *State
366-
if state != nil && state.InputContextContinuation != nil {
367-
content, err = r.subCallResume(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, "", state.InputContextContinuation.WithResumeInput(state.ResumeInput), engine.ContextToolCategory)
368-
} else {
369-
content, err = r.subCall(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, contextInput, "", engine.ContextToolCategory)
370-
}
350+
content, err = r.subCall(callCtx.Ctx, callCtx, monitor, env, toolRef.ToolID, contextInput, "", engine.ContextToolCategory)
371351
if err != nil {
372-
return nil, nil, err
352+
return nil, err
373353
}
374354
if content.Continuation != nil {
375-
if newState == nil {
376-
newState = &State{}
377-
}
378-
newState.InputContexts = result
379-
newState.InputContextContinuation = content
380-
newState.InputContextContinuationInput = input
381-
if state != nil {
382-
newState.InputContextContinuationResumeInput = state.ResumeInput
383-
}
384-
return nil, newState, nil
355+
return nil, fmt.Errorf("invalid state: context tool [%s] can not result in a continuation", toolRef.ToolID)
385356
}
386357
result = append(result, engine.InputContext{
387358
ToolID: toolRef.ToolID,
388359
Content: *content.Result,
389360
})
390361
}
391362

392-
return result, newState, nil
363+
return result, nil
393364
}
394365

395366
func (r *Runner) call(callCtx engine.Context, monitor Monitor, env []string, input string) (*State, error) {
396367
result, err := r.start(callCtx, nil, monitor, env, input)
397368
if err != nil {
398369
return nil, err
399370
}
400-
if result.StartContinuation {
401-
return result, nil
402-
}
403371
return r.resume(callCtx, monitor, env, result)
404372
}
405373

@@ -431,15 +399,10 @@ func (r *Runner) start(callCtx engine.Context, state *State, monitor Monitor, en
431399
}
432400
}
433401

434-
var newState *State
435-
callCtx.InputContext, newState, err = r.getContext(callCtx, state, monitor, env, input)
402+
callCtx.InputContext, err = r.getContext(callCtx, state, monitor, env, input)
436403
if err != nil {
437404
return nil, err
438405
}
439-
if newState != nil && newState.InputContextContinuation != nil {
440-
newState.StartContinuation = true
441-
return newState, nil
442-
}
443406

444407
e := engine.Engine{
445408
Model: r.c,
@@ -489,11 +452,7 @@ type State struct {
489452
SubCalls []SubCallResult `json:"subCalls,omitempty"`
490453
SubCallID string `json:"subCallID,omitempty"`
491454

492-
InputContexts []engine.InputContext `json:"inputContexts,omitempty"`
493-
InputContextContinuation *State `json:"inputContextContinuation,omitempty"`
494-
InputContextContinuationInput string `json:"inputContextContinuationInput,omitempty"`
495-
InputContextContinuationResumeInput *string `json:"inputContextContinuationResumeInput,omitempty"`
496-
StartContinuation bool `json:"startContinuation,omitempty"`
455+
InputContexts []engine.InputContext `json:"inputContexts,omitempty"`
497456
}
498457

499458
func (s State) WithResumeInput(input *string) *State {
@@ -506,10 +465,6 @@ func (s State) ContinuationContentToolID() (string, error) {
506465
return s.ContinuationToolID, nil
507466
}
508467

509-
if s.InputContextContinuation != nil {
510-
return s.InputContextContinuation.ContinuationContentToolID()
511-
}
512-
513468
for _, subCall := range s.SubCalls {
514469
if s.SubCallID == subCall.CallID {
515470
return subCall.State.ContinuationContentToolID()
@@ -523,10 +478,6 @@ func (s State) ContinuationContent() (string, error) {
523478
return *s.Continuation.Result, nil
524479
}
525480

526-
if s.InputContextContinuation != nil {
527-
return s.InputContextContinuation.ContinuationContent()
528-
}
529-
530481
for _, subCall := range s.SubCalls {
531482
if s.SubCallID == subCall.CallID {
532483
return subCall.State.ContinuationContent()
@@ -545,10 +496,6 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
545496
retState, retErr = r.handleOutput(callCtx, monitor, env, retState, retErr)
546497
}()
547498

548-
if state.StartContinuation {
549-
return nil, fmt.Errorf("invalid state, resume should not have StartContinuation set to true")
550-
}
551-
552499
if state.Continuation == nil {
553500
return nil, errors.New("invalid state, resume should have Continuation data")
554501
}
@@ -653,8 +600,12 @@ func (r *Runner) resume(callCtx engine.Context, monitor Monitor, env []string, s
653600
contentInput = state.Continuation.State.Input
654601
}
655602

656-
callCtx.InputContext, state, err = r.getContext(callCtx, state, monitor, env, contentInput)
657-
if err != nil || state.InputContextContinuation != nil {
603+
if state.ResumeInput != nil {
604+
contentInput = *state.ResumeInput
605+
}
606+
607+
callCtx.InputContext, err = r.getContext(callCtx, state, monitor, env, contentInput)
608+
if err != nil {
658609
return state, err
659610
}
660611

@@ -764,10 +715,6 @@ func (r *Runner) subCalls(callCtx engine.Context, monitor Monitor, env []string,
764715
callCtx.LastReturn = state.Continuation
765716
}
766717

767-
if state.InputContextContinuation != nil {
768-
return state, nil, nil
769-
}
770-
771718
if state.SubCallID != "" {
772719
if state.ResumeInput == nil {
773720
return nil, nil, fmt.Errorf("invalid state, input must be set for sub call continuation on callID [%s]", state.SubCallID)

pkg/tests/runner2_test.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package tests
2+
3+
import (
4+
"context"
5+
"testing"
6+
7+
"github.com/gptscript-ai/gptscript/pkg/loader"
8+
"github.com/gptscript-ai/gptscript/pkg/tests/tester"
9+
"github.com/stretchr/testify/require"
10+
)
11+
12+
func TestContextWithAsterick(t *testing.T) {
13+
r := tester.NewRunner(t)
14+
prg, err := loader.ProgramFromSource(context.Background(), `
15+
chat: true
16+
context: foo with *
17+
18+
Say hi
19+
20+
---
21+
name: foo
22+
23+
#!/bin/bash
24+
25+
echo This is the input: ${GPTSCRIPT_INPUT}
26+
`, "")
27+
require.NoError(t, err)
28+
29+
resp, err := r.Chat(context.Background(), nil, prg, nil, "input 1")
30+
r.AssertStep(t, resp, err)
31+
32+
resp, err = r.Chat(context.Background(), resp.State, prg, nil, "input 2")
33+
r.AssertStep(t, resp, err)
34+
}

pkg/tests/runner_test.go

Lines changed: 2 additions & 76 deletions
Original file line numberDiff line numberDiff line change
@@ -213,82 +213,8 @@ func TestContextSubChat(t *testing.T) {
213213
prg, err := r.Load("")
214214
require.NoError(t, err)
215215

216-
resp, err := r.Chat(context.Background(), nil, prg, os.Environ(), "User 1")
217-
require.NoError(t, err)
218-
r.AssertResponded(t)
219-
assert.False(t, resp.Done)
220-
autogold.Expect("Assistant Response 1 - from chatbot1").Equal(t, resp.Content)
221-
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step1"))
222-
223-
r.RespondWith(tester.Result{
224-
Content: []types.ContentPart{
225-
{
226-
ToolCall: &types.CompletionToolCall{
227-
ID: "call_2",
228-
Function: types.CompletionFunctionCall{
229-
Name: types.ToolNormalizer("sys.chat.finish"),
230-
Arguments: "Response from context chatbot",
231-
},
232-
},
233-
},
234-
},
235-
}, tester.Result{
236-
Text: "Assistant Response 2 - from context tool",
237-
}, tester.Result{
238-
Text: "Assistant Response 3 - from main chat tool",
239-
})
240-
resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 2")
241-
require.NoError(t, err)
242-
r.AssertResponded(t)
243-
assert.False(t, resp.Done)
244-
autogold.Expect("Assistant Response 3 - from main chat tool").Equal(t, resp.Content)
245-
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step2"))
246-
247-
r.RespondWith(tester.Result{
248-
Content: []types.ContentPart{
249-
{
250-
ToolCall: &types.CompletionToolCall{
251-
ID: "call_3",
252-
Function: types.CompletionFunctionCall{
253-
Name: "chatbot",
254-
Arguments: "Input to chatbot1 on resume",
255-
},
256-
},
257-
},
258-
},
259-
}, tester.Result{
260-
Text: "Assistant Response 4 - from chatbot1",
261-
})
262-
resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 3")
263-
require.NoError(t, err)
264-
r.AssertResponded(t)
265-
assert.False(t, resp.Done)
266-
autogold.Expect("Assistant Response 3 - from main chat tool").Equal(t, resp.Content)
267-
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step3"))
268-
269-
r.RespondWith(tester.Result{
270-
Content: []types.ContentPart{
271-
{
272-
ToolCall: &types.CompletionToolCall{
273-
ID: "call_4",
274-
Function: types.CompletionFunctionCall{
275-
Name: types.ToolNormalizer("sys.chat.finish"),
276-
Arguments: "Response from context chatbot after resume",
277-
},
278-
},
279-
},
280-
},
281-
}, tester.Result{
282-
Text: "Assistant Response 5 - from context tool resume",
283-
}, tester.Result{
284-
Text: "Assistant Response 6 - from main chat tool resume",
285-
})
286-
resp, err = r.Chat(context.Background(), resp.State, prg, os.Environ(), "User 4")
287-
require.NoError(t, err)
288-
r.AssertResponded(t)
289-
assert.False(t, resp.Done)
290-
autogold.Expect("Assistant Response 6 - from main chat tool resume").Equal(t, resp.Content)
291-
autogold.ExpectFile(t, toJSONString(t, resp), autogold.Name(t.Name()+"/step4"))
216+
_, err = r.Chat(context.Background(), nil, prg, os.Environ(), "User 1")
217+
autogold.Expect("invalid state: context tool [testdata/TestContextSubChat/test.gpt:subtool] can not result in a continuation").Equal(t, err.Error())
292218
}
293219

294220
func TestSubChat(t *testing.T) {

pkg/tests/testdata/TestContextSubChat/call10-resp.golden

Lines changed: 0 additions & 9 deletions
This file was deleted.

pkg/tests/testdata/TestContextSubChat/call3-resp.golden

Lines changed: 0 additions & 16 deletions
This file was deleted.

0 commit comments

Comments
 (0)