Skip to content

Commit 60da900

Browse files
authored
enhance: update credentials framework for OAuth support (#305)
Signed-off-by: Grant Linville <[email protected]>
1 parent abcd863 commit 60da900

File tree

3 files changed

+108
-64
lines changed

3 files changed

+108
-64
lines changed

pkg/cli/credential.go

+50-30
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ import (
66
"sort"
77
"strings"
88
"text/tabwriter"
9+
"time"
910

1011
cmd2 "github.com/acorn-io/cmd"
1112
"github.com/gptscript-ai/gptscript/pkg/cache"
@@ -14,6 +15,11 @@ import (
1415
"github.com/spf13/cobra"
1516
)
1617

18+
const (
19+
expiresNever = "never"
20+
expiresExpired = "expired"
21+
)
22+
1723
type Credential struct {
1824
root *GPTScript
1925
AllContexts bool `usage:"List credentials for all contexts" local:"true"`
@@ -46,6 +52,7 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
4652
}
4753
opts.Cache = cache.Complete(opts.Cache)
4854

55+
// Initialize the credential store and get all the credentials.
4956
store, err := credentials.NewStore(cfg, ctx, opts.Cache.CacheDir)
5057
if err != nil {
5158
return fmt.Errorf("failed to get credentials store: %w", err)
@@ -56,6 +63,10 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
5663
return fmt.Errorf("failed to list credentials: %w", err)
5764
}
5865

66+
w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0)
67+
defer w.Flush()
68+
69+
// Sort credentials and print column names, depending on the options.
5970
if c.AllContexts {
6071
// Sort credentials by context
6172
sort.Slice(creds, func(i, j int) bool {
@@ -65,25 +76,10 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
6576
return creds[i].Context < creds[j].Context
6677
})
6778

68-
w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0)
69-
defer w.Flush()
70-
7179
if c.ShowEnvVars {
72-
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tENVIRONMENT VARIABLES\n"))
73-
74-
for _, cred := range creds {
75-
envVars := make([]string, 0, len(cred.Env))
76-
for envVar := range cred.Env {
77-
envVars = append(envVars, envVar)
78-
}
79-
sort.Strings(envVars)
80-
_, _ = fmt.Fprintf(w, "%s\t%s\t%s\n", cred.Context, cred.ToolName, strings.Join(envVars, ", "))
81-
}
80+
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tEXPIRES IN\tENV\n"))
8281
} else {
83-
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\n"))
84-
for _, cred := range creds {
85-
_, _ = fmt.Fprintf(w, "%s\t%s\n", cred.Context, cred.ToolName)
86-
}
82+
_, _ = w.Write([]byte("CONTEXT\tCREDENTIAL\tEXPIRES IN\n"))
8783
}
8884
} else {
8985
// Sort credentials by tool name
@@ -92,24 +88,48 @@ func (c *Credential) Run(_ *cobra.Command, _ []string) error {
9288
})
9389

9490
if c.ShowEnvVars {
95-
w := tabwriter.NewWriter(os.Stdout, 10, 1, 3, ' ', 0)
96-
defer w.Flush()
97-
_, _ = w.Write([]byte("CREDENTIAL\tENVIRONMENT VARIABLES\n"))
98-
99-
for _, cred := range creds {
100-
envVars := make([]string, 0, len(cred.Env))
101-
for envVar := range cred.Env {
102-
envVars = append(envVars, envVar)
103-
}
104-
sort.Strings(envVars)
105-
_, _ = fmt.Fprintf(w, "%s\t%s\n", cred.ToolName, strings.Join(envVars, ", "))
91+
_, _ = w.Write([]byte("CREDENTIAL\tEXPIRES IN\tENV\n"))
92+
} else {
93+
_, _ = w.Write([]byte("CREDENTIAL\tEXPIRES IN\n"))
94+
}
95+
}
96+
97+
for _, cred := range creds {
98+
expires := expiresNever
99+
if cred.ExpiresAt != nil {
100+
expires = expiresExpired
101+
if !cred.IsExpired() {
102+
expires = time.Until(*cred.ExpiresAt).Truncate(time.Second).String()
106103
}
104+
}
105+
106+
var fields []any
107+
if c.AllContexts {
108+
fields = []any{cred.Context, cred.ToolName, expires}
107109
} else {
108-
for _, cred := range creds {
109-
fmt.Println(cred.ToolName)
110+
fields = []any{cred.ToolName, expires}
111+
}
112+
113+
if c.ShowEnvVars {
114+
envVars := make([]string, 0, len(cred.Env))
115+
for envVar := range cred.Env {
116+
envVars = append(envVars, envVar)
110117
}
118+
sort.Strings(envVars)
119+
fields = append(fields, strings.Join(envVars, ", "))
111120
}
121+
122+
printFields(w, fields)
112123
}
113124

114125
return nil
115126
}
127+
128+
func printFields(w *tabwriter.Writer, fields []any) {
129+
if len(fields) == 0 {
130+
return
131+
}
132+
133+
fmtStr := strings.Repeat("%s\t", len(fields)-1) + "%s\n"
134+
_, _ = fmt.Fprintf(w, fmtStr, fields...)
135+
}

