Skip to content

Commit 96bd5cc

Browse files
committed
feat: add credential management
Signed-off-by: Grant Linville <[email protected]>
1 parent fce5058 commit 96bd5cc

File tree

3 files changed

+125
-0
lines changed

3 files changed

+125
-0
lines changed

gptscript/credentials.py

+74
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import json
2+
from datetime import datetime, timezone
3+
from enum import Enum
4+
5+
6+
def is_timezone_aware(dt: datetime):
7+
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None
8+
9+
10+
class CredentialType(Enum):
11+
Tool = "tool",
12+
ModelProvider = "modelProvider"
13+
14+
15+
class Credential:
16+
def __init__(self,
17+
context: str = "default",
18+
toolName: str = "",
19+
type: CredentialType = CredentialType.Tool,
20+
env: dict[str, str] = None,
21+
ephemeral: bool = False,
22+
expiresAt: datetime = None,
23+
refreshToken: str = "",
24+
):
25+
self.context = context
26+
self.toolName = toolName
27+
self.type = type
28+
self.env = env
29+
self.ephemeral = ephemeral
30+
self.expiresAt = expiresAt
31+
self.refreshToken = refreshToken
32+
33+
if self.env is None:
34+
self.env = {}
35+
36+
def to_json(self):
37+
datetime_str = ""
38+
39+
if self.expiresAt is not None:
40+
system_tz = datetime.now().astimezone().tzinfo
41+
42+
if not is_timezone_aware(self.expiresAt):
43+
self.expiresAt = self.expiresAt.replace(tzinfo=system_tz)
44+
datetime_str = self.expiresAt.isoformat()
45+
46+
# For UTC only, replace the "+00:00" with "Z"
47+
if self.expiresAt.tzinfo == timezone.utc:
48+
datetime_str = datetime_str.replace("+00:00", "Z")
49+
50+
req = {
51+
"context": self.context,
52+
"toolName": self.toolName,
53+
"type": self.type.value[0],
54+
"env": self.env,
55+
"ephemeral": self.ephemeral,
56+
"refreshToken": self.refreshToken,
57+
}
58+
59+
if datetime_str != "":
60+
req["expiresAt"] = datetime_str
61+
62+
return json.dumps(req)
63+
64+
class CredentialRequest:
65+
def __init__(self,
66+
content: str = "",
67+
allContexts: bool = False,
68+
context: str = "default",
69+
name: str = "",
70+
):
71+
self.content = content
72+
self.allContexts = allContexts
73+
self.context = context
74+
self.name = name

gptscript/gptscript.py

+33
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
import requests
1010

1111
from gptscript.confirm import AuthResponse
12+
from gptscript.credentials import Credential
1213
from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program
1314
from gptscript.opts import GlobalOptions
1415
from gptscript.prompt import PromptResponse
@@ -183,6 +184,38 @@ async def list_models(self, providers: list[str] = None, credential_overrides: l
183184
{"providers": providers, "credentialOverrides": credential_overrides}
184185
)).split("\n")
185186

187+
async def list_credentials(self, context: str = "default", all_contexts: bool = False) -> list[Credential] | str:
188+
res = await self._run_basic_command(
189+
"credentials",
190+
{"context": context, "allContexts": all_contexts}
191+
)
192+
if res.startswith("an error occurred:"):
193+
return res
194+
195+
return [Credential(**c) for c in json.loads(res)]
196+
197+
async def create_credential(self, cred: Credential) -> str:
198+
return await self._run_basic_command(
199+
"credentials/create",
200+
{"content": cred.to_json()}
201+
)
202+
203+
async def reveal_credential(self, context: str = "default", name: str = "") -> Credential | str:
204+
res = await self._run_basic_command(
205+
"credentials/reveal",
206+
{"context": context, "name": name}
207+
)
208+
if res.startswith("an error occurred:"):
209+
return res
210+
211+
return Credential(**json.loads(res))
212+
213+
async def delete_credential(self, context: str = "default", name: str = "") -> str:
214+
return await self._run_basic_command(
215+
"credentials/delete",
216+
{"context": context, "name": name}
217+
)
218+
186219

187220
def _get_command():
188221
if os.getenv("GPTSCRIPT_BIN") is not None:

tests/test_gptscript.py

+18
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import pytest
99

1010
from gptscript.confirm import AuthResponse
11+
from gptscript.credentials import Credential
1112
from gptscript.exec_utils import get_env
1213
from gptscript.frame import RunEventType, CallFrame, RunFrame, RunState, PromptFrame
1314
from gptscript.gptscript import GPTScript
@@ -683,3 +684,20 @@ async def test_parse_with_metadata_then_run(gptscript):
683684
tools = await gptscript.parse(cwd + "/tests/fixtures/parse-with-metadata.gpt")
684685
run = gptscript.evaluate(tools[0])
685686
assert "200" == await run.text(), "Expect file to have correct output"
687+
688+
@pytest.mark.asyncio
689+
async def test_credentials(gptscript):
690+
name = "test-" + str(os.urandom(4).hex())
691+
res = await gptscript.create_credential(Credential(toolName=name, env={"TEST": "test"}))
692+
assert not res.startswith("an error occurred"), "Unexpected error creating credential: " + res
693+
694+
res = await gptscript.list_credentials()
695+
assert not str(res).startswith("an error occurred"), "Unexpected error listing credentials: " + str(res)
696+
assert len(res) > 0, "Expected at least one credential"
697+
698+
res = await gptscript.reveal_credential(name=name)
699+
assert not str(res).startswith("an error occurred"), "Unexpected error revealing credential: " + res
700+
assert res.env["TEST"] == "test", "Unexpected credential value: " + str(res)
701+
702+
res = await gptscript.delete_credential(name=name)
703+
assert not res.startswith("an error occurred"), "Unexpected error deleting credential: " + res

0 commit comments

Comments
 (0)