Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 93 additions & 29 deletions tests/test_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,29 +34,44 @@ def TEST_CONSTANTS():
encryption_algorithm=serialization.NoEncryption(),
)

current_datetime = datetime.now(timezone.utc)
current_timestamp = str(current_datetime)

token_claims = {
"sid": "session_123",
"org_id": "organization_123",
"role": "admin",
"permissions": ["read"],
"entitlements": ["feature_1"],
"exp": int(current_datetime.timestamp()) + 3600,
"iat": int(current_datetime.timestamp()),
}

user_id = "user_123"

return {
"COOKIE_PASSWORD": "pfSqwTFXUTGEBBD1RQh2kt/oNJYxBgaoZan4Z8sMrKU=",
"SESSION_DATA": "session_data",
"CLIENT_ID": "client_123",
"USER_ID": "user_123",
"USER_ID": user_id,
"SESSION_ID": "session_123",
"ORGANIZATION_ID": "organization_123",
"CURRENT_TIMESTAMP": str(datetime.now(timezone.utc)),
"CURRENT_DATETIME": current_datetime,
"CURRENT_TIMESTAMP": current_timestamp,
"PRIVATE_KEY": private_pem,
"PUBLIC_KEY": public_key,
"TEST_TOKEN": jwt.encode(
{
"sid": "session_123",
"org_id": "organization_123",
"role": "admin",
"permissions": ["read"],
"entitlements": ["feature_1"],
"exp": int(datetime.now(timezone.utc).timestamp()) + 3600,
"iat": int(datetime.now(timezone.utc).timestamp()),
},
private_pem,
algorithm="RS256",
),
"TEST_TOKEN": jwt.encode(token_claims, private_pem, algorithm="RS256"),
"TEST_TOKEN_CLAIMS": token_claims,
"TEST_USER": {
"object": "user",
"id": user_id,
"email": "[email protected]",
"first_name": "Test",
"last_name": "User",
"email_verified": True,
"created_at": current_timestamp,
"updated_at": current_timestamp,
},
}


Expand Down Expand Up @@ -145,6 +160,30 @@ def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management):
assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT


@with_jwks_mock
def test_authenticate_jwt_with_aud_claim(TEST_CONSTANTS, mock_user_management):
access_token = jwt.encode(
{**TEST_CONSTANTS["TEST_TOKEN_CLAIMS"], **{"aud": TEST_CONSTANTS["CLIENT_ID"]}},
TEST_CONSTANTS["PRIVATE_KEY"],
algorithm="RS256",
)

session_data = Session.seal_data(
{"access_token": access_token, "user": TEST_CONSTANTS["TEST_USER"]},
TEST_CONSTANTS["COOKIE_PASSWORD"],
)
session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=session_data,
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)

response = session.authenticate()

assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse)


@with_jwks_mock
def test_authenticate_success(TEST_CONSTANTS, mock_user_management):
session = Session(
Expand Down Expand Up @@ -229,27 +268,16 @@ def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management):

@with_jwks_mock
def test_refresh_success(TEST_CONSTANTS, mock_user_management):
test_user = {
"object": "user",
"id": TEST_CONSTANTS["USER_ID"],
"email": "[email protected]",
"first_name": "Test",
"last_name": "User",
"email_verified": True,
"created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
"updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"],
}

session_data = Session.seal_data(
{"refresh_token": "refresh_token_12345", "user": test_user},
{"refresh_token": "refresh_token_12345", "user": TEST_CONSTANTS["TEST_USER"]},
TEST_CONSTANTS["COOKIE_PASSWORD"],
)

mock_response = {
"access_token": TEST_CONSTANTS["TEST_TOKEN"],
"refresh_token": "refresh_token_123",
"sealed_session": session_data,
"user": test_user,
"user": TEST_CONSTANTS["TEST_USER"],
}

mock_user_management.authenticate_with_refresh_token.return_value = (
Expand Down Expand Up @@ -278,7 +306,7 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):

assert isinstance(response, RefreshWithSessionCookieSuccessResponse)
assert response.authenticated is True
assert response.user.id == test_user["id"]
assert response.user.id == TEST_CONSTANTS["TEST_USER"]["id"]

# Verify the refresh token was used correctly
mock_user_management.authenticate_with_refresh_token.assert_called_once_with(
Expand All @@ -291,6 +319,42 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management):
)


@with_jwks_mock
def test_refresh_success_with_aud_claim(TEST_CONSTANTS, mock_user_management):
session_data = Session.seal_data(
{"refresh_token": "refresh_token_12345", "user": TEST_CONSTANTS["TEST_USER"]},
TEST_CONSTANTS["COOKIE_PASSWORD"],
)

access_token = jwt.encode(
{**TEST_CONSTANTS["TEST_TOKEN_CLAIMS"], **{"aud": TEST_CONSTANTS["CLIENT_ID"]}},
TEST_CONSTANTS["PRIVATE_KEY"],
algorithm="RS256",
)

mock_response = {
"access_token": access_token,
"refresh_token": "refresh_token_123",
"sealed_session": session_data,
"user": TEST_CONSTANTS["TEST_USER"],
}

mock_user_management.authenticate_with_refresh_token.return_value = (
RefreshTokenAuthenticationResponse(**mock_response)
)

session = Session(
user_management=mock_user_management,
client_id=TEST_CONSTANTS["CLIENT_ID"],
session_data=session_data,
cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"],
)

response = session.refresh()

assert isinstance(response, RefreshWithSessionCookieSuccessResponse)


def test_seal_data(TEST_CONSTANTS):
test_data = {"test": "data"}
sealed = Session.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"])
Expand Down
13 changes: 11 additions & 2 deletions workos/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,10 @@ def authenticate(

signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"])
decoded = jwt.decode(
session["access_token"], signing_key.key, algorithms=self.jwk_algorithms
session["access_token"],
signing_key.key,
algorithms=self.jwk_algorithms,
options={"verify_aud": False},
)

return AuthenticateWithSessionCookieSuccessResponse(
Expand Down Expand Up @@ -141,6 +144,7 @@ def refresh(
auth_response.access_token,
signing_key.key,
algorithms=self.jwk_algorithms,
options={"verify_aud": False},
)

return RefreshWithSessionCookieSuccessResponse(
Expand Down Expand Up @@ -176,7 +180,12 @@ def get_logout_url(self, return_to: Optional[str] = None) -> str:
def _is_valid_jwt(self, token: str) -> bool:
try:
signing_key = self.jwks.get_signing_key_from_jwt(token)
jwt.decode(token, signing_key.key, algorithms=self.jwk_algorithms)
jwt.decode(
token,
signing_key.key,
algorithms=self.jwk_algorithms,
options={"verify_aud": False},
)
return True
except jwt.exceptions.InvalidTokenError:
return False
Expand Down
Loading