Skip to content

Commit c02c4cb

Browse files
chore: make credential overrides cred context aware
1 parent 3f876b2 commit c02c4cb

File tree

6 files changed

+199
-60
lines changed

6 files changed

+199
-60
lines changed

pkg/credentials/factory.go

+46-14
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package credentials
22

33
import (
44
"context"
5+
"strings"
56

67
"github.com/docker/docker-credential-helpers/client"
78
"github.com/gptscript-ai/gptscript/pkg/config"
@@ -13,12 +14,32 @@ type ProgramLoaderRunner interface {
1314
Run(ctx context.Context, prg types.Program, input string) (output string, err error)
1415
}
1516

16-
func NewFactory(ctx context.Context, cfg *config.CLIConfig, plr ProgramLoaderRunner) (StoreFactory, error) {
17+
func NewFactory(ctx context.Context, cfg *config.CLIConfig, overrides []string, plr ProgramLoaderRunner) (StoreFactory, error) {
18+
creds, err := ParseCredentialOverrides(overrides)
19+
if err != nil {
20+
return StoreFactory{}, err
21+
}
22+
23+
overrideMap := make(map[string]map[string]map[string]string)
24+
for k, v := range creds {
25+
contextName, toolName, ok := strings.Cut(k, ctxSeparator)
26+
if !ok {
27+
continue
28+
}
29+
toolMap, ok := overrideMap[contextName]
30+
if !ok {
31+
toolMap = make(map[string]map[string]string)
32+
}
33+
toolMap[toolName] = v
34+
overrideMap[contextName] = toolMap
35+
}
36+
1737
toolName := translateToolName(cfg.CredentialsStore)
1838
if toolName == config.FileCredHelper {
1939
return StoreFactory{
20-
file: true,
21-
cfg: cfg,
40+
file: true,
41+
cfg: cfg,
42+
overrides: overrideMap,
2243
}, nil
2344
}
2445

@@ -28,10 +49,11 @@ func NewFactory(ctx context.Context, cfg *config.CLIConfig, plr ProgramLoaderRun
2849
}
2950

3051
return StoreFactory{
31-
ctx: ctx,
32-
prg: prg,
33-
runner: plr,
34-
cfg: cfg,
52+
ctx: ctx,
53+
prg: prg,
54+
runner: plr,
55+
cfg: cfg,
56+
overrides: overrideMap,
3557
}, nil
3658
}
3759

@@ -41,22 +63,32 @@ type StoreFactory struct {
4163
file bool
4264
runner ProgramLoaderRunner
4365
cfg *config.CLIConfig
66+
// That's a lot of maps: context -> toolName -> key -> value
67+
overrides map[string]map[string]map[string]string
4468
}
4569

4670
func (s *StoreFactory) NewStore(credCtxs []string) (CredentialStore, error) {
4771
if err := validateCredentialCtx(credCtxs); err != nil {
4872
return nil, err
4973
}
5074
if s.file {
51-
return Store{
52-
credCtxs: credCtxs,
53-
cfg: s.cfg,
75+
return withOverride{
76+
target: Store{
77+
credCtxs: credCtxs,
78+
cfg: s.cfg,
79+
},
80+
overrides: s.overrides,
81+
credContext: credCtxs,
5482
}, nil
5583
}
56-
return Store{
57-
credCtxs: credCtxs,
58-
cfg: s.cfg,
59-
program: s.program,
84+
return withOverride{
85+
target: Store{
86+
credCtxs: credCtxs,
87+
cfg: s.cfg,
88+
program: s.program,
89+
},
90+
overrides: s.overrides,
91+
credContext: credCtxs,
6092
}, nil
6193
}
6294

pkg/credentials/overrides.go

+149
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
package credentials
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"maps"
7+
"os"
8+
"strings"
9+
)
10+
11+
// ParseCredentialOverrides parses a string of credential overrides that the user provided as a command line arg.
12+
// The format of credential overrides can be one of two things:
13+
// cred1:ENV1,ENV2 (direct mapping of environment variables)
14+
// cred1:ENV1=VALUE1,ENV2=VALUE2 (key-value pairs)
15+
//
16+
// This function turns it into a map[string]map[string]string like this:
17+
//
18+
// {
19+
// "cred1": {
20+
// "ENV1": "VALUE1",
21+
// "ENV2": "VALUE2",
22+
// }
23+
// }
24+
func ParseCredentialOverrides(overrides []string) (map[string]map[string]string, error) {
25+
credentialOverrides := make(map[string]map[string]string)
26+
27+
for _, o := range overrides {
28+
credName, envs, found := strings.Cut(o, ":")
29+
if !found {
30+
return nil, fmt.Errorf("invalid credential override: %s", o)
31+
}
32+
envMap, ok := credentialOverrides[credName]
33+
if !ok {
34+
envMap = make(map[string]string)
35+
}
36+
for _, env := range strings.Split(envs, ",") {
37+
for _, env := range strings.Split(env, "|") {
38+
key, value, found := strings.Cut(env, "=")
39+
if !found {
40+
// User just passed an env var name as the key, so look up the value.
41+
value = os.Getenv(key)
42+
}
43+
envMap[key] = value
44+
}
45+
}
46+
credentialOverrides[credName] = envMap
47+
}
48+
49+
return credentialOverrides, nil
50+
}
51+
52+
type withOverride struct {
53+
target CredentialStore
54+
credContext []string
55+
overrides map[string]map[string]map[string]string
56+
}
57+
58+
func (w withOverride) Get(ctx context.Context, toolName string) (*Credential, bool, error) {
59+
for _, credCtx := range w.credContext {
60+
overrides, ok := w.overrides[credCtx]
61+
if !ok {
62+
continue
63+
}
64+
override, ok := overrides[toolName]
65+
if !ok {
66+
continue
67+
}
68+
69+
return &Credential{
70+
Context: credCtx,
71+
ToolName: toolName,
72+
Type: CredentialTypeTool,
73+
Env: maps.Clone(override),
74+
}, true, nil
75+
}
76+
77+
return w.target.Get(ctx, toolName)
78+
}
79+
80+
func (w withOverride) Add(ctx context.Context, cred Credential) error {
81+
for _, credCtx := range w.credContext {
82+
if override, ok := w.overrides[credCtx]; ok {
83+
if _, ok := override[cred.ToolName]; ok {
84+
return fmt.Errorf("cannot add credential with context %q and tool %q because it is statically configure", cred.Context, cred.ToolName)
85+
}
86+
}
87+
}
88+
return w.target.Add(ctx, cred)
89+
}
90+
91+
func (w withOverride) Refresh(ctx context.Context, cred Credential) error {
92+
if override, ok := w.overrides[cred.Context]; ok {
93+
if _, ok := override[cred.ToolName]; ok {
94+
return nil
95+
}
96+
}
97+
return w.target.Refresh(ctx, cred)
98+
}
99+
100+
func (w withOverride) Remove(ctx context.Context, toolName string) error {
101+
for _, credCtx := range w.credContext {
102+
if override, ok := w.overrides[credCtx]; ok {
103+
if _, ok := override[toolName]; ok {
104+
return fmt.Errorf("cannot remove credential with context %q and tool %q because it is statically configure", credCtx, toolName)
105+
}
106+
}
107+
}
108+
return w.target.Remove(ctx, toolName)
109+
}
110+
111+
func (w withOverride) List(ctx context.Context) ([]Credential, error) {
112+
creds, err := w.target.List(ctx)
113+
if err != nil {
114+
return nil, err
115+
}
116+
117+
added := make(map[string]map[string]bool)
118+
for i, cred := range creds {
119+
if override, ok := w.overrides[cred.Context]; ok {
120+
if _, ok := override[cred.ToolName]; ok {
121+
creds[i].Type = CredentialTypeTool
122+
creds[i].Env = maps.Clone(override[cred.ToolName])
123+
}
124+
}
125+
tools, ok := added[cred.Context]
126+
if !ok {
127+
tools = make(map[string]bool)
128+
}
129+
tools[cred.ToolName] = true
130+
added[cred.Context] = tools
131+
}
132+
133+
for _, credCtx := range w.credContext {
134+
tools := w.overrides[credCtx]
135+
for toolName := range tools {
136+
if _, ok := added[credCtx][toolName]; ok {
137+
continue
138+
}
139+
creds = append(creds, Credential{
140+
Context: credCtx,
141+
ToolName: toolName,
142+
Type: CredentialTypeTool,
143+
Env: maps.Clone(tools[toolName]),
144+
})
145+
}
146+
}
147+
148+
return creds, nil
149+
}

pkg/gptscript/gptscript.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,7 @@ func New(ctx context.Context, o ...Options) (*GPTScript, error) {
124124
return nil, err
125125
}
126126

127-
storeFactory, err := credentials.NewFactory(ctx, cliCfg, simplerRunner)
127+
storeFactory, err := credentials.NewFactory(ctx, cliCfg, opts.Runner.CredentialOverrides, simplerRunner)
128128
if err != nil {
129129
return nil, err
130130
}

pkg/runner/credentials.go

-43
This file was deleted.

pkg/runner/credentials_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import (
44
"os"
55
"testing"
66

7+
"github.com/gptscript-ai/gptscript/pkg/credentials"
78
"github.com/stretchr/testify/require"
89
)
910

@@ -119,7 +120,7 @@ func TestParseCredentialOverrides(t *testing.T) {
119120
_ = os.Setenv(k, v)
120121
}
121122

122-
out, err := parseCredentialOverrides(tc.in)
123+
out, err := credentials.ParseCredentialOverrides(tc.in)
123124
if tc.expectErr {
124125
require.Error(t, err)
125126
return

pkg/runner/runner.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -754,7 +754,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
754754
err error
755755
)
756756
if r.credOverrides != nil {
757-
credOverrides, err = parseCredentialOverrides(r.credOverrides)
757+
credOverrides, err = credentials.ParseCredentialOverrides(r.credOverrides)
758758
if err != nil {
759759
return nil, fmt.Errorf("failed to parse credential overrides: %w", err)
760760
}

0 commit comments

Comments
 (0)