Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 41 additions & 9 deletions pkg/engine/daemon.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (
"sync"
"time"

cryptorand "crypto/rand"

"github.com/gptscript-ai/gptscript/pkg/system"
"github.com/gptscript-ai/gptscript/pkg/types"
)
Expand All @@ -19,6 +21,7 @@ var ports Ports

type Ports struct {
daemonPorts map[string]int64
daemonTokens map[string]string
daemonsRunning map[string]func()
daemonLock sync.Mutex

Expand Down Expand Up @@ -119,18 +122,46 @@ func getPath(instructions string) (string, string) {
return strings.TrimSpace(rest), strings.TrimSpace(value)
}

func (e *Engine) startDaemon(tool types.Tool) (string, error) {
func getDaemonToken(toolID string) (string, error) {
token, ok := ports.daemonTokens[toolID]
if !ok {
// Generate a new token.
tokenBytes := make([]byte, 50)
count, err := cryptorand.Read(tokenBytes)
if err != nil {
return "", fmt.Errorf("failed to generate daemon token: %w", err)
} else if count != len(tokenBytes) {
return "", fmt.Errorf("failed to generate daemon token")
}

token = fmt.Sprintf("%x", tokenBytes)

if ports.daemonTokens == nil {
ports.daemonTokens = map[string]string{}
}
ports.daemonTokens[toolID] = token
}

return token, nil
}

func (e *Engine) startDaemon(tool types.Tool) (string, string, error) {
ports.daemonLock.Lock()
defer ports.daemonLock.Unlock()

instructions := strings.TrimPrefix(tool.Instructions, types.DaemonPrefix)
instructions, path := getPath(instructions)
tool.Instructions = types.CommandPrefix + instructions

token, err := getDaemonToken(tool.ID)
if err != nil {
return "", "", err
}

port, ok := ports.daemonPorts[tool.ID]
url := fmt.Sprintf("http://127.0.0.1:%d%s", port, path)
if ok && ports.daemonsRunning[url] != nil {
return url, nil
return url, token, nil
}

if ports.daemonCtx == nil {
Expand All @@ -149,18 +180,19 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
cmd, stop, err := e.newCommand(ctx, []string{
fmt.Sprintf("PORT=%d", port),
fmt.Sprintf("GPTSCRIPT_PORT=%d", port),
fmt.Sprintf("GPTSCRIPT_DAEMON_TOKEN=%s", token),
},
tool,
"{}",
false,
)
if err != nil {
return url, err
return url, "", err
}

r, w, err := os.Pipe()
if err != nil {
return "", err
return "", "", err
}

// Loop back to gptscript to help with process supervision
Expand All @@ -178,7 +210,7 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
log.Infof("launched [%s][%s] port [%d] %v", tool.Name, tool.ID, port, cmd.Args)
if err := cmd.Start(); err != nil {
stop()
return url, err
return url, "", err
}

if ports.daemonPorts == nil {
Expand Down Expand Up @@ -217,20 +249,20 @@ func (e *Engine) startDaemon(tool types.Tool) (string, error) {
_, _ = io.ReadAll(resp.Body)
_ = resp.Body.Close()
}()
return url, nil
return url, token, nil
}
select {
case <-killedCtx.Done():
return url, fmt.Errorf("daemon failed to start: %w", context.Cause(killedCtx))
return url, "", fmt.Errorf("daemon failed to start: %w", context.Cause(killedCtx))
case <-time.After(time.Second):
}
}

return url, fmt.Errorf("timeout waiting for 200 response from GET %s", url)
return url, "", fmt.Errorf("timeout waiting for 200 response from GET %s", url)
}

func (e *Engine) runDaemon(ctx Context, tool types.Tool, input string) (cmdRet *Return, cmdErr error) {
url, err := e.startDaemon(tool)
url, _, err := e.startDaemon(tool)
if err != nil {
return nil, err
}
Expand Down
11 changes: 9 additions & 2 deletions pkg/engine/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,10 @@ func (e *Engine) runHTTP(ctx Context, tool types.Tool, input string) (cmdRet *Re
return nil, err
}

var requestedEnvVars map[string]struct{}
var (
requestedEnvVars map[string]struct{}
daemonToken string
)
if strings.HasSuffix(parsed.Hostname(), DaemonURLSuffix) {
referencedToolName := strings.TrimSuffix(parsed.Hostname(), DaemonURLSuffix)
referencedToolRefs, ok := tool.ToolMapping[referencedToolName]
Expand All @@ -50,7 +53,7 @@ func (e *Engine) runHTTP(ctx Context, tool types.Tool, input string) (cmdRet *Re
if !ok {
return nil, fmt.Errorf("failed to find tool [%s] for [%s]", referencedToolName, parsed.Hostname())
}
toolURL, err = e.startDaemon(referencedTool)
toolURL, daemonToken, err = e.startDaemon(referencedTool)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -85,6 +88,10 @@ func (e *Engine) runHTTP(ctx Context, tool types.Tool, input string) (cmdRet *Re
return nil, err
}

if daemonToken != "" {
req.Header.Add("X-GPTScript-Daemon-Token", daemonToken)
}

for _, k := range slices.Sorted(maps.Keys(envMap)) {
if _, ok := requestedEnvVars[k]; ok || strings.HasPrefix(k, "GPTSCRIPT_WORKSPACE_") {
req.Header.Add("X-GPTScript-Env", k+"="+envMap[k])
Expand Down