pkg/credentials/credential.go

+32-15
Original file line numberDiff line numberDiff line change
@@ -4,43 +4,58 @@ import (
44
"encoding/json"
55
"fmt"
66
"strings"
7+
"time"
78

89
"github.com/docker/cli/cli/config/types"
910
)
1011

11-
const ctxSeparator = "///"
12-
1312
type CredentialType string
1413

1514
const (
15+
ctxSeparator = "///"
1616
CredentialTypeTool CredentialType = "tool"
1717
CredentialTypeModelProvider CredentialType = "modelProvider"
18+
ExistingCredential = "GPTSCRIPT_EXISTING_CREDENTIAL"
1819
)
1920

2021
type Credential struct {
21-
Context string `json:"context"`
22-
ToolName string `json:"toolName"`
23-
Type CredentialType `json:"type"`
24-
Env map[string]string `json:"env"`
22+
Context string `json:"context"`
23+
ToolName string `json:"toolName"`
24+
Type CredentialType `json:"type"`
25+
Env map[string]string `json:"env"`
26+
ExpiresAt *time.Time `json:"expiresAt"`
27+
RefreshToken string `json:"refreshToken"`
28+
}
29+
30+
func (c Credential) IsExpired() bool {
31+
if c.ExpiresAt == nil {
32+
return false
33+
}
34+
return time.Now().After(*c.ExpiresAt)
2535
}
2636

2737
func (c Credential) toDockerAuthConfig() (types.AuthConfig, error) {
28-
env, err := json.Marshal(c.Env)
38+
cred, err := json.Marshal(c)
2939
if err != nil {
3040
return types.AuthConfig{}, err
3141
}
3242

3343
return types.AuthConfig{
3444
Username: string(c.Type),
35-
Password: string(env),
45+
Password: string(cred),
3646
ServerAddress: toolNameWithCtx(c.ToolName, c.Context),
3747
}, nil
3848
}
3949

