Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 15 additions & 5 deletions lean/components/api/auth0_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,50 @@ def __init__(self, api_client: 'APIClient') -> None:
self._api = api_client
self._cache = {}

def read(self, brokerage_id: str) -> QCAuth0Authorization:
def read(self, brokerage_id: str, user_name: str = None) -> QCAuth0Authorization:
"""Reads the authorization data for a brokerage.

:param brokerage_id: the id of the brokerage to read the authorization data for
:param user_name: the optional login ID of the user
:return: the authorization data for the specified brokerage
"""
try:
# First check cache
if brokerage_id in self._cache.keys():
return self._cache[brokerage_id]
if user_name:
cache_key = (brokerage_id, user_name)
else:
cache_key = brokerage_id
if cache_key in self._cache:
return self._cache[cache_key]
payload = {
"brokerage": brokerage_id
}
if user_name:
payload["userId"] = user_name

data = self._api.post("live/auth0/read", payload)
# Store in cache
result = QCAuth0Authorization(**data)
self._cache[brokerage_id] = result
self._cache[cache_key] = result
return result
except RequestFailedError as e:
return QCAuth0Authorization(authorization=None)

@staticmethod
def authorize(brokerage_id: str, logger: Logger, project_id: int, no_browser: bool = False) -> None:
def authorize(brokerage_id: str, logger: Logger, project_id: int, no_browser: bool = False, user_name: str = None) -> None:
"""Starts the authorization process for a brokerage.

:param brokerage_id: the id of the brokerage to start the authorization process for
:param logger: the logger instance to use
:param project_id: The local or cloud project_id
:param user_name: the optional login ID of the user to pre-fill in the authorization page
:param no_browser: whether to disable opening the browser
"""
from webbrowser import open

full_url = f"{API_BASE_URL}live/auth0/authorize?brokerage={brokerage_id}&projectId={project_id}"
if user_name:
full_url += f"&userId={user_name}"

logger.info(f"Please open the following URL in your browser to authorize the LEAN CLI.")
logger.info(full_url)
Expand Down
8 changes: 4 additions & 4 deletions lean/components/util/auth0_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from lean.components.util.logger import Logger


def get_authorization(auth0_client: Auth0Client, brokerage_id: str, logger: Logger, project_id: int, no_browser: bool = False) -> QCAuth0Authorization:
def get_authorization(auth0_client: Auth0Client, brokerage_id: str, logger: Logger, project_id: int, no_browser: bool = False, user_name: str = None) -> QCAuth0Authorization:
"""Gets the authorization data for a brokerage, authorizing if necessary.

:param auth0_client: An instance of Auth0Client, containing methods to interact with live/auth0/* API endpoints.
Expand All @@ -28,18 +28,18 @@ def get_authorization(auth0_client: Auth0Client, brokerage_id: str, logger: Logg
"""
from time import time, sleep

data = auth0_client.read(brokerage_id)
data = auth0_client.read(brokerage_id, user_name=user_name)
if data.authorization is not None:
return data

start_time = time()
auth0_client.authorize(brokerage_id, logger, project_id, no_browser)
auth0_client.authorize(brokerage_id, logger, project_id, no_browser, user_name=user_name)

# keep checking for new data every 5 seconds for 7 minutes
while time() - start_time < 420:
logger.debug("Will sleep 5 seconds and retry fetching authorization...")
sleep(5)
data = auth0_client.read(brokerage_id)
data = auth0_client.read(brokerage_id, user_name=user_name)
if data.authorization is None:
continue
return data
Expand Down
1 change: 1 addition & 0 deletions lean/models/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@ class AuthConfiguration(InternalInputUserInput):
def __init__(self, config_json_object):
super().__init__(config_json_object)
self.require_project_id = config_json_object.get("require-project-id", False)
self.require_user_name = config_json_object.get("require-user-name", False)

def factory(config_json_object) -> 'AuthConfiguration':
"""Creates an instance of the child classes.
Expand Down
36 changes: 35 additions & 1 deletion lean/models/json_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,30 @@ def convert_variable_to_lean_key(self, variable_key: str) -> str:
"""
return variable_key.replace('_', '-')

def get_user_name(self, lean_config: Dict[str, Any], configuration, user_provided_options: Dict[str, Any], require_user_name: bool) -> str:
"""Retrieve the user name, prompting the user if required and not already set.

:param lean_config: The Lean config dict to read defaults from.
:param configuration: The AuthConfiguration instance.
:param user_provided_options: Options passed as command-line arguments.
:param require_user_name: Flag to determine if prompting is necessary.
:return: The user name, or None if not required.
"""
if not require_user_name:
return None
from click import prompt
user_name_key = configuration._id.replace("-oauth-token", "") + "-user-name"
user_name_variable = self.convert_lean_key_to_variable(user_name_key)
if user_name_variable in user_provided_options and user_provided_options[user_name_variable]:
return user_provided_options[user_name_variable]
if lean_config and lean_config.get(user_name_key):
return lean_config[user_name_key]
user_name = prompt("Please enter your Login ID to proceed with Auth0 authentication",
show_default=False)
if lean_config is not None:
lean_config[user_name_key] = user_name
return user_name

