Skip to content

Commit 1c6faca

Browse files
bug: respect run level env in openai prompt
1 parent 2aafa62 commit 1c6faca

File tree

10 files changed

+64
-39
lines changed

10 files changed

+64
-39
lines changed

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ require (
1616
github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510
1717
github.com/google/uuid v1.6.0
1818
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379
19-
github.com/gptscript-ai/tui v0.0.0-20240614023948-004dc1597dd7
19+
github.com/gptscript-ai/tui v0.0.0-20240614062633-985091711b0a
2020
github.com/hexops/autogold/v2 v2.2.1
2121
github.com/hexops/valast v1.4.4
2222
github.com/jaytaylor/html2text v0.0.0-20230321000545-74c2419ad056

go.sum

+2-2
Original file line numberDiff line numberDiff line change
@@ -173,8 +173,8 @@ github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf037
173173
github.com/gptscript-ai/chat-completion-client v0.0.0-20240531200700-af8e7ecf0379/go.mod h1:7P/o6/IWa1KqsntVf68hSnLKuu3+xuqm6lYhch1w4jo=
174174
github.com/gptscript-ai/go-gptscript v0.0.0-20240613214812-8111c2b02d71 h1:WehkkausLuXI91ePpIVrzZ6eBmfFIU/HfNsSA1CHiwo=
175175
github.com/gptscript-ai/go-gptscript v0.0.0-20240613214812-8111c2b02d71/go.mod h1:Dh6vYRAiVcyC3ElZIGzTvNF1FxtYwA07BHfSiFKQY7s=
176-
github.com/gptscript-ai/tui v0.0.0-20240614023948-004dc1597dd7 h1:t+IuS+4JLUnwLHv+bgJQ2jHVT9ii0SLR3D7eNTZ47fg=
177-
github.com/gptscript-ai/tui v0.0.0-20240614023948-004dc1597dd7/go.mod h1:ZlyM+BRiD6mV04w+Xw2mXP1VKGEUbn8BvwrosWlplUo=
176+
github.com/gptscript-ai/tui v0.0.0-20240614062633-985091711b0a h1:LFsEDiIAx0Rq0V6aOMlRjXMMIfkA3uEhqqqjoggLlDQ=
177+
github.com/gptscript-ai/tui v0.0.0-20240614062633-985091711b0a/go.mod h1:ZlyM+BRiD6mV04w+Xw2mXP1VKGEUbn8BvwrosWlplUo=
178178
github.com/hashicorp/errwrap v1.0.0 h1:hLrqtEDnRye3+sgx6z4qVLNuviH3MR5aQ0ykNJa/UYA=
179179
github.com/hashicorp/errwrap v1.0.0/go.mod h1:YH+1FKiLXxHSkmPseP+kNlulaMuP3n2brvKWEqk/Jc4=
180180
github.com/hashicorp/go-multierror v1.1.1 h1:H5DkEtf6CXdFp0N0Em5UCwQpXMWke8IA0+lD48awMYo=

pkg/cli/gptscript.go

+5-5
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,11 @@ func (r *GPTScript) NewGPTScriptOpts() (gptscript.Options, error) {
132132
CredentialOverride: r.CredentialOverride,
133133
Sequential: r.ForceSequential,
134134
},
135-
Quiet: r.Quiet,
136-
Env: os.Environ(),
137-
CredentialContext: r.CredentialContext,
138-
Workspace: r.Workspace,
135+
Quiet: r.Quiet,
136+
Env: os.Environ(),
137+
CredentialContext: r.CredentialContext,
138+
Workspace: r.Workspace,
139+
DisablePromptServer: r.UI,
139140
}
140141

