|
| 1 | +import requests |
| 2 | +from requests.auth import AuthBase |
| 3 | +import msal |
| 4 | +from logging import getLogger |
| 5 | +import json |
| 6 | +import os |
| 7 | + |
| 8 | +logger = getLogger(__name__) |
| 9 | + |
| 10 | + |
| 11 | +class DeviceCodeFlowTokenAuth(AuthBase): |
| 12 | + _DEFAULT_TOKEN_CACHE_FILE_NAME = 'token' |
| 13 | + _DEFAULT_TOKEN_CACHE_DIR = os.path.expanduser( |
| 14 | + os.path.join('~', '.{}'.format('myapp'))) |
| 15 | + _DEFAULT_TOKEN_CACHE_FILE_PATH = os.path.join( |
| 16 | + _DEFAULT_TOKEN_CACHE_DIR, _DEFAULT_TOKEN_CACHE_FILE_NAME) |
| 17 | + |
| 18 | + def __init__(self, auth_config): |
| 19 | + self.cache = self.__getTokenCache() |
| 20 | + self.app = msal.PublicClientApplication( |
| 21 | + auth_config['client_id'], authority=auth_config['authority'], token_cache=self.cache) |
| 22 | + self.config = auth_config |
| 23 | + |
| 24 | + def __call__(self, r): |
| 25 | + token = self.__getTokenFromCache() |
| 26 | + if not token: |
| 27 | + token = self.__getTokenFromAD() |
| 28 | + if "access_token" in token: |
| 29 | + logger.info("Access token acquired successfully") |
| 30 | + bearer = 'Bearer {token}'.format(token=token['access_token']) |
| 31 | + r.headers['Authorization'] = bearer |
| 32 | + self.__saveTokenCache() |
| 33 | + else: |
| 34 | + logger.info("Token does not contain access_token") |
| 35 | + logger.info("Token Result: {token}".format(token=token)) |
| 36 | + return r |
| 37 | + |
| 38 | + def __getTokenFromCache(self): |
| 39 | + accounts = self.app.get_accounts() |
| 40 | + if accounts: |
| 41 | + logger.info( |
| 42 | + "Account(s) exists in cache, probably with token too. Let's try.") |
| 43 | + logger.info("Trying with account: {account}".format( |
| 44 | + account=accounts[0])) |
| 45 | + return self.app.acquire_token_silent( |
| 46 | + self.config["scope"], account=accounts[0]) |
| 47 | + logger.info("No accounts found") |
| 48 | + return None |
| 49 | + |
| 50 | + def __getTokenFromAD(self): |
| 51 | + logger.info( |
| 52 | + "No suitable token exists in cache. Let's get a new one from AAD.") |
| 53 | + flow = self.app.initiate_device_flow(scopes=self.config["scope"]) |
| 54 | + if "user_code" not in flow: |
| 55 | + raise ValueError( |
| 56 | + "Fail to create device flow. Err: %s" % json.dumps(flow, indent=4)) |
| 57 | + logger.warning(flow["message"]) |
| 58 | + return self.app.acquire_token_by_device_flow(flow) |
| 59 | + |
| 60 | + def __getTokenCache(self): |
| 61 | + cache = msal.SerializableTokenCache() |
| 62 | + if os.path.exists(self._DEFAULT_TOKEN_CACHE_FILE_PATH): |
| 63 | + logger.info( |
| 64 | + f'Looking for token cache in {self._DEFAULT_TOKEN_CACHE_FILE_PATH}') |
| 65 | + try: |
| 66 | + with open(self._DEFAULT_TOKEN_CACHE_FILE_PATH) as f: |
| 67 | + cache.deserialize(f.read()) |
| 68 | + logger.info('Token cache deserialized successfully') |
| 69 | + except: |
| 70 | + logger.exception('Unable to deserialize token cache') |
| 71 | + try: |
| 72 | + os.remove(self._DEFAULT_TOKEN_CACHE_FILE_PATH) |
| 73 | + except: |
| 74 | + logger.info( |
| 75 | + f'Unable to delete cache at path {self._DEFAULT_TOKEN_CACHE_FILE_PATH}', exc_info=1) |
| 76 | + else: |
| 77 | + logger.info( |
| 78 | + f'Token cache does not exist at path {self._DEFAULT_TOKEN_CACHE_FILE_PATH}') |
| 79 | + return cache |
| 80 | + |
| 81 | + def __saveTokenCache(self): |
| 82 | + try: |
| 83 | + if not os.path.exists(self._DEFAULT_TOKEN_CACHE_DIR): |
| 84 | + os.makedirs(self._DEFAULT_TOKEN_CACHE_DIR) |
| 85 | + with open(self._DEFAULT_TOKEN_CACHE_FILE_PATH, 'w') as f: |
| 86 | + f.write(self.cache.serialize()) |
| 87 | + logger.info( |
| 88 | + f'Token cache successfully serialzied to {self._DEFAULT_TOKEN_CACHE_FILE_PATH}') |
| 89 | + except: |
| 90 | + logger.exception('Unable to serialize token cache') |
0 commit comments