Skip to content

Commit 4a4f946

Browse files
authored
Added auth.py
1 parent 8daa6d3 commit 4a4f946

File tree

1 file changed

+90
-0
lines changed

1 file changed

+90
-0
lines changed

auth.py

Lines changed: 90 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,90 @@
1+
import requests
2+
from requests.auth import AuthBase
3+
import msal
4+
from logging import getLogger
5+
import json
6+
import os
7+
8+
logger = getLogger(__name__)
9+
10+
11+
class DeviceCodeFlowTokenAuth(AuthBase):
12+
_DEFAULT_TOKEN_CACHE_FILE_NAME = 'token'
13+
_DEFAULT_TOKEN_CACHE_DIR = os.path.expanduser(
14+
os.path.join('~', '.{}'.format('myapp')))
15+
_DEFAULT_TOKEN_CACHE_FILE_PATH = os.path.join(
16+
_DEFAULT_TOKEN_CACHE_DIR, _DEFAULT_TOKEN_CACHE_FILE_NAME)
17+
18+
def __init__(self, auth_config):
19+
self.cache = self.__getTokenCache()
20+
self.app = msal.PublicClientApplication(
21+
auth_config['client_id'], authority=auth_config['authority'], token_cache=self.cache)
22+
self.config = auth_config
23+
24+
def __call__(self, r):
25+
token = self.__getTokenFromCache()
26+
if not token:
27+
token = self.__getTokenFromAD()
28+
if "access_token" in token:
29+
logger.info("Access token acquired successfully")
30+
bearer = 'Bearer {token}'.format(token=token['access_token'])
31+
r.headers['Authorization'] = bearer
32+
self.__saveTokenCache()
33+
else:
34+
logger.info("Token does not contain access_token")
35+
logger.info("Token Result: {token}".format(token=token))
36+
return r
37+
38+
def __getTokenFromCache(self):
39+
accounts = self.app.get_accounts()
40+
if accounts:
41+
logger.info(
42+
"Account(s) exists in cache, probably with token too. Let's try.")
43+
logger.info("Trying with account: {account}".format(
44+
account=accounts[0]))
45+
return self.app.acquire_token_silent(
46+
self.config["scope"], account=accounts[0])
47+
logger.info("No accounts found")
48+
return None
49+
50+
def __getTokenFromAD(self):
51+
logger.info(
52+
"No suitable token exists in cache. Let's get a new one from AAD.")
53+
flow = self.app.initiate_device_flow(scopes=self.config["scope"])
54+
if "user_code" not in flow:
55+
raise ValueError(
56+
"Fail to create device flow. Err: %s" % json.dumps(flow, indent=4))
57+
logger.warning(flow["message"])
58+
return self.app.acquire_token_by_device_flow(flow)
59+
60+
def __getTokenCache(self):
61+
cache = msal.SerializableTokenCache()
62+
if os.path.exists(self._DEFAULT_TOKEN_CACHE_FILE_PATH):
63+
logger.info(
64+
f'Looking for token cache in {self._DEFAULT_TOKEN_CACHE_FILE_PATH}')
65+
try:
66+
with open(self._DEFAULT_TOKEN_CACHE_FILE_PATH) as f:
67+
cache.deserialize(f.read())
68+
logger.info('Token cache deserialized successfully')
69+
except:
70+
logger.exception('Unable to deserialize token cache')
71+
try:
72+
os.remove(self._DEFAULT_TOKEN_CACHE_FILE_PATH)
73+
except:
74+
logger.info(
75+
f'Unable to delete cache at path {self._DEFAULT_TOKEN_CACHE_FILE_PATH}', exc_info=1)
76+
else:
77+
logger.info(
78+
f'Token cache does not exist at path {self._DEFAULT_TOKEN_CACHE_FILE_PATH}')
79+
return cache
80+
81+
def __saveTokenCache(self):
82+
try:
83+
if not os.path.exists(self._DEFAULT_TOKEN_CACHE_DIR):
84+
os.makedirs(self._DEFAULT_TOKEN_CACHE_DIR)
85+
with open(self._DEFAULT_TOKEN_CACHE_FILE_PATH, 'w') as f:
86+
f.write(self.cache.serialize())
87+
logger.info(
88+
f'Token cache successfully serialzied to {self._DEFAULT_TOKEN_CACHE_FILE_PATH}')
89+
except:
90+
logger.exception('Unable to serialize token cache')

0 commit comments

Comments
 (0)