141142
if r.Confirm {
@@ -452,7 +453,6 @@ func (r *GPTScript) Run(cmd *cobra.Command, args []string) (retErr error) {
452453
Workspace: r.Workspace,
453454
SaveChatStateFile: r.SaveChatStateFile,
454455
ChatState: chatState,
455-
ExtraEnv: gptScript.ExtraEnv,
456456
})
457457
}
458458
return chat.Start(cmd.Context(), chatState, gptScript, func() (types.Program, error) {

pkg/context/context.go

+11
Original file line numberDiff line numberDiff line change
@@ -46,3 +46,14 @@ func GetLogger(ctx context.Context) mvl.Logger {
4646

4747
return l
4848
}
49+
50+
type envKey struct{}
51+
52+
func WithEnv(ctx context.Context, env []string) context.Context {
53+
return context.WithValue(ctx, envKey{}, env)
54+
}
55+
56+
func GetEnv(ctx context.Context) []string {
57+
l, _ := ctx.Value(envKey{}).([]string)
58+
return l
59+
}

pkg/engine/engine.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"sync"
99

1010
"github.com/gptscript-ai/gptscript/pkg/config"
11+
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
1112
"github.com/gptscript-ai/gptscript/pkg/counter"
1213
"github.com/gptscript-ai/gptscript/pkg/system"
1314
"github.com/gptscript-ai/gptscript/pkg/types"
@@ -328,7 +329,7 @@ func (e *Engine) complete(ctx context.Context, state *State) (*Return, error) {
328329
}
329330
}()
330331

331-
resp, err := e.Model.Call(ctx, state.Completion, progress)
332+
resp, err := e.Model.Call(gcontext.WithEnv(ctx, e.Env), state.Completion, progress)
332333
if err != nil {
333334
return nil, err
334335
}

pkg/gptscript/gptscript.go

+22-14
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,15 @@ type GPTScript struct {
3838
}
3939

4040
type Options struct {
41-
Cache cache.Options
42-
OpenAI openai.Options
43-
Monitor monitor.Options
44-
Runner runner.Options
45-
CredentialContext string
46-
Quiet *bool
47-
Workspace string
48-
Env []string
41+
Cache cache.Options
42+
OpenAI openai.Options
43+
Monitor monitor.Options
44+
Runner runner.Options
45+
CredentialContext string
46+
Quiet *bool
47+
Workspace string
48+
DisablePromptServer bool
49+
Env []string
4950
}
5051

5152
func complete(opts ...Options) Options {
@@ -60,6 +61,7 @@ func complete(opts ...Options) Options {
6061
result.Quiet = types.FirstSet(opt.Quiet, result.Quiet)
6162
result.Workspace = types.FirstSet(opt.Workspace, result.Workspace)
6263
result.Env = append(result.Env, opt.Env...)
64+
result.DisablePromptServer = types.FirstSet(opt.DisablePromptServer, result.DisablePromptServer)
6365
}
6466

6567
if result.Quiet == nil {
@@ -123,15 +125,21 @@ func New(o ...Options) (*GPTScript, error) {
123125
return nil, err
124126
}
125127

126-
ctx, closeServer := context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause))
127-
extraEnv, err := prompt.NewServer(ctx, opts.Env)
128-
if err != nil {
129-
closeServer()
130-
return nil, err
128+
var (
129+
extraEnv []string
130+
closeServer = func() {}
131+
)
132+
if !opts.DisablePromptServer {
133+
var ctx context.Context
134+
ctx, closeServer = context.WithCancel(context2.AddPauseFuncToCtx(context.Background(), opts.Runner.MonitorFactory.Pause))
135+
extraEnv, err = prompt.NewServer(ctx, opts.Env)
136+
if err != nil {
137+
closeServer()
138+
return nil, err
139+
}
131140
}
132141

133142
fullEnv := append(opts.Env, extraEnv...)
134-
oaiClient.SetEnvs(fullEnv)
135143

136144
remoteClient := remote.New(runner, fullEnv, cacheClient, credStore)
137145
if err := registry.AddClient(remoteClient); err != nil {

pkg/openai/client.go

+2-6
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212

1313
openai "github.com/gptscript-ai/chat-completion-client"
1414
"github.com/gptscript-ai/gptscript/pkg/cache"
15+
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
1516
"github.com/gptscript-ai/gptscript/pkg/counter"
1617
"github.com/gptscript-ai/gptscript/pkg/credentials"
1718
"github.com/gptscript-ai/gptscript/pkg/hash"
@@ -43,7 +44,6 @@ type Client struct {
4344
invalidAuth bool
4445
cacheKeyBase string
4546
setSeed bool
46-
envs []string
4747
credStore credentials.CredentialStore
4848
}
4949

@@ -136,10 +136,6 @@ func (c *Client) ValidAuth() error {
136136
return nil
137137
}
138138

139-
func (c *Client) SetEnvs(env []string) {
140-
c.envs = env
141-
}
142-
143139
func (c *Client) Supports(ctx context.Context, modelName string) (bool, error) {
144140
models, err := c.ListModels(ctx)
145141
if err != nil {
@@ -546,7 +542,7 @@ func (c *Client) call(ctx context.Context, request openai.ChatCompletionRequest,
546542
}
547543

548544
func (c *Client) RetrieveAPIKey(ctx context.Context) error {
549-
k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", c.envs)
545+
k, err := prompt.GetModelProviderCredential(ctx, c.credStore, BuiltinCredName, "OPENAI_API_KEY", "Please provide your OpenAI API key:", gcontext.GetEnv(ctx))
550546
if err != nil {
551547
return err
552548
}

pkg/prompt/prompt.go

+11-5
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ func sysPromptHTTP(ctx context.Context, envs []string, url string, prompt types.
4141
defer resp.Body.Close()
4242

4343
if resp.StatusCode != 200 {
44-
return "", fmt.Errorf("invalid status code [%d], expected 200", resp.StatusCode)
44+
return "", fmt.Errorf("getting prompt response invalid status code [%d], expected 200", resp.StatusCode)
4545
}
4646

4747
data, err = io.ReadAll(resp.Body)
@@ -75,17 +75,23 @@ func SysPrompt(ctx context.Context, envs []string, input string) (_ string, err
7575
func sysPrompt(ctx context.Context, req types.Prompt) (_ string, err error) {
7676
defer context2.GetPauseFuncFromCtx(ctx)()()
7777

78-
if req.Message != "" {
78+
if req.Message != "" && len(req.Fields) != 1 {
7979
_, _ = fmt.Fprintln(os.Stderr, req.Message)
8080
}
8181

8282
results := map[string]string{}
8383
for _, f := range req.Fields {
84-
var value string
84+
var (
85+
value string
86+
msg = f
87+
)
88+
if len(req.Fields) == 1 && req.Message != "" {
89+
msg = req.Message
90+
}
8591
if req.Sensitive {
86-
err = survey.AskOne(&survey.Password{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
92+
err = survey.AskOne(&survey.Password{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
8793
} else {
88-
err = survey.AskOne(&survey.Input{Message: f}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
94+
err = survey.AskOne(&survey.Input{Message: msg}, &value, survey.WithStdio(os.Stdin, os.Stderr, os.Stderr))
8995
}
9096
if err != nil {
9197
return "", err

pkg/prompt/server.go

+6-4
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,11 @@ import (
1515

1616
func NewServer(ctx context.Context, envs []string) ([]string, error) {
1717
for _, env := range envs {
18-
v, ok := strings.CutPrefix(env, types.PromptTokenEnvVar+"=")
19-
if ok && v != "" {
20-
return nil, nil
18+
for _, k := range []string{types.PromptURLEnvVar, types.PromptTokenEnvVar} {
19+
v, ok := strings.CutPrefix(env, k+"=")
20+
if ok && v != "" {
21+
return nil, nil
22+
}
2123
}
2224
}
2325

@@ -34,7 +36,7 @@ func NewServer(ctx context.Context, envs []string) ([]string, error) {
3436
Handler: http.HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
3537
if r.Header.Get("Authorization") != "Bearer "+token {
3638
rw.WriteHeader(http.StatusUnauthorized)
37-
_, _ = rw.Write([]byte("Unauthorized"))
39+
_, _ = rw.Write([]byte("Unauthorized (invalid token)"))
3840
return
3941
}
4042

pkg/remote/remote.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"sync"
1111

1212
"github.com/gptscript-ai/gptscript/pkg/cache"
13+
gcontext "github.com/gptscript-ai/gptscript/pkg/context"
1314
"github.com/gptscript-ai/gptscript/pkg/credentials"
1415
"github.com/gptscript-ai/gptscript/pkg/engine"
1516
env2 "github.com/gptscript-ai/gptscript/pkg/env"
@@ -176,5 +177,5 @@ func (c *Client) load(ctx context.Context, toolName string) (*openai.Client, err
176177
}
177178

178179
func (c *Client) retrieveAPIKey(ctx context.Context, env, url string) (string, error) {
179-
return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), c.envs)
180+
return prompt.GetModelProviderCredential(ctx, c.credStore, url, env, fmt.Sprintf("Please provide your API key for %s", url), append(gcontext.GetEnv(ctx), c.envs...))
180181
}

0 commit comments

Comments
 (0)