Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add webdriver for oauth #181

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 76 additions & 5 deletions td/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import requests
import urllib.parse

from authlib.integrations.httpx_client import OAuth2Client

from typing import Any
from typing import Dict
from typing import List
Expand Down Expand Up @@ -52,7 +54,7 @@ class TDClient():
"""

def __init__(self, client_id: str, redirect_uri: str, account_number: str = None, credentials_path: str = None,
auth_flow: str = 'default', _do_init: bool = True, _multiprocessing_safe = False) -> None:
auth_flow: str = 'default', _do_init: bool = True, _multiprocessing_safe = False, webdriver_path: str = "") -> None:
"""Creates a new instance of the TDClient Object.

Initializes the session with default values and any user-provided overrides.The
Expand Down Expand Up @@ -129,7 +131,10 @@ def __init__(self, client_id: str, redirect_uri: str, account_number: str = None
self.client_id = client_id
self.redirect_uri = redirect_uri
self.account_number = account_number
self.webdriver_path = webdriver_path

self._token_endpoint = "https://api.tdameritrade.com/v1/oauth2/token"

self.credentials_path = pathlib.Path(credentials_path)
self._td_utilities = TDUtilities()

Expand Down Expand Up @@ -263,7 +268,10 @@ def login(self) -> bool:
self.authstate = True
return True
else:
self.oauth()
if self.auth_flow == 'webdriver':
self.auth_using_webdriver()
else:
self.oauth()
self.authstate = True
return True

Expand Down Expand Up @@ -298,7 +306,7 @@ def grab_access_token(self) -> dict:

# Make the request.
response = requests.post(
url="https://api.tdameritrade.com/v1/oauth2/token",
url=self._token_endpoint,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data=data
)
Expand Down Expand Up @@ -335,7 +343,7 @@ def grab_refresh_token(self) -> bool:

# Make the request.
response = requests.post(
url="https://api.tdameritrade.com/v1/oauth2/token",
url=self._token_endpoint,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data=data
)
Expand Down Expand Up @@ -389,6 +397,69 @@ def oauth(self) -> None:
return_refresh_token=True
)

def auth_using_webdriver(self) -> None:
"""Runs the oAuth process using webdriver for the TD Ameritrade API."""

print(f'Failed to find credentials json file \'{self.credentials_path}\'')

from selenium import webdriver
with webdriver.Chrome(executable_path=self.webdriver_path) as driver:
self.auth_from_login_flow(driver)

def _normalize_api_key(self, api_key):
api_key_suffix = '@AMER.OAUTHAP'

if not api_key.endswith(api_key_suffix):
print(f'Appending {api_key_suffix} to API key')
api_key = api_key + api_key_suffix
return api_key

class RedirectTimeoutError(Exception):
pass

def auth_from_login_flow(self, driver):
print((f'Creating new token with redirect URL \'{self.redirect_uri}\' ' +
f'and credentials path \'{self.credentials_path}\''))

self.client_id = self._normalize_api_key(self.client_id)

oauth = OAuth2Client(self.client_id, redirect_uri=self.redirect_uri)
authorization_url, state = oauth.create_authorization_url(
'https://auth.tdameritrade.com/auth')

driver.get(authorization_url)

# Tolerate redirects to HTTPS on the callback URL
if self.redirect_uri.startswith('http://'):
redirect_urls = (self.redirect_uri, 'https' + self.redirect_uri[4:])
else:
redirect_urls = (self.redirect_uri,)

# Wait until the current URL starts with the callback URL
current_url = ''
num_waits = 0
redirect_wait_time_seconds = 0.1
max_waits = 3000
while not any(current_url.startswith(r_url) for r_url in redirect_urls):
current_url = driver.current_url

if num_waits > max_waits:
raise RedirectTimeoutError('timed out waiting for redirect')
time.sleep(redirect_wait_time_seconds)
num_waits += 1

token = oauth.fetch_token(
self._token_endpoint,
authorization_response=self.redirect_uri,
access_type='offline',
client_id=self.client_id,
include_client_id=True)

print(token)

# TODO: implement refresh token mode
self._token_save(token_dict=token)

def exchange_code_for_token(self, code: str, return_refresh_token: bool) -> dict:
"""Access token handler for AuthCode Workflow.

Expand Down Expand Up @@ -422,7 +493,7 @@ def exchange_code_for_token(self, code: str, return_refresh_token: bool) -> dict

# Make the request.
response = requests.post(
url="https://api.tdameritrade.com/v1/oauth2/token",
url=self._token_endpoint,
headers={'Content-Type': 'application/x-www-form-urlencoded'},
data=data
)
Expand Down