Skip to content

Commit 94ad8bd

Browse files
Merge pull request #41 from smalihaider/stateless-code-flow
Add support for stateless code flow
2 parents b2ac7fe + c58cf7e commit 94ad8bd

File tree

9 files changed

+1146
-61
lines changed

9 files changed

+1146
-61
lines changed

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
description='OpenID Connect Provider (OP) library in Python.',
1313
install_requires=[
1414
'oic >= 1.2.1',
15+
'pycryptodomex',
1516
],
1617
extras_require={
1718
'mongo': 'pymongo',

src/pyop/authz_state.py

Lines changed: 88 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from .exceptions import InvalidRefreshToken
1212
from .exceptions import InvalidScope
1313
from .exceptions import InvalidSubjectIdentifier
14+
from .storage import StatelessWrapper
1415
from .util import requested_scope_is_allowed
1516

1617
logger = logging.getLogger(__name__)
@@ -24,13 +25,15 @@ def rand_str():
2425

2526
class AuthorizationState(object):
2627
KEY_AUTHORIZATION_REQUEST = 'auth_req'
28+
KEY_USER_INFO = 'user_info'
29+
KEY_EXTRA_ID_TOKEN_CLAIMS = 'extra_id_token_claims'
2730

2831
def __init__(self, subject_identifier_factory, authorization_code_db=None, access_token_db=None,
2932
refresh_token_db=None, subject_identifier_db=None, *,
3033
authorization_code_lifetime=600, access_token_lifetime=3600, refresh_token_lifetime=None,
3134
refresh_token_threshold=None):
3235
# type: (se_leg_op.token_state.SubjectIdentifierFactory, Mapping[str, Any], Mapping[str, Any],
33-
# Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
36+
# Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
3437
"""
3538
:param subject_identifier_factory: callable to use when construction subject identifiers
3639
:param authorization_code_db: database for storing authorization codes, defaults to in-memory
@@ -77,10 +80,28 @@ def __init__(self, subject_identifier_factory, authorization_code_db=None, acces
7780
"""
7881
Mapping of user id's to subject identifiers.
7982
"""
80-
self.subject_identifiers = subject_identifier_db if subject_identifier_db is not None else {}
81-
82-
def create_authorization_code(self, authorization_request, subject_identifier, scope=None):
83-
# type: (AuthorizationRequest, str, Optional[List[str]]) -> str
83+
self.stateless = (
84+
isinstance(self.authorization_codes, StatelessWrapper)
85+
or isinstance(self.access_tokens, StatelessWrapper)
86+
or isinstance(self.refresh_tokens, StatelessWrapper)
87+
)
88+
self.subject_identifiers = (
89+
{}
90+
if self.stateless
91+
else subject_identifier_db
92+
if subject_identifier_db is not None
93+
else {}
94+
)
95+
96+
def create_authorization_code(
97+
self,
98+
authorization_request,
99+
subject_identifier,
100+
scope=None,
101+
user_info=None,
102+
extra_id_token_claims=None,
103+
):
104+
# type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict], Optional[Mappings[str, Union[str, List[str]]]]) -> str
84105
"""
85106
Creates an authorization code bound to the authorization request and the authenticated user identified
86107
by the subject identifier.
@@ -92,21 +113,29 @@ def create_authorization_code(self, authorization_request, subject_identifier, s
92113
scope = ' '.join(scope or authorization_request['scope'])
93114
logger.debug('creating authz code for scope=%s', scope)
94115

95-
authorization_code = rand_str()
96116
authz_info = {
97117
'used': False,
98118
'exp': int(time.time()) + self.authorization_code_lifetime,
99119
'sub': subject_identifier,
100120
'granted_scope': scope,
101121
self.KEY_AUTHORIZATION_REQUEST: authorization_request.to_dict()
102122
}
103-
self.authorization_codes[authorization_code] = authz_info
123+
124+
if self.stateless:
125+
if user_info:
126+
authz_info[self.KEY_USER_INFO] = user_info
127+
authz_info[self.KEY_EXTRA_ID_TOKEN_CLAIMS] = extra_id_token_claims or {}
128+
authorization_code = self.authorization_codes.pack(authz_info)
129+
else:
130+
authorization_code = rand_str()
131+
self.authorization_codes[authorization_code] = authz_info
132+
104133
logger.debug('new authz_code=%s to client_id=%s for sub=%s valid_until=%s', authorization_code,
105134
authorization_request['client_id'], subject_identifier, authz_info['exp'])
106135
return authorization_code
107136

108-
def create_access_token(self, authorization_request, subject_identifier, scope=None):
109-
# type: (AuthorizationRequest, str, Optional[List[str]]) -> se_leg_op.access_token.AccessToken
137+
def create_access_token(self, authorization_request, subject_identifier, scope=None, user_info=None):
138+
# type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict]) -> se_leg_op.access_token.AccessToken
110139
"""
111140
Creates an access token bound to the authentication request and the authenticated user identified by the
112141
subject identifier.
@@ -116,15 +145,15 @@ def create_access_token(self, authorization_request, subject_identifier, scope=N
116145

117146
scope = scope or authorization_request['scope']
118147

119-
return self._create_access_token(subject_identifier, authorization_request.to_dict(), ' '.join(scope))
148+
return self._create_access_token(subject_identifier, authorization_request.to_dict(), ' '.join(scope),
149+
user_info=user_info)
120150

121-
def _create_access_token(self, subject_identifier, auth_req, granted_scope, current_scope=None):
122-
# type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str]) -> se_leg_op.access_token.AccessToken
151+
def _create_access_token(self, subject_identifier, auth_req, granted_scope, current_scope=None,
152+
user_info=None):
153+
# type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str], Optional[dict]) -> se_leg_op.access_token.AccessToken
123154
"""
124155
Creates an access token bound to the subject identifier, client id and requested scope.
125156
"""
126-
access_token = AccessToken(rand_str(), self.access_token_lifetime)
127-
128157
scope = current_scope or granted_scope
129158
logger.debug('creating access token for scope=%s', scope)
130159

@@ -136,13 +165,21 @@ def _create_access_token(self, subject_identifier, auth_req, granted_scope, curr
136165
'aud': [auth_req['client_id']],
137166
'scope': scope,
138167
'granted_scope': granted_scope,
139-
'token_type': access_token.BEARER_TOKEN_TYPE,
168+
'token_type': AccessToken.BEARER_TOKEN_TYPE,
140169
self.KEY_AUTHORIZATION_REQUEST: auth_req
141170
}
142-
self.access_tokens[access_token.value] = authz_info
171+
172+
if self.stateless:
173+
if user_info:
174+
authz_info[self.KEY_USER_INFO] = user_info
175+
access_token_val = self.access_tokens.pack(authz_info)
176+
else:
177+
access_token_val = rand_str()
178+
self.access_tokens[access_token_val] = authz_info
143179

144180
logger.debug('new access_token=%s to client_id=%s for sub=%s valid_until=%s',
145-
access_token.value, auth_req['client_id'], subject_identifier, authz_info['exp'])
181+
access_token_val, auth_req['client_id'], subject_identifier, authz_info['exp'])
182+
access_token = AccessToken(access_token_val, self.access_token_lifetime)
146183
return access_token
147184

148185
def exchange_code_for_token(self, authorization_code):
@@ -165,7 +202,8 @@ def exchange_code_for_token(self, authorization_code):
165202
authz_info['used'] = True
166203

167204
access_token = self._create_access_token(authz_info['sub'], authz_info[self.KEY_AUTHORIZATION_REQUEST],
168-
authz_info['granted_scope'])
205+
authz_info['granted_scope'],
206+
user_info=authz_info.get(self.KEY_USER_INFO))
169207

170208
logger.debug('authz_code=%s exchanged to access_token=%s', authorization_code, access_token.value)
171209
return access_token
@@ -199,9 +237,13 @@ def create_refresh_token(self, access_token_value):
199237
logger.debug('no refresh token issued for for access_token=%s', access_token_value)
200238
return None
201239

202-
refresh_token = rand_str()
203240
authz_info = {'access_token': access_token_value, 'exp': int(time.time()) + self.refresh_token_lifetime}
204-
self.refresh_tokens[refresh_token] = authz_info
241+
242+
if self.stateless:
243+
refresh_token = self.refresh_tokens.pack(authz_info)
244+
else:
245+
refresh_token = rand_str()
246+
self.refresh_tokens[refresh_token] = authz_info
205247

206248
logger.debug('issued refresh_token=%s expiring=%d for access_token=%s', refresh_token, authz_info['exp'],
207249
access_token_value)
@@ -235,7 +277,8 @@ def use_refresh_token(self, refresh_token, scope=None):
235277
scope = authz_info['granted_scope']
236278

237279
new_access_token = self._create_access_token(authz_info['sub'], authz_info[self.KEY_AUTHORIZATION_REQUEST],
238-
authz_info['granted_scope'], scope)
280+
authz_info['granted_scope'], scope,
281+
user_info=authz_info.get(self.KEY_USER_INFO))
239282

240283
new_refresh_token = None
241284
if self.refresh_token_threshold \
@@ -293,7 +336,7 @@ def get_subject_identifier(self, subject_type, user_id, sector_identifier=None):
293336
raise ValueError('Unknown subject_type={}'.format(subject_type))
294337

295338
def _is_valid_subject_identifier(self, sub):
296-
# type: (str) -> str
339+
# type: (str) -> bool
297340
"""
298341
Determines whether the subject identifier is known.
299342
"""
@@ -307,13 +350,33 @@ def _is_valid_subject_identifier(self, sub):
307350
def get_user_id_for_subject_identifier(self, subject_identifier):
308351
for user_id, subject_identifiers in self.subject_identifiers.items():
309352
is_public_sub = 'public' in subject_identifiers and subject_identifier == subject_identifiers['public']
310-
is_pairwise_sub = 'pairwise' in subject_identifiers and subject_identifier in subject_identifiers[
311-
'pairwise']
353+
is_pairwise_sub = 'pairwise' in subject_identifiers and subject_identifier in subject_identifiers['pairwise']
312354
if is_public_sub or is_pairwise_sub:
313355
return user_id
314356

315357
raise InvalidSubjectIdentifier('{} unknown'.format(subject_identifier))
316358

359+
def get_user_info_for_code(self, authorization_code):
360+
# type: (str) -> dict
361+
if authorization_code not in self.authorization_codes:
362+
raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))
363+
364+
return self.authorization_codes[authorization_code].get(self.KEY_USER_INFO)
365+
366+
def get_extra_io_token_claims_for_code(self, authorization_code):
367+
# type: (str) -> dict
368+
if authorization_code not in self.authorization_codes:
369+
raise InvalidAuthorizationCode('{} unknown'.format(authorization_code))
370+
371+
return self.authorization_codes[authorization_code].get(self.KEY_EXTRA_ID_TOKEN_CLAIMS)
372+
373+
def get_user_info_for_access_token(self, access_token):
374+
# type: (str) -> dict
375+
if access_token not in self.access_tokens:
376+
raise InvalidAccessToken('{} unknown'.format(access_token))
377+
378+
return self.access_tokens[access_token].get(self.KEY_USER_INFO)
379+
317380
def get_authorization_request_for_code(self, authorization_code):
318381
# type: (str) -> AuthorizationRequest
319382
if authorization_code not in self.authorization_codes:
@@ -323,7 +386,7 @@ def get_authorization_request_for_code(self, authorization_code):
323386
self.authorization_codes[authorization_code][self.KEY_AUTHORIZATION_REQUEST])
324387

325388
def get_authorization_request_for_access_token(self, access_token_value):
326-
# type: (str) ->
389+
# type: (str) ->
327390
if access_token_value not in self.access_tokens:
328391
raise InvalidAccessToken('{} unknown'.format(access_token_value))
329392

src/pyop/crypto.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
import base64
2+
import hashlib
3+
4+
from Cryptodome import Random
5+
from Cryptodome.Cipher import AES
6+
7+
8+
class _AESCipher(object):
9+
"""
10+
This class will perform AES encryption/decryption with a keylength of 256.
11+
12+
@see: http://stackoverflow.com/questions/12524994/encrypt-decrypt-using-pycrypto-aes-256
13+
"""
14+
15+
def __init__(self, key):
16+
"""
17+
Constructor
18+
19+
:type key: str
20+
21+
:param key: The key used for encryption and decryption. The longer key the better.
22+
"""
23+
self.bs = 32
24+
self.key = hashlib.sha256(key.encode()).digest()
25+
26+
def encrypt(self, raw):
27+
"""
28+
Encryptes the parameter raw.
29+
30+
:type raw: bytes
31+
:rtype: str
32+
33+
:param: bytes to be encrypted.
34+
35+
:return: A base 64 encoded string.
36+
"""
37+
raw = self._pad(raw)
38+
iv = Random.new().read(AES.block_size)
39+
cipher = AES.new(self.key, AES.MODE_CBC, iv)
40+
return base64.urlsafe_b64encode(iv + cipher.encrypt(raw))
41+
42+
def decrypt(self, enc):
43+
"""
44+
Decryptes the parameter enc.
45+
46+
:type enc: bytes
47+
:rtype: bytes
48+
49+
:param: The value to be decrypted.
50+
:return: The decrypted value.
51+
"""
52+
enc = base64.urlsafe_b64decode(enc)
53+
iv = enc[:AES.block_size]
54+
cipher = AES.new(self.key, AES.MODE_CBC, iv)
55+
return self._unpad(cipher.decrypt(enc[AES.block_size:]))
56+
57+
def _pad(self, b):
58+
"""
59+
Will padd the param to be of the correct length for the encryption alg.
60+
61+
:type b: bytes
62+
:rtype: bytes
63+
"""
64+
return b + (self.bs - len(b) % self.bs) * chr(self.bs - len(b) % self.bs).encode("UTF-8")
65+
66+
@staticmethod
67+
def _unpad(b):
68+
"""
69+
Removes the padding performed by the method _pad.
70+
71+
:type b: bytes
72+
:rtype: bytes
73+
"""
74+
return b[:-ord(b[len(b) - 1:])]

0 commit comments

Comments
 (0)