def get_project_id(self, default_project_id: int, require_project_id: bool) -> int:
"""Retrieve the project ID, prompting the user if required and default is invalid.

Expand Down Expand Up @@ -238,8 +262,12 @@ def config_build(self,
lean_config["project-id"] = self.get_project_id(lean_config["project-id"],
configuration.require_project_id)
logger.debug(f'project_id: {lean_config["project-id"]}')
user_name = self.get_user_name(lean_config, configuration, user_provided_options,
configuration.require_user_name)
logger.debug(f'user_name: {user_name}')
auth_authorizations = get_authorization(container.api_client.auth0, self._display_name.lower(),
logger, lean_config["project-id"], no_browser=no_browser)
logger, lean_config["project-id"], no_browser=no_browser,
user_name=user_name)
logger.debug(f'auth: {auth_authorizations}')
configuration._value = auth_authorizations.get_authorization_config_without_account()
for inner_config in self._lean_configs:
Expand All @@ -255,6 +283,12 @@ def config_build(self,
for account_id in api_account_ids)):
raise ValueError(f"The provided account id '{user_provide_account_id}' is not valid, "
f"available: {api_account_ids}")
existing_account = lean_config.get(inner_config._id)
if existing_account and (existing_account not in api_account_ids
or len(api_account_ids) > 1):
# Clear stale or ambiguous account so the user is prompted
# to select from the current API choices
lean_config.pop(inner_config._id)
break
continue

Expand Down
122 changes: 122 additions & 0 deletions tests/components/api/test_auth0_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from unittest import mock
from lean.constants import API_BASE_URL
from lean.components.api.api_client import APIClient
from lean.components.api.auth0_client import Auth0Client
from lean.components.util.http_client import HTTPClient


Expand Down Expand Up @@ -49,6 +50,127 @@ def test_auth0client_trade_station() -> None:
assert len(result.get_account_ids()) > 0


def test_auth0client_authorize_with_user_name() -> None:
with mock.patch("webbrowser.open") as mock_open:
Auth0Client.authorize("charles-schwab", mock.Mock(), 123, user_name="test_login")
mock_open.assert_called_once()
called_url = mock_open.call_args[0][0]
assert "&userId=test_login" in called_url


def test_auth0client_authorize_without_user_name() -> None:
with mock.patch("webbrowser.open") as mock_open:
Auth0Client.authorize("charles-schwab", mock.Mock(), 123)
mock_open.assert_called_once()
called_url = mock_open.call_args[0][0]
assert "userId" not in called_url


@responses.activate
def test_auth0client_read_with_user_name() -> None:
api_clint = APIClient(mock.Mock(), HTTPClient(mock.Mock()), user_id="123", api_token="abc")

responses.add(
responses.POST,
f"{API_BASE_URL}live/auth0/read",
json={
"authorization": {
"charles-schwab-access-token": "abc123",
"accounts": [{"id": "ACC001", "name": "ACC001 | Individual | USD"}]
},
"success": "true"},
status=200
)

result = api_clint.auth0.read("charles-schwab", user_name="test_login")

assert result
assert result.authorization
sent_body = responses.calls[0].request.body.decode()
assert "userId" in sent_body
assert "test_login" in sent_body


@responses.activate
def test_auth0client_read_without_user_name() -> None:
api_clint = APIClient(mock.Mock(), HTTPClient(mock.Mock()), user_id="123", api_token="abc")

responses.add(
responses.POST,
f"{API_BASE_URL}live/auth0/read",
json={
"authorization": {
"charles-schwab-access-token": "abc123",
"accounts": [{"id": "ACC001", "name": "ACC001 | Individual | USD"}]
},
"success": "true"},
status=200
)

result = api_clint.auth0.read("charles-schwab")

assert result
assert result.authorization
sent_body = responses.calls[0].request.body.decode()
assert "userId" not in sent_body


@responses.activate
def test_auth0client_read_caches_without_user_name() -> None:
api_clint = APIClient(mock.Mock(), HTTPClient(mock.Mock()), user_id="123", api_token="abc")

responses.add(
responses.POST,
f"{API_BASE_URL}live/auth0/read",
json={
"authorization": {
"charles-schwab-access-token": "abc123",
"accounts": [{"id": "ACC001", "name": "ACC001 | Individual | USD"}]
},
"success": "true"},
status=200
)

api_clint.auth0.read("charles-schwab")
api_clint.auth0.read("charles-schwab")

assert len(responses.calls) == 1


@responses.activate
def test_auth0client_read_caches_per_user_name() -> None:
api_clint = APIClient(mock.Mock(), HTTPClient(mock.Mock()), user_id="123", api_token="abc")

responses.add(
responses.POST,
f"{API_BASE_URL}live/auth0/read",
json={
"authorization": {
"charles-schwab-access-token": "abc123",
"accounts": [{"id": "ACC001", "name": "ACC001 | Individual | USD"}]
},
"success": "true"},
status=200
)
responses.add(
responses.POST,
f"{API_BASE_URL}live/auth0/read",
json={
"authorization": {
"charles-schwab-access-token": "xyz789",
"accounts": [{"id": "ACC002", "name": "ACC002 | Individual | USD"}]
},
"success": "true"},
status=200
)

api_clint.auth0.read("charles-schwab", user_name="user_a")
api_clint.auth0.read("charles-schwab", user_name="user_a") # cache hit
api_clint.auth0.read("charles-schwab", user_name="user_b") # different user — new call

assert len(responses.calls) == 2


@responses.activate
def test_auth0client_alpaca() -> None:
api_clint = APIClient(mock.Mock(), HTTPClient(mock.Mock()), user_id="123", api_token="abc")
Expand Down
Loading
Loading