4050
func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error) {
41-
var env map[string]string
42-
if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil {
43-
return Credential{}, err
51+
var cred Credential
52+
if err := json.Unmarshal([]byte(authCfg.Password), &cred); err != nil || len(cred.Env) == 0 {
53+
// Legacy: try unmarshalling into just an env map
54+
var env map[string]string
55+
if err := json.Unmarshal([]byte(authCfg.Password), &env); err != nil {
56+
return Credential{}, err
57+
}
58+
cred.Env = env
4459
}
4560

4661
// We used to hardcode the username as "gptscript" before CredentialType was introduced, so
@@ -62,10 +77,12 @@ func credentialFromDockerAuthConfig(authCfg types.AuthConfig) (Credential, error
6277
}
6378

6479
return Credential{
65-
Context: ctx,
66-
ToolName: tool,
67-
Type: CredentialType(credType),
68-
Env: env,
80+
Context: ctx,
81+
ToolName: tool,
82+
Type: CredentialType(credType),
83+
Env: cred.Env,
84+
ExpiresAt: cred.ExpiresAt,
85+
RefreshToken: cred.RefreshToken,
6986
}, nil
7087
}
7188

pkg/runner/runner.go

+26-19
Original file line numberDiff line numberDiff line change
@@ -250,7 +250,7 @@ var (
250250
EventTypeRunFinish EventType = "runFinish"
251251
)
252252

253-
func getContextInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
253+
func getToolRefInput(prg *types.Program, ref types.ToolReference, input string) (string, error) {
254254
if ref.Arg == "" {
255255
return "", nil
256256
}
@@ -355,7 +355,7 @@ func (r *Runner) getContext(callCtx engine.Context, state *State, monitor Monito
355355
continue
356356
}
357357

358-
contextInput, err := getContextInput(callCtx.Program, toolRef, input)
358+
contextInput, err := getToolRefInput(callCtx.Program, toolRef, input)
359359
if err != nil {
360360
return nil, nil, err
361361
}
@@ -867,7 +867,7 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
867867
}
868868

869869
var (
870-
cred *credentials.Credential
870+
c *credentials.Credential
871871
exists bool
872872
)
873873

@@ -879,25 +879,39 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
879879
// Only try to look up the cred if the tool is on GitHub or has an alias.
880880
// If it is a GitHub tool and has an alias, the alias overrides the tool name, so we use it as the credential name.
881881
if isGitHubTool(toolName) && credentialAlias == "" {
882-
cred, exists, err = r.credStore.Get(toolName)
882+
c, exists, err = r.credStore.Get(toolName)
883883
if err != nil {
884884
return nil, fmt.Errorf("failed to get credentials for tool %s: %w", toolName, err)
885885
}
886886
} else if credentialAlias != "" {
887-
cred, exists, err = r.credStore.Get(credentialAlias)
887+
c, exists, err = r.credStore.Get(credentialAlias)
888888
if err != nil {
889889
return nil, fmt.Errorf("failed to get credentials for tool %s: %w", credentialAlias, err)
890890
}
891891
}
892892

893+
if c == nil {
894+
c = &credentials.Credential{}
895+
}
896+
893897
// If the credential doesn't already exist in the store, run the credential tool in order to get the value,
894898
// and save it in the store.
895-
if !exists {
899+
if !exists || c.IsExpired() {
896900
credToolRefs, ok := callCtx.Tool.ToolMapping[credToolName]
897901
if !ok || len(credToolRefs) != 1 {
898902
return nil, fmt.Errorf("failed to find ID for tool %s", credToolName)
899903
}
900904

905+
// If the existing credential is expired, we need to provide it to the cred tool through the environment.
906+
if exists && c.IsExpired() {
907+
credJSON, err := json.Marshal(c)
908+
if err != nil {
909+
return nil, fmt.Errorf("failed to marshal credential: %w", err)
910+
}
911+
env = append(env, fmt.Sprintf("%s=%s", credentials.ExistingCredential, string(credJSON)))
912+
}
913+
914+
// Get the input for the credential tool, if there is any.
901915
var input string
902916
if args != nil {
903917
inputBytes, err := json.Marshal(args)
@@ -916,21 +930,14 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
916930
return nil, fmt.Errorf("invalid state: credential tool [%s] can not result in a continuation", credToolName)
917931
}
918932

919-
var envMap struct {
920-
Env map[string]string `json:"env"`
921-
}
922-
if err := json.Unmarshal([]byte(*res.Result), &envMap); err != nil {
933+
if err := json.Unmarshal([]byte(*res.Result), &c); err != nil {
923934
return nil, fmt.Errorf("failed to unmarshal credential tool %s response: %w", credToolName, err)
924935
}
925-
926-
cred = &credentials.Credential{
927-
Type: credentials.CredentialTypeTool,
928-
Env: envMap.Env,
929-
ToolName: credName,
930-
}
936+
c.ToolName = credName
937+
c.Type = credentials.CredentialTypeTool
931938

932939
isEmpty := true
933-
for _, v := range cred.Env {
940+
for _, v := range c.Env {
934941
if v != "" {
935942
isEmpty = false
936943
break
@@ -941,15 +948,15 @@ func (r *Runner) handleCredentials(callCtx engine.Context, monitor Monitor, env
941948
if (isGitHubTool(toolName) && callCtx.Program.ToolSet[credToolRefs[0].ToolID].Source.Repo != nil) || credentialAlias != "" {
942949
if isEmpty {
943950
log.Warnf("Not saving empty credential for tool %s", toolName)
944-
} else if err := r.credStore.Add(*cred); err != nil {
951+
} else if err := r.credStore.Add(*c); err != nil {
945952
return nil, fmt.Errorf("failed to add credential for tool %s: %w", toolName, err)
946953
}
947954
} else {
948955
log.Warnf("Not saving credential for tool %s - credentials will only be saved for tools from GitHub, or tools that use aliases.", toolName)
949956
}
950957
}
951958

952-
for k, v := range cred.Env {
959+
for k, v := range c.Env {
953960
env = append(env, fmt.Sprintf("%s=%s", k, v))
954961
}
955962
}

0 commit comments

Comments
 (0)