Skip to content

enhance: Add proper aborting of runs #94

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Feb 25, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ func (g *GPTScript) Run(ctx context.Context, toolPath string, opts Options) (*Ru
}).NextChat(ctx, opts.Input)
}

func (g *GPTScript) AbortRun(ctx context.Context, run *Run) error {
_, err := g.runBasicCommand(ctx, "abort/"+run.id, (map[string]any)(nil))
return err
}

type ParseOptions struct {
DisableCache bool
}
Expand Down
139 changes: 137 additions & 2 deletions gptscript_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"strconv"
"strings"
"testing"
"time"

"github.com/getkin/kin-openapi/openapi3"
"github.com/stretchr/testify/require"
Expand Down Expand Up @@ -134,7 +135,7 @@ func TestListModelsWithDefaultProvider(t *testing.T) {
}
}

func TestAbortRun(t *testing.T) {
func TestCancelRun(t *testing.T) {
tool := ToolDef{Instructions: "What is the capital of the united states?"}

run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
Expand All @@ -146,7 +147,7 @@ func TestAbortRun(t *testing.T) {
<-run.Events()

if err := run.Close(); err != nil {
t.Errorf("Error aborting run: %v", err)
t.Errorf("Error canceling run: %v", err)
}

if run.State() != Error {
Expand All @@ -158,6 +159,77 @@ func TestAbortRun(t *testing.T) {
}
}

func TestAbortChatCompletionRun(t *testing.T) {
tool := ToolDef{Instructions: "What is the capital of the united states?"}

run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
if err != nil {
t.Errorf("Error executing tool: %v", err)
}

// Abort the run after the first event from the LLM
for e := range run.Events() {
if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." {
break
}
}

if err := g.AbortRun(context.Background(), run); err != nil {
t.Errorf("Error aborting run: %v", err)
}

// Wait for run to stop
for range run.Events() {
continue
}

if run.State() != Finished {
t.Errorf("Unexpected run state: %s", run.State())
}

if out, err := run.Text(); err != nil {
t.Errorf("Error reading output: %v", err)
} else if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") {
t.Errorf("Unexpected output: %s", out)
}
}

func TestAbortCommandRun(t *testing.T) {
tool := ToolDef{Instructions: "#!/usr/bin/env bash\necho Hello, world!\nsleep 5\necho Hello, again!\nsleep 5"}

run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
if err != nil {
t.Errorf("Error executing tool: %v", err)
}

// Abort the run after the first event.
for e := range run.Events() {
if e.Call != nil && e.Call.Type == EventTypeChat {
time.Sleep(2 * time.Second)
break
}
}

if err := g.AbortRun(context.Background(), run); err != nil {
t.Errorf("Error aborting run: %v", err)
}

// Wait for run to stop
for range run.Events() {
continue
}

if run.State() != Finished {
t.Errorf("Unexpected run state: %s", run.State())
}

if out, err := run.Text(); err != nil {
t.Errorf("Error reading output: %v", err)
} else if !strings.Contains(out, "Hello, world!") || strings.Contains(out, "Hello, again!") || !strings.HasSuffix(out, "\nABORTED BY USER") {
t.Errorf("Unexpected output: %s", out)
}
}

func TestSimpleEvaluate(t *testing.T) {
tool := ToolDef{Instructions: "What is the capital of the united states?"}

Expand Down Expand Up @@ -844,6 +916,69 @@ func TestToolChat(t *testing.T) {
}
}

func TestAbortChat(t *testing.T) {
tool := ToolDef{
Chat: true,
Instructions: "You are a chat bot. Don't finish the conversation until I say 'bye'.",
Tools: []string{"sys.chat.finish"},
}

run, err := g.Evaluate(context.Background(), Options{DisableCache: true, IncludeEvents: true}, tool)
if err != nil {
t.Fatalf("Error executing tool: %v", err)
}
inputs := []string{
"Tell me a joke.",
"What was my first message?",
}

// Just wait for the chat to start up.
for range run.Events() {
continue
}

for i, input := range inputs {
run, err = run.NextChat(context.Background(), input)
if err != nil {
t.Fatalf("Error sending next input %q: %v", input, err)
}

// Abort the run after the first event from the LLM
for e := range run.Events() {
if e.Call != nil && e.Call.Type == EventTypeCallProgress && len(e.Call.Output) > 0 && e.Call.Output[0].Content != "Waiting for model response..." {
break
}
}

if i == 0 {
if err := g.AbortRun(context.Background(), run); err != nil {
t.Fatalf("Error aborting run: %v", err)
}
}

// Wait for the run to complete
for range run.Events() {
continue
}

out, err := run.Text()
if err != nil {
t.Errorf("Error reading output: %s", run.ErrorOutput())
t.Fatalf("Error reading output: %v", err)
}

if i == 0 {
if strings.TrimSpace(out) != "ABORTED BY USER" && !strings.HasSuffix(out, "\nABORTED BY USER") {
t.Fatalf("Unexpected output: %s", out)
}
} else {
if !strings.Contains(out, "Tell me a joke") {
t.Errorf("Unexpected output: %s", out)
}
}
}
}

func TestFileChat(t *testing.T) {
wd, err := os.Getwd()
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions run.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ type Run struct {
basicCommand bool

program *Program
id string
callsLock sync.RWMutex
calls CallFrames
rawOutput map[string]any
Expand Down Expand Up @@ -400,6 +401,7 @@ func (r *Run) request(ctx context.Context, payload any) (err error) {
if event.Run.Type == EventTypeRunStart {
r.callsLock.Lock()
r.program = &event.Run.Program
r.id = event.Run.ID
r.callsLock.Unlock()
} else if event.Run.Type == EventTypeRunFinish && event.Run.Error != "" {
r.state = Error
Expand Down