Skip to content

Commit e030247

Browse files
authored
feat: add credential management (#52)
Signed-off-by: Grant Linville <[email protected]>
1 parent fce5058 commit e030247

File tree

3 files changed

+161
-1
lines changed

3 files changed

+161
-1
lines changed

gptscript/credentials.py

+95
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
import json
2+
from datetime import datetime, timezone
3+
from enum import Enum
4+
from typing import List
5+
6+
7+
def is_timezone_aware(dt: datetime):
8+
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None
9+
10+
11+
class CredentialType(Enum):
12+
tool = "tool",
13+
modelProvider = "modelProvider"
14+
15+
16+
class Credential:
17+
def __init__(self,
18+
context: str = "default",
19+
toolName: str = "",
20+
type: CredentialType = CredentialType.tool,
21+
env: dict[str, str] = None,
22+
ephemeral: bool = False,
23+
expiresAt: datetime = None,
24+
refreshToken: str = "",
25+
):
26+
self.context = context
27+
self.toolName = toolName
28+
self.type = type
29+
self.env = env
30+
self.ephemeral = ephemeral
31+
self.expiresAt = expiresAt
32+
self.refreshToken = refreshToken
33+
34+
if self.env is None:
35+
self.env = {}
36+
37+
def to_json(self):
38+
datetime_str = ""
39+
40+
if self.expiresAt is not None:
41+
system_tz = datetime.now().astimezone().tzinfo
42+
43+
if not is_timezone_aware(self.expiresAt):
44+
self.expiresAt = self.expiresAt.replace(tzinfo=system_tz)
45+
datetime_str = self.expiresAt.isoformat()
46+
47+
# For UTC only, replace the "+00:00" with "Z"
48+
if self.expiresAt.tzinfo == timezone.utc:
49+
datetime_str = datetime_str.replace("+00:00", "Z")
50+
51+
req = {
52+
"context": self.context,
53+
"toolName": self.toolName,
54+
"type": self.type.value[0],
55+
"env": self.env,
56+
"ephemeral": self.ephemeral,
57+
"refreshToken": self.refreshToken,
58+
}
59+
60+
if datetime_str != "":
61+
req["expiresAt"] = datetime_str
62+
63+
return json.dumps(req)
64+
65+
class CredentialRequest:
66+
def __init__(self,
67+
content: str = "",
68+
allContexts: bool = False,
69+
contexts: List[str] = None,
70+
name: str = "",
71+
):
72+
if contexts is None:
73+
contexts = ["default"]
74+
75+
self.content = content
76+
self.allContexts = allContexts
77+
self.contexts = contexts
78+
self.name = name
79+
80+
def to_credential(c) -> Credential:
81+
expiresAt = c["expiresAt"]
82+
if expiresAt is not None:
83+
if expiresAt.endswith("Z"):
84+
expiresAt = expiresAt.replace("Z", "+00:00")
85+
expiresAt = datetime.fromisoformat(expiresAt)
86+
87+
return Credential(
88+
context=c["context"],
89+
toolName=c["toolName"],
90+
type=CredentialType[c["type"]],
91+
env=c["env"],
92+
ephemeral=c.get("ephemeral", False),
93+
expiresAt=expiresAt,
94+
refreshToken=c["refreshToken"],
95+
)

gptscript/gptscript.py

+40-1
Original file line numberDiff line numberDiff line change
@@ -4,11 +4,12 @@
44
from subprocess import Popen, PIPE
55
from sys import executable
66
from time import sleep
7-
from typing import Any, Callable, Awaitable
7+
from typing import Any, Callable, Awaitable, List
88

99
import requests
1010

1111
from gptscript.confirm import AuthResponse
12+
from gptscript.credentials import Credential, to_credential
1213
from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program
1314
from gptscript.opts import GlobalOptions
1415
from gptscript.prompt import PromptResponse
@@ -183,6 +184,44 @@ async def list_models(self, providers: list[str] = None, credential_overrides: l
183184
{"providers": providers, "credentialOverrides": credential_overrides}
184185
)).split("\n")
185186

187+
async def list_credentials(self, contexts: List[str] = None, all_contexts: bool = False) -> list[Credential] | str:
188+
if contexts is None:
189+
contexts = ["default"]
190+
191+
res = await self._run_basic_command(
192+
"credentials",
193+
{"context": contexts, "allContexts": all_contexts}
194+
)
195+
if res.startswith("an error occurred:"):
196+
return res
197+
198+
return [to_credential(cred) for cred in json.loads(res)]
199+
200+
async def create_credential(self, cred: Credential) -> str:
201+
return await self._run_basic_command(
202+
"credentials/create",
203+
{"content": cred.to_json()}
204+
)
205+
206+
async def reveal_credential(self, contexts: List[str] = None, name: str = "") -> Credential | str:
207+
if contexts is None:
208+
contexts = ["default"]
209+
210+
res = await self._run_basic_command(
211+
"credentials/reveal",
212+
{"context": contexts, "name": name}
213+
)
214+
if res.startswith("an error occurred:"):
215+
return res
216+
217+
return to_credential(json.loads(res))
218+
219+
async def delete_credential(self, context: str = "default", name: str = "") -> str:
220+
return await self._run_basic_command(
221+
"credentials/delete",
222+
{"context": [context], "name": name}
223+
)
224+
186225

187226
def _get_command():
188227
if os.getenv("GPTSCRIPT_BIN") is not None:

tests/test_gptscript.py

+26
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@
44
import os
55
import platform
66
import subprocess
7+
from datetime import datetime, timedelta, timezone
8+
from time import sleep
79

810
import pytest
911

1012
from gptscript.confirm import AuthResponse
13+
from gptscript.credentials import Credential
1114
from gptscript.exec_utils import get_env
1215
from gptscript.frame import RunEventType, CallFrame, RunFrame, RunState, PromptFrame
1316
from gptscript.gptscript import GPTScript
@@ -683,3 +686,26 @@ async def test_parse_with_metadata_then_run(gptscript):
683686
tools = await gptscript.parse(cwd + "/tests/fixtures/parse-with-metadata.gpt")
684687
run = gptscript.evaluate(tools[0])
685688
assert "200" == await run.text(), "Expect file to have correct output"
689+
690+
@pytest.mark.asyncio
691+
async def test_credentials(gptscript):
692+
name = "test-" + str(os.urandom(4).hex())
693+
now = datetime.now()
694+
res = await gptscript.create_credential(Credential(toolName=name, env={"TEST": "test"}, expiresAt=now + timedelta(seconds=5)))
695+
assert not res.startswith("an error occurred"), "Unexpected error creating credential: " + res
696+
697+
sleep(5)
698+
699+
res = await gptscript.list_credentials()
700+
assert not str(res).startswith("an error occurred"), "Unexpected error listing credentials: " + str(res)
701+
assert len(res) > 0, "Expected at least one credential"
702+
for cred in res:
703+
if cred.toolName == name:
704+
assert cred.expiresAt < datetime.now(timezone.utc), "Expected credential to have expired"
705+
706+
res = await gptscript.reveal_credential(name=name)
707+
assert not str(res).startswith("an error occurred"), "Unexpected error revealing credential: " + res
708+
assert res.env["TEST"] == "test", "Unexpected credential value: " + str(res)
709+
710+
res = await gptscript.delete_credential(name=name)
711+
assert not res.startswith("an error occurred"), "Unexpected error deleting credential: " + res

0 commit comments

Comments
 (0)