Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 27 additions & 9 deletions graphistry/arrow_uploader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import List, Optional, Dict, Any

import io, pyarrow as pa, requests, sys
import base64, io, json, pyarrow as pa, requests, sys

from graphistry.privacy import Mode, Privacy, ModeAction
from graphistry.otel import inject_trace_headers
Expand All @@ -21,6 +21,19 @@
from graphistry.models.types import ValidationParam
logger = setup_logger(__name__)


def _personal_org_from_jwt(token: str) -> Optional[str]:
"""Decode JWT payload (no verification) to extract username as personal-org slug."""
try:
parts = token.split('.')
if len(parts) < 2:
return None
segment = parts[1] + '=' * (4 - len(parts[1]) % 4)
return json.loads(base64.urlsafe_b64decode(segment)).get('username')
except Exception:
return None


class ArrowUploader:

def __init__(
Expand Down Expand Up @@ -428,15 +441,20 @@ def sso_get_token(self, state):
self.token = token_value

active_org = data.get('active_organization')
if not active_org or not active_org.get('slug'):
raise Exception(
"SSO response missing active organization; see graphistry/graphistry#2933"
)
slug = active_org.get('slug') if isinstance(active_org, dict) else None

if not slug:
# New SSO users may have no active_org yet; fall back to personal org (slug == username)
slug = _personal_org_from_jwt(token_value)
if slug:
logger.info("SSO response missing active_organization; falling back to personal org: %s", slug)
else:
logger.warning("SSO response missing active_organization and JWT has no username; proceeding without org")

slug = active_org['slug']
logger.debug("@ArrowUploader.sso_get_token, org_name: %s", slug)
self.org_name = slug
self._switch_org(slug, token_value or self.token)
if slug:
logger.debug("@ArrowUploader.sso_get_token, org_name: %s", slug)
self.org_name = slug
self._switch_org(slug, token_value)

except Exception as e:
logger.error('Unexpected SSO authentication error: %s', out, exc_info=True)
Expand Down
34 changes: 28 additions & 6 deletions graphistry/tests/test_arrow_uploader.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-

import graphistry, pandas as pd, pytest, unittest
import base64, graphistry, json, pandas as pd, pytest, unittest
try:
import mock # type: ignore
except ImportError: # pragma: no cover - fallback for stdlib-only envs
Expand Down Expand Up @@ -473,19 +473,41 @@ def test_sso_login_get_sso_token_ok(self, mock_get):
mock_switch.assert_called_once_with('mock-org', '123')

@mock.patch('requests.get')
def test_sso_get_token_missing_org_raises(self, mock_get):
def test_sso_get_token_missing_org_falls_back_to_personal(self, mock_get):
payload = base64.urlsafe_b64encode(
json.dumps({'user_id': 1, 'username': 'testuser', 'exp': 9999999999}).encode()
).rstrip(b'=').decode()
fake_token = f"eyJhbGciOiJIUzI1NiJ9.{payload}.fakesig"

mock_resp = self._mock_response(
json_data={
'status': 'OK',
'message': 'State is valid',
'data': {
'token': '123',
}
'data': {'token': fake_token},
})
mock_get.return_value = mock_resp

au = ArrowUploader()
with mock.patch.object(ArrowUploader, "_switch_org") as mock_switch:
au.sso_get_token(state='ignored-valid')

with pytest.raises(Exception):
assert au.token == fake_token
assert au.org_name == 'testuser'
mock_switch.assert_called_once_with('testuser', fake_token)

@mock.patch('requests.get')
def test_sso_get_token_missing_org_no_username_in_jwt(self, mock_get):
mock_resp = self._mock_response(
json_data={
'status': 'OK',
'message': 'State is valid',
'data': {'token': '123'},
})
mock_get.return_value = mock_resp

au = ArrowUploader()
with mock.patch.object(ArrowUploader, "_switch_org") as mock_switch:
au.sso_get_token(state='ignored-valid')

assert au.token == '123'
mock_switch.assert_not_called()
Loading