Skip to content

feat: add optional audience parameter to credential exchange related methods #419

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions integration/tests/posit/connect/oauth/test_associations.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def setup_class(cls):
task = bundle.deploy()
task.wait_for()

cls.content.oauth.associations.update(cls.integration["guid"])
cls.content.oauth.associations.update([cls.integration["guid"]])

@classmethod
def teardown_class(cls):
Expand Down Expand Up @@ -102,7 +102,7 @@ def test_find_update_by_content(self):
assert associations[0]["oauth_integration_guid"] == self.integration["guid"]

# update content association to another_integration
self.content.oauth.associations.update(self.another_integration["guid"])
self.content.oauth.associations.update([self.another_integration["guid"]])
updated_associations = self.content.oauth.associations.find()
assert len(updated_associations) == 1
assert updated_associations[0]["app_guid"] == self.content["guid"]
Expand Down
15 changes: 11 additions & 4 deletions src/posit/connect/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from __future__ import annotations

from typing_extensions import TYPE_CHECKING, overload
from typing_extensions import TYPE_CHECKING, Optional, overload

from . import hooks, me
from .auth import Auth
Expand All @@ -11,7 +11,8 @@
from .context import Context, ContextManager, requires
from .groups import Groups
from .metrics.metrics import Metrics
from .oauth.oauth import OAuth, OAuthTokenType
from .oauth.oauth import OAuth
from .oauth.types import OAuthTokenType
from .resources import _PaginatedResourceSequence, _ResourceSequence
from .sessions import Session
from .system import System
Expand Down Expand Up @@ -176,7 +177,11 @@ def __init__(self, *args, **kwargs) -> None:
self._ctx = Context(self)

@requires("2025.01.0")
def with_user_session_token(self, token: str) -> Client:
def with_user_session_token(
self,
token: str,
audience: Optional[str] = None,
) -> Client:
"""Create a new Client scoped to the user specified in the user session token.

Create a new Client instance from a user session token exchange for an api key scoped to the
Expand Down Expand Up @@ -256,7 +261,9 @@ def user_profile():
raise ValueError("token must be set to non-empty string.")

visitor_credentials = self.oauth.get_credentials(
token, requested_token_type=OAuthTokenType.API_KEY
token,
requested_token_type=OAuthTokenType.API_KEY,
audience=audience,
)

visitor_api_key = visitor_credentials.get("access_token", "")
Expand Down
14 changes: 14 additions & 0 deletions src/posit/connect/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import os
import posixpath
import time

Expand Down Expand Up @@ -1000,3 +1001,16 @@ def get(self, guid: str) -> ContentItem:

response = self._ctx.client.get(f"v1/content/{guid}", params=params)
return ContentItem(self._ctx, **response.json())

@property
def current(self) -> ContentItem:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! Did you consider making this the default behavior for content.get()? In my personal opinion, it seems more idomatic to do content.get() vs content.current(), but I could be missing some context as to why current is better.

"""Get the content item for the current context.

Returns
-------
ContentItem
"""
guid = os.getenv("CONNECT_CONTENT_GUID")
if not guid:
raise RuntimeError("CONNECT_CONTENT_GUID environment variable is not set.")
return self.get(guid)
14 changes: 11 additions & 3 deletions src/posit/connect/external/aws.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing_extensions import TYPE_CHECKING, Optional, TypedDict

from ..oauth.oauth import OAuthTokenType
from ..oauth.types import OAuthTokenType

if TYPE_CHECKING:
from ..client import Client
Expand All @@ -19,7 +19,11 @@ class Credentials(TypedDict):
expiration: datetime


def get_credentials(client: Client, user_session_token: str) -> Credentials:
def get_credentials(
client: Client,
user_session_token: str,
audience: Optional[str] = None,
) -> Credentials:
"""
Get AWS credentials using OAuth token exchange for an AWS Viewer integration.

Expand Down Expand Up @@ -66,6 +70,7 @@ def get_credentials(client: Client, user_session_token: str) -> Credentials:
credentials = client.oauth.get_credentials(
user_session_token=user_session_token,
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
audience=audience,
)

# Decode base64 access token
Expand All @@ -76,7 +81,9 @@ def get_credentials(client: Client, user_session_token: str) -> Credentials:


def get_content_credentials(
client: Client, content_session_token: Optional[str] = None
client: Client,
content_session_token: Optional[str] = None,
audience: Optional[str] = None,
) -> Credentials:
"""
Get AWS credentials using OAuth token exchange for an AWS Service Account integration.
Expand Down Expand Up @@ -122,6 +129,7 @@ def get_content_credentials(
credentials = client.oauth.get_content_credentials(
content_session_token=content_session_token,
requested_token_type=OAuthTokenType.AWS_CREDENTIALS,
audience=audience,
)

# Decode base64 access token
Expand Down
33 changes: 27 additions & 6 deletions src/posit/connect/external/databricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,11 +62,16 @@ class _PositConnectContentCredentialsProvider:
* https://github.com/posit-dev/posit-sdk-py/blob/main/src/posit/connect/oauth/oauth.py
"""

def __init__(self, client: Client):
def __init__(
self,
client: Client,
audience: Optional[str] = None,
):
self._client = client
self._audience = audience

def __call__(self) -> Dict[str, str]:
credentials = self._client.oauth.get_content_credentials()
credentials = self._client.oauth.get_content_credentials(audience=self._audience)
return _new_bearer_authorization_header(credentials)


