Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add credential management #52

Merged
merged 3 commits into from
Sep 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions gptscript/credentials.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import json
from datetime import datetime, timezone
from enum import Enum
from typing import List


def is_timezone_aware(dt: datetime):
return dt.tzinfo is not None and dt.tzinfo.utcoffset(dt) is not None


class CredentialType(Enum):
tool = "tool",
modelProvider = "modelProvider"


class Credential:
def __init__(self,
context: str = "default",
toolName: str = "",
type: CredentialType = CredentialType.tool,
env: dict[str, str] = None,
ephemeral: bool = False,
expiresAt: datetime = None,
refreshToken: str = "",
):
self.context = context
self.toolName = toolName
self.type = type
self.env = env
self.ephemeral = ephemeral
self.expiresAt = expiresAt
self.refreshToken = refreshToken

if self.env is None:
self.env = {}

def to_json(self):
datetime_str = ""

if self.expiresAt is not None:
system_tz = datetime.now().astimezone().tzinfo

if not is_timezone_aware(self.expiresAt):
self.expiresAt = self.expiresAt.replace(tzinfo=system_tz)
datetime_str = self.expiresAt.isoformat()

# For UTC only, replace the "+00:00" with "Z"
if self.expiresAt.tzinfo == timezone.utc:
datetime_str = datetime_str.replace("+00:00", "Z")

req = {
"context": self.context,
"toolName": self.toolName,
"type": self.type.value[0],
"env": self.env,
"ephemeral": self.ephemeral,
"refreshToken": self.refreshToken,
}

if datetime_str != "":
req["expiresAt"] = datetime_str

return json.dumps(req)

class CredentialRequest:
def __init__(self,
content: str = "",
allContexts: bool = False,
contexts: List[str] = None,
name: str = "",
):
if contexts is None:
contexts = ["default"]

self.content = content
self.allContexts = allContexts
self.contexts = contexts
self.name = name

def to_credential(c) -> Credential:
expiresAt = c["expiresAt"]
if expiresAt is not None:
if expiresAt.endswith("Z"):
expiresAt = expiresAt.replace("Z", "+00:00")
expiresAt = datetime.fromisoformat(expiresAt)

return Credential(
context=c["context"],
toolName=c["toolName"],
type=CredentialType[c["type"]],
env=c["env"],
ephemeral=c.get("ephemeral", False),
expiresAt=expiresAt,
refreshToken=c["refreshToken"],
)
41 changes: 40 additions & 1 deletion gptscript/gptscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
from subprocess import Popen, PIPE
from sys import executable
from time import sleep
from typing import Any, Callable, Awaitable
from typing import Any, Callable, Awaitable, List

import requests

from gptscript.confirm import AuthResponse
from gptscript.credentials import Credential, to_credential
from gptscript.frame import RunFrame, CallFrame, PromptFrame, Program
from gptscript.opts import GlobalOptions
from gptscript.prompt import PromptResponse
Expand Down Expand Up @@ -183,6 +184,44 @@ async def list_models(self, providers: list[str] = None, credential_overrides: l
{"providers": providers, "credentialOverrides": credential_overrides}
)).split("\n")

async def list_credentials(self, contexts: List[str] = None, all_contexts: bool = False) -> list[Credential] | str:
if contexts is None:
contexts = ["default"]

res = await self._run_basic_command(
"credentials",
{"context": contexts, "allContexts": all_contexts}
)
if res.startswith("an error occurred:"):
return res

return [to_credential(cred) for cred in json.loads(res)]

async def create_credential(self, cred: Credential) -> str:
return await self._run_basic_command(
"credentials/create",
{"content": cred.to_json()}
)

async def reveal_credential(self, contexts: List[str] = None, name: str = "") -> Credential | str:
if contexts is None:
contexts = ["default"]

res = await self._run_basic_command(
"credentials/reveal",
{"context": contexts, "name": name}
)
if res.startswith("an error occurred:"):
return res

return to_credential(json.loads(res))

async def delete_credential(self, context: str = "default", name: str = "") -> str:
return await self._run_basic_command(
"credentials/delete",
{"context": [context], "name": name}
)


def _get_command():
if os.getenv("GPTSCRIPT_BIN") is not None:
Expand Down
26 changes: 26 additions & 0 deletions tests/test_gptscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,13 @@
import os
import platform
import subprocess
from datetime import datetime, timedelta, timezone
from time import sleep

import pytest

from gptscript.confirm import AuthResponse
from gptscript.credentials import Credential
from gptscript.exec_utils import get_env
from gptscript.frame import RunEventType, CallFrame, RunFrame, RunState, PromptFrame
from gptscript.gptscript import GPTScript
Expand Down Expand Up @@ -683,3 +686,26 @@ async def test_parse_with_metadata_then_run(gptscript):
tools = await gptscript.parse(cwd + "/tests/fixtures/parse-with-metadata.gpt")
run = gptscript.evaluate(tools[0])
assert "200" == await run.text(), "Expect file to have correct output"

@pytest.mark.asyncio
async def test_credentials(gptscript):
name = "test-" + str(os.urandom(4).hex())
now = datetime.now()
res = await gptscript.create_credential(Credential(toolName=name, env={"TEST": "test"}, expiresAt=now + timedelta(seconds=5)))
assert not res.startswith("an error occurred"), "Unexpected error creating credential: " + res

sleep(5)

res = await gptscript.list_credentials()
assert not str(res).startswith("an error occurred"), "Unexpected error listing credentials: " + str(res)
assert len(res) > 0, "Expected at least one credential"
for cred in res:
if cred.toolName == name:
assert cred.expiresAt < datetime.now(timezone.utc), "Expected credential to have expired"

res = await gptscript.reveal_credential(name=name)
assert not str(res).startswith("an error occurred"), "Unexpected error revealing credential: " + res
assert res.env["TEST"] == "test", "Unexpected credential value: " + str(res)

res = await gptscript.delete_credential(name=name)
assert not res.startswith("an error occurred"), "Unexpected error deleting credential: " + res