Skip to content

feat: add chat #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,14 @@ require (
github.com/acorn-io/broadcaster v0.0.0-20240105011354-bfadd4a7b45d
github.com/acorn-io/cmd v0.0.0-20240404013709-34f690bde37b
github.com/adrg/xdg v0.4.0
github.com/chzyer/readline v1.5.1
github.com/docker/cli v26.0.0+incompatible
github.com/docker/docker-credential-helpers v0.8.1
github.com/fatih/color v1.16.0
github.com/getkin/kin-openapi v0.123.0
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
github.com/gptscript-ai/chat-completion-client v0.0.0-20240404013040-49eb8f6affa1
github.com/hexops/autogold/v2 v2.1.0
github.com/hexops/autogold/v2 v2.2.1
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056
github.com/mholt/archiver/v4 v4.0.0-alpha.8
github.com/olahol/melody v1.1.4
Expand All @@ -26,8 +27,8 @@ require (
github.com/stretchr/testify v1.8.4
github.com/tidwall/gjson v1.17.1
golang.org/x/exp v0.0.0-20240103183307-be819d1f06fc
golang.org/x/sync v0.6.0
golang.org/x/term v0.16.0
golang.org/x/sync v0.7.0
golang.org/x/term v0.19.0
gopkg.in/yaml.v3 v3.0.1
)