Expand All @@ -81,12 +86,21 @@ class _PositConnectViewerCredentialsProvider:
* https://github.com/posit-dev/posit-sdk-py/blob/main/src/posit/connect/oauth/oauth.py
"""

def __init__(self, client: Client, user_session_token: str):
def __init__(
self,
client: Client,
user_session_token: str,
audience: Optional[str] = None,
):
self._client = client
self._user_session_token = user_session_token
self._audience = audience

def __call__(self) -> Dict[str, str]:
credentials = self._client.oauth.get_credentials(self._user_session_token)
credentials = self._client.oauth.get_credentials(
self._user_session_token,
audience=self._audience,
)
return _new_bearer_authorization_header(credentials)


Expand Down Expand Up @@ -174,10 +188,12 @@ def __init__(
self,
client: Optional[Client] = None,
user_session_token: Optional[str] = None,
audience: Optional[str] = None,
):
self._cp: Optional[CredentialsProvider] = None
self._client = client
self._user_session_token = user_session_token
self._audience = audience

def auth_type(self) -> str:
return POSIT_OAUTH_INTEGRATION_AUTH_TYPE
Expand All @@ -194,13 +210,18 @@ def __call__(self, *args, **kwargs) -> CredentialsProvider: # noqa: ARG002
if self._cp is None:
if self._user_session_token:
self._cp = _PositConnectViewerCredentialsProvider(
self._client, self._user_session_token
self._client,
self._user_session_token,
audience=self._audience,
)
else:
logger.info(
"ConnectStrategy will attempt to use OAuth Service Account credentials because user_session_token is not set"
)
self._cp = _PositConnectContentCredentialsProvider(self._client)
self._cp = _PositConnectContentCredentialsProvider(
self._client,
audience=self._audience,
)
return self._cp


Expand Down
7 changes: 6 additions & 1 deletion src/posit/connect/external/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,10 +69,12 @@ def __init__(
local_authenticator: Optional[str] = None,
client: Optional[Client] = None,
user_session_token: Optional[str] = None,
audience: Optional[str] = None,
):
self._local_authenticator = local_authenticator
self._client = client
self._user_session_token = user_session_token
self._audience = audience

@property
def authenticator(self) -> Optional[str]:
Expand All @@ -93,5 +95,8 @@ def token(self) -> Optional[str]:
if self._client is None:
self._client = Client()

credentials = self._client.oauth.get_credentials(self._user_session_token)
credentials = self._client.oauth.get_credentials(
self._user_session_token,
audience=self._audience,
)
return credentials.get("access_token")
78 changes: 74 additions & 4 deletions src/posit/connect/oauth/associations.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
"""OAuth association resources."""

from typing_extensions import List
from __future__ import annotations

from ..context import Context
import re

from typing_extensions import TYPE_CHECKING, List, Optional

# from ..context import requires
from ..resources import BaseResource, Resources

if TYPE_CHECKING:
from ..context import Context
from ..oauth import types


class Association(BaseResource):
pass
Expand Down Expand Up @@ -59,16 +67,78 @@ def find(self) -> List[Association]:
for result in response.json()
]

# TODO turn this on before merging
# @requires("2025.07.0")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You can do @requires("2025.06.0-dev"). It's effectively the same thing.

def find_by(
self,
integration_type: Optional[types.OAuthIntegrationType | str] = None,
auth_type: Optional[types.OAuthIntegrationAuthType | str] = None,
name: Optional[str] = None,
description: Optional[str] = None,
guid: Optional[str] = None,
) -> Association | None:
"""Find an OAuth integration associated with content by various criteria.

Parameters
----------
integration_type : Optional[types.OAuthIntegrationType | str]
The type of the integration (e.g., "aws", "azure").
auth_type : Optional[types.OAuthIntegrationAuthType | str]
The authentication type of the integration (e.g., "Viewer", "Service Account").
name : Optional[str]
A regex pattern to match the integration name. For exact matches, use `^` and `$`. For example,
`^My Integration$` will match only "My Integration".
description : Optional[str]
A regex pattern to match the integration description. For exact matches, use `^` and `$`. For example,
`^My Integration Description$` will match only "My Integration Description".
guid : Optional[str]
The unique identifier of the integration.

Returns
-------
Association | None
The first matching association, or None if no match is found.
"""
for integration in self.find():
if (
integration_type is not None
and integration.get("oauth_integration_template") != integration_type
):
continue

if (
auth_type is not None
and integration.get("oauth_integration_auth_type") != auth_type
):
continue

if name is not None:
integration_name = integration.get("oauth_integration_name", "")
if not re.search(name, integration_name):
continue

if description is not None:
integration_description = integration.get("oauth_integration_description", "")
if not re.search(description, integration_description):
continue

if guid is not None and integration.get("oauth_integration_guid") != guid:
continue

return integration

return None

def delete(self) -> None:
"""Delete integration associations."""
data = []

path = f"v1/content/{self.content_guid}/oauth/integrations/associations"
self._ctx.client.put(path, json=data)

def update(self, integration_guid: str) -> None:
def update(self, integration_guids: list[str]) -> None:
"""Set integration associations."""
data = [{"oauth_integration_guid": integration_guid}]
data = [{"oauth_integration_guid": guid} for guid in integration_guids]

path = f"v1/content/{self.content_guid}/oauth/integrations/associations"
self._ctx.client.put(path, json=data)
Loading