Expand All @@ -48,7 +49,7 @@ require (
github.com/hashicorp/go-multierror v1.1.1 // indirect
github.com/hexops/autogold v1.3.1 // indirect
github.com/hexops/gotextdiff v1.0.3 // indirect
github.com/hexops/valast v1.4.3 // indirect
github.com/hexops/valast v1.4.4 // indirect
github.com/inconshreveable/mousetrap v1.1.0 // indirect
github.com/invopop/yaml v0.2.0 // indirect
github.com/josharian/intern v1.0.0 // indirect
Expand All @@ -75,11 +76,11 @@ require (
github.com/tidwall/pretty v1.2.0 // indirect
github.com/ulikunitz/xz v0.5.10 // indirect
go4.org v0.0.0-20200411211856-f5505b9728dd // indirect
golang.org/x/mod v0.15.0 // indirect
golang.org/x/net v0.20.0 // indirect
golang.org/x/sys v0.16.0 // indirect
golang.org/x/mod v0.17.0 // indirect
golang.org/x/net v0.24.0 // indirect
golang.org/x/sys v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect
golang.org/x/tools v0.17.0 // indirect
golang.org/x/tools v0.20.0 // indirect
gotest.tools/v3 v3.5.1 // indirect
mvdan.cc/gofumpt v0.6.0 // indirect
)
67 changes: 52 additions & 15 deletions go.sum

Large diffs are not rendered by default.

29 changes: 29 additions & 0 deletions pkg/builtin/builtin.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,15 @@ var tools = map[string]types.Tool{
},
BuiltinFunc: SysAbort,
},
"sys.chat.finish": {
Parameters: types.Parameters{
Description: "Concludes the conversation. This can not be used to ask a question.",
Arguments: types.ObjectSchema(
"summary", "A summary of the dialog",
),
},
BuiltinFunc: SysChatFinish,
},
"sys.http.post": {
Parameters: types.Parameters{
Description: "Write contents to a http or https URL using the POST method",
Expand Down Expand Up @@ -524,6 +533,26 @@ func SysGetenv(ctx context.Context, env []string, input string) (string, error)
return os.Getenv(params.Name), nil
}

type ErrChatFinish struct {
Message string
}

func (e *ErrChatFinish) Error() string {
return fmt.Sprintf("CHAT FINISH: %s", e.Message)
}

func SysChatFinish(ctx context.Context, env []string, input string) (string, error) {
var params struct {
Message string `json:"message,omitempty"`
}
if err := json.Unmarshal([]byte(input), &params); err != nil {
return "", err
}
return "", &ErrChatFinish{
Message: params.Message,
}
}

func SysAbort(ctx context.Context, env []string, input string) (string, error) {
var params struct {
Message string `json:"message,omitempty"`
Expand Down
83 changes: 83 additions & 0 deletions pkg/chat/chat.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
package chat

import (
"context"

"github.com/fatih/color"
"github.com/gptscript-ai/gptscript/pkg/runner"
"github.com/gptscript-ai/gptscript/pkg/types"
)

type Prompter interface {
Readline() (string, bool, error)
Printf(format string, args ...interface{}) (int, error)
SetPrompt(p string)
Close() error
}

type Chatter interface {
Chat(ctx context.Context, prevState runner.ChatState, prg types.Program, env []string, input string) (resp runner.ChatResponse, err error)
}

type GetProgram func() (types.Program, error)

func getPrompt(prg types.Program, resp runner.ChatResponse) string {
name := prg.ChatName()
if newName := prg.ToolSet[resp.ToolID].Name; newName != "" {
name = newName
}

return color.GreenString("%s> ", name)
}

func Start(ctx context.Context, prevState runner.ChatState, chatter Chatter, prg GetProgram, env []string, startInput string) error {
var (
prompter Prompter
)

prompter, err := newReadlinePrompter()
if err != nil {
return err
}
defer prompter.Close()

for {
var (
input string
ok bool
resp runner.ChatResponse
)

prg, err := prg()
if err != nil {
return err
}

prompter.SetPrompt(getPrompt(prg, resp))

if startInput != "" {
input = startInput
startInput = ""
} else if !(prevState == nil && prg.ToolSet[prg.EntryToolID].Arguments == nil) {
// The above logic will skip prompting if this is the first loop and the chat expects no args
input, ok, err = prompter.Readline()
if !ok || err != nil {
return err
}
}

resp, err = chatter.Chat(ctx, prevState, prg, env, input)
if err != nil || resp.Done {
return err
}

if resp.Content != "" {
_, err := prompter.Printf(color.RedString("< %s\n", resp.Content))
if err != nil {
return err
}
}

prevState = resp.State
}
}
66 changes: 66 additions & 0 deletions pkg/chat/readline.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
package chat

import (
"errors"
"fmt"
"io"
"strings"

"github.com/adrg/xdg"
"github.com/chzyer/readline"
"github.com/fatih/color"
"github.com/gptscript-ai/gptscript/pkg/mvl"
)

var _ Prompter = (*readlinePrompter)(nil)

type readlinePrompter struct {
readliner *readline.Instance
}

func newReadlinePrompter() (*readlinePrompter, error) {
historyFile, err := xdg.CacheFile("gptscript/chat.history")
if err != nil {
historyFile = ""
}

l, err := readline.NewEx(&readline.Config{
Prompt: color.GreenString("> "),
HistoryFile: historyFile,
InterruptPrompt: "^C",
EOFPrompt: "exit",
HistorySearchFold: true,
})
if err != nil {
return nil, err
}

l.CaptureExitSignal()
mvl.SetOutput(l.Stderr())

return &readlinePrompter{
readliner: l,
}, nil
}

func (r *readlinePrompter) Printf(format string, args ...interface{}) (int, error) {
return fmt.Fprintf(r.readliner.Stdout(), format, args...)
}

func (r *readlinePrompter) Readline() (string, bool, error) {
line, err := r.readliner.Readline()
if errors.Is(err, readline.ErrInterrupt) {
return "", false, nil
} else if errors.Is(err, io.EOF) {
return "", false, nil
}
return strings.TrimSpace(line), true, nil
}

func (r *readlinePrompter) SetPrompt(prompt string) {
r.readliner.SetPrompt(prompt)
}

func (r *readlinePrompter) Close() error {
return r.readliner.Close()
}
9 changes: 9 additions & 0 deletions pkg/cli/eval.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"strconv"
"strings"

"github.com/gptscript-ai/gptscript/pkg/chat"
"github.com/gptscript-ai/gptscript/pkg/gptscript"
"github.com/gptscript-ai/gptscript/pkg/input"
"github.com/gptscript-ai/gptscript/pkg/loader"
Expand All @@ -15,6 +16,7 @@ import (

type Eval struct {
Tools []string `usage:"Tools available to call"`
Chat bool `usage:"Enable chat"`
MaxTokens int `usage:"Maximum number of tokens to output"`
Model string `usage:"The model to use"`
JSON bool `usage:"Output JSON"`
Expand All @@ -33,6 +35,7 @@ func (e *Eval) Run(cmd *cobra.Command, args []string) error {
ModelName: e.Model,
JSONResponse: e.JSON,
InternalPrompt: e.InternalPrompt,
Chat: e.Chat,
},
Instructions: strings.Join(args, " "),
}
Expand Down Expand Up @@ -66,6 +69,12 @@ func (e *Eval) Run(cmd *cobra.Command, args []string) error {
return err
}

if e.Chat {
return chat.Start(e.gptscript.NewRunContext(cmd), nil, runner, func() (types.Program, error) {
return prg, nil
}, os.Environ(), toolInput)
}

toolOutput, err := runner.Run(e.gptscript.NewRunContext(cmd), prg, os.Environ(), toolInput)
if err != nil {
return err
Expand Down
45 changes: 42 additions & 3 deletions pkg/cli/gptscript.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package cli

import (
"context"
"encoding/json"
"fmt"
"io"
"os"
Expand All @@ -14,6 +15,7 @@ import (
"github.com/gptscript-ai/gptscript/pkg/assemble"
"github.com/gptscript-ai/gptscript/pkg/builtin"
"github.com/gptscript-ai/gptscript/pkg/cache"
"github.com/gptscript-ai/gptscript/pkg/chat"
"github.com/gptscript-ai/gptscript/pkg/confirm"
"github.com/gptscript-ai/gptscript/pkg/gptscript"
"github.com/gptscript-ai/gptscript/pkg/input"
Expand Down Expand Up @@ -57,6 +59,10 @@ type GPTScript struct {
Ports string `usage:"The port range to use for ephemeral daemon ports (ex: 11000-12000)" hidden:"true"`
CredentialContext string `usage:"Context name in which to store credentials" default:"default"`
CredentialOverride string `usage:"Credentials to override (ex: --credential-override github.com/example/cred-tool:API_TOKEN=1234)"`
ChatState string `usage:"The chat state to continue, or null to start a new chat and return the state"`
ForceChat bool `usage:"Force an interactive chat session if even the top level tool is not a chat tool"`

readData []byte
}

func New() *cobra.Command {
Expand Down Expand Up @@ -207,11 +213,17 @@ func (r *GPTScript) PersistentPre(*cobra.Command, []string) error {
r.Quiet = new(bool)
} else {
r.Quiet = &[]bool{true}[0]
if r.Color == nil {
r.Color = new(bool)
}
}
}

if r.Debug {
mvl.SetDebug()
if r.Color == nil {
r.Color = new(bool)
}
} else {
mvl.SetSimpleFormat()
if *r.Quiet {
Expand Down Expand Up @@ -245,9 +257,18 @@ func (r *GPTScript) readProgram(ctx context.Context, args []string) (prg types.P
}

if args[0] == "-" {
data, err := io.ReadAll(os.Stdin)
if err != nil {
return prg, err
var (
data []byte
err error
)
if len(r.readData) > 0 {
data = r.readData
} else {
data, err = io.ReadAll(os.Stdin)
if err != nil {
return prg, err
}
r.readData = data
}
return loader.ProgramFromSource(ctx, string(data), r.SubTool)
}
Expand Down Expand Up @@ -349,6 +370,24 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
return err
}

if r.ChatState != "" {
resp, err := gptScript.Chat(r.NewRunContext(cmd), r.ChatState, prg, os.Environ(), toolInput)
if err != nil {
return err
}
data, err := json.Marshal(resp)
if err != nil {
return err
}
return r.PrintOutput(toolInput, string(data))
}

if prg.IsChat() || r.ForceChat {
return chat.Start(r.NewRunContext(cmd), nil, gptScript, func() (types.Program, error) {
return r.readProgram(ctx, args)
}, os.Environ(), toolInput)
}

s, err := gptScript.Run(r.NewRunContext(cmd), prg, os.Environ(), toolInput)
if err != nil {
return err
Expand Down
4 changes: 2 additions & 2 deletions pkg/engine/cmd.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"github.com/gptscript-ai/gptscript/pkg/version"
)

func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, isCredential bool) (cmdOut string, cmdErr error) {
func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string, toolCategory ToolCategory) (cmdOut string, cmdErr error) {
id := fmt.Sprint(atomic.AddInt64(&completionID, 1))

defer func() {
Expand Down Expand Up @@ -65,7 +65,7 @@ func (e *Engine) runCommand(ctx context.Context, tool types.Tool, input string,
cmd.Stderr = io.MultiWriter(all, os.Stderr)
cmd.Stdout = io.MultiWriter(all, output)

if isCredential {
if toolCategory == CredentialToolCategory {
pause := context2.GetPauseFuncFromCtx(ctx)
unpause := pause()
defer unpause()
Expand Down
Loading