11
11
from .exceptions import InvalidRefreshToken
12
12
from .exceptions import InvalidScope
13
13
from .exceptions import InvalidSubjectIdentifier
14
+ from .storage import StatelessWrapper
14
15
from .util import requested_scope_is_allowed
15
16
16
17
logger = logging .getLogger (__name__ )
@@ -24,13 +25,15 @@ def rand_str():
24
25
25
26
class AuthorizationState (object ):
26
27
KEY_AUTHORIZATION_REQUEST = 'auth_req'
28
+ KEY_USER_INFO = 'user_info'
29
+ KEY_EXTRA_ID_TOKEN_CLAIMS = 'extra_id_token_claims'
27
30
28
31
def __init__ (self , subject_identifier_factory , authorization_code_db = None , access_token_db = None ,
29
32
refresh_token_db = None , subject_identifier_db = None , * ,
30
33
authorization_code_lifetime = 600 , access_token_lifetime = 3600 , refresh_token_lifetime = None ,
31
34
refresh_token_threshold = None ):
32
35
# type: (se_leg_op.token_state.SubjectIdentifierFactory, Mapping[str, Any], Mapping[str, Any],
33
- # Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
36
+ # Mapping[str, Any], Mapping[str, Any], int, int, Optional[int], Optional[int]) -> None
34
37
"""
35
38
:param subject_identifier_factory: callable to use when construction subject identifiers
36
39
:param authorization_code_db: database for storing authorization codes, defaults to in-memory
@@ -77,10 +80,28 @@ def __init__(self, subject_identifier_factory, authorization_code_db=None, acces
77
80
"""
78
81
Mapping of user id's to subject identifiers.
79
82
"""
80
- self .subject_identifiers = subject_identifier_db if subject_identifier_db is not None else {}
81
-
82
- def create_authorization_code (self , authorization_request , subject_identifier , scope = None ):
83
- # type: (AuthorizationRequest, str, Optional[List[str]]) -> str
83
+ self .stateless = (
84
+ isinstance (self .authorization_codes , StatelessWrapper )
85
+ or isinstance (self .access_tokens , StatelessWrapper )
86
+ or isinstance (self .refresh_tokens , StatelessWrapper )
87
+ )
88
+ self .subject_identifiers = (
89
+ {}
90
+ if self .stateless
91
+ else subject_identifier_db
92
+ if subject_identifier_db is not None
93
+ else {}
94
+ )
95
+
96
+ def create_authorization_code (
97
+ self ,
98
+ authorization_request ,
99
+ subject_identifier ,
100
+ scope = None ,
101
+ user_info = None ,
102
+ extra_id_token_claims = None ,
103
+ ):
104
+ # type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict], Optional[Mappings[str, Union[str, List[str]]]]) -> str
84
105
"""
85
106
Creates an authorization code bound to the authorization request and the authenticated user identified
86
107
by the subject identifier.
@@ -92,21 +113,29 @@ def create_authorization_code(self, authorization_request, subject_identifier, s
92
113
scope = ' ' .join (scope or authorization_request ['scope' ])
93
114
logger .debug ('creating authz code for scope=%s' , scope )
94
115
95
- authorization_code = rand_str ()
96
116
authz_info = {
97
117
'used' : False ,
98
118
'exp' : int (time .time ()) + self .authorization_code_lifetime ,
99
119
'sub' : subject_identifier ,
100
120
'granted_scope' : scope ,
101
121
self .KEY_AUTHORIZATION_REQUEST : authorization_request .to_dict ()
102
122
}
103
- self .authorization_codes [authorization_code ] = authz_info
123
+
124
+ if self .stateless :
125
+ if user_info :
126
+ authz_info [self .KEY_USER_INFO ] = user_info
127
+ authz_info [self .KEY_EXTRA_ID_TOKEN_CLAIMS ] = extra_id_token_claims or {}
128
+ authorization_code = self .authorization_codes .pack (authz_info )
129
+ else :
130
+ authorization_code = rand_str ()
131
+ self .authorization_codes [authorization_code ] = authz_info
132
+
104
133
logger .debug ('new authz_code=%s to client_id=%s for sub=%s valid_until=%s' , authorization_code ,
105
134
authorization_request ['client_id' ], subject_identifier , authz_info ['exp' ])
106
135
return authorization_code
107
136
108
- def create_access_token (self , authorization_request , subject_identifier , scope = None ):
109
- # type: (AuthorizationRequest, str, Optional[List[str]]) -> se_leg_op.access_token.AccessToken
137
+ def create_access_token (self , authorization_request , subject_identifier , scope = None , user_info = None ):
138
+ # type: (AuthorizationRequest, str, Optional[List[str]], Optional[dict] ) -> se_leg_op.access_token.AccessToken
110
139
"""
111
140
Creates an access token bound to the authentication request and the authenticated user identified by the
112
141
subject identifier.
@@ -116,15 +145,15 @@ def create_access_token(self, authorization_request, subject_identifier, scope=N
116
145
117
146
scope = scope or authorization_request ['scope' ]
118
147
119
- return self ._create_access_token (subject_identifier , authorization_request .to_dict (), ' ' .join (scope ))
148
+ return self ._create_access_token (subject_identifier , authorization_request .to_dict (), ' ' .join (scope ),
149
+ user_info = user_info )
120
150
121
- def _create_access_token (self , subject_identifier , auth_req , granted_scope , current_scope = None ):
122
- # type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str]) -> se_leg_op.access_token.AccessToken
151
+ def _create_access_token (self , subject_identifier , auth_req , granted_scope , current_scope = None ,
152
+ user_info = None ):
153
+ # type: (str, Mapping[str, Union[str, List[str]]], str, Optional[str], Optional[dict]) -> se_leg_op.access_token.AccessToken
123
154
"""
124
155
Creates an access token bound to the subject identifier, client id and requested scope.
125
156
"""
126
- access_token = AccessToken (rand_str (), self .access_token_lifetime )
127
-
128
157
scope = current_scope or granted_scope
129
158
logger .debug ('creating access token for scope=%s' , scope )
130
159
@@ -136,13 +165,21 @@ def _create_access_token(self, subject_identifier, auth_req, granted_scope, curr
136
165
'aud' : [auth_req ['client_id' ]],
137
166
'scope' : scope ,
138
167
'granted_scope' : granted_scope ,
139
- 'token_type' : access_token .BEARER_TOKEN_TYPE ,
168
+ 'token_type' : AccessToken .BEARER_TOKEN_TYPE ,
140
169
self .KEY_AUTHORIZATION_REQUEST : auth_req
141
170
}
142
- self .access_tokens [access_token .value ] = authz_info
171
+
172
+ if self .stateless :
173
+ if user_info :
174
+ authz_info [self .KEY_USER_INFO ] = user_info
175
+ access_token_val = self .access_tokens .pack (authz_info )
176
+ else :
177
+ access_token_val = rand_str ()
178
+ self .access_tokens [access_token_val ] = authz_info
143
179
144
180
logger .debug ('new access_token=%s to client_id=%s for sub=%s valid_until=%s' ,
145
- access_token .value , auth_req ['client_id' ], subject_identifier , authz_info ['exp' ])
181
+ access_token_val , auth_req ['client_id' ], subject_identifier , authz_info ['exp' ])
182
+ access_token = AccessToken (access_token_val , self .access_token_lifetime )
146
183
return access_token
147
184
148
185
def exchange_code_for_token (self , authorization_code ):
@@ -165,7 +202,8 @@ def exchange_code_for_token(self, authorization_code):
165
202
authz_info ['used' ] = True
166
203
167
204
access_token = self ._create_access_token (authz_info ['sub' ], authz_info [self .KEY_AUTHORIZATION_REQUEST ],
168
- authz_info ['granted_scope' ])
205
+ authz_info ['granted_scope' ],
206
+ user_info = authz_info .get (self .KEY_USER_INFO ))
169
207
170
208
logger .debug ('authz_code=%s exchanged to access_token=%s' , authorization_code , access_token .value )
171
209
return access_token
@@ -199,9 +237,13 @@ def create_refresh_token(self, access_token_value):
199
237
logger .debug ('no refresh token issued for for access_token=%s' , access_token_value )
200
238
return None
201
239
202
- refresh_token = rand_str ()
203
240
authz_info = {'access_token' : access_token_value , 'exp' : int (time .time ()) + self .refresh_token_lifetime }
204
- self .refresh_tokens [refresh_token ] = authz_info
241
+
242
+ if self .stateless :
243
+ refresh_token = self .refresh_tokens .pack (authz_info )
244
+ else :
245
+ refresh_token = rand_str ()
246
+ self .refresh_tokens [refresh_token ] = authz_info
205
247
206
248
logger .debug ('issued refresh_token=%s expiring=%d for access_token=%s' , refresh_token , authz_info ['exp' ],
207
249
access_token_value )
@@ -235,7 +277,8 @@ def use_refresh_token(self, refresh_token, scope=None):
235
277
scope = authz_info ['granted_scope' ]
236
278
237
279
new_access_token = self ._create_access_token (authz_info ['sub' ], authz_info [self .KEY_AUTHORIZATION_REQUEST ],
238
- authz_info ['granted_scope' ], scope )
280
+ authz_info ['granted_scope' ], scope ,
281
+ user_info = authz_info .get (self .KEY_USER_INFO ))
239
282
240
283
new_refresh_token = None
241
284
if self .refresh_token_threshold \
@@ -293,7 +336,7 @@ def get_subject_identifier(self, subject_type, user_id, sector_identifier=None):
293
336
raise ValueError ('Unknown subject_type={}' .format (subject_type ))
294
337
295
338
def _is_valid_subject_identifier (self , sub ):
296
- # type: (str) -> str
339
+ # type: (str) -> bool
297
340
"""
298
341
Determines whether the subject identifier is known.
299
342
"""
@@ -307,13 +350,33 @@ def _is_valid_subject_identifier(self, sub):
307
350
def get_user_id_for_subject_identifier (self , subject_identifier ):
308
351
for user_id , subject_identifiers in self .subject_identifiers .items ():
309
352
is_public_sub = 'public' in subject_identifiers and subject_identifier == subject_identifiers ['public' ]
310
- is_pairwise_sub = 'pairwise' in subject_identifiers and subject_identifier in subject_identifiers [
311
- 'pairwise' ]
353
+ is_pairwise_sub = 'pairwise' in subject_identifiers and subject_identifier in subject_identifiers ['pairwise' ]
312
354
if is_public_sub or is_pairwise_sub :
313
355
return user_id
314
356
315
357
raise InvalidSubjectIdentifier ('{} unknown' .format (subject_identifier ))
316
358
359
+ def get_user_info_for_code (self , authorization_code ):
360
+ # type: (str) -> dict
361
+ if authorization_code not in self .authorization_codes :
362
+ raise InvalidAuthorizationCode ('{} unknown' .format (authorization_code ))
363
+
364
+ return self .authorization_codes [authorization_code ].get (self .KEY_USER_INFO )
365
+
366
+ def get_extra_io_token_claims_for_code (self , authorization_code ):
367
+ # type: (str) -> dict
368
+ if authorization_code not in self .authorization_codes :
369
+ raise InvalidAuthorizationCode ('{} unknown' .format (authorization_code ))
370
+
371
+ return self .authorization_codes [authorization_code ].get (self .KEY_EXTRA_ID_TOKEN_CLAIMS )
372
+
373
+ def get_user_info_for_access_token (self , access_token ):
374
+ # type: (str) -> dict
375
+ if access_token not in self .access_tokens :
376
+ raise InvalidAccessToken ('{} unknown' .format (access_token ))
377
+
378
+ return self .access_tokens [access_token ].get (self .KEY_USER_INFO )
379
+
317
380
def get_authorization_request_for_code (self , authorization_code ):
318
381
# type: (str) -> AuthorizationRequest
319
382
if authorization_code not in self .authorization_codes :
@@ -323,7 +386,7 @@ def get_authorization_request_for_code(self, authorization_code):
323
386
self .authorization_codes [authorization_code ][self .KEY_AUTHORIZATION_REQUEST ])
324
387
325
388
def get_authorization_request_for_access_token (self , access_token_value ):
326
- # type: (str) ->
389
+ # type: (str) ->
327
390
if access_token_value not in self .access_tokens :
328
391
raise InvalidAccessToken ('{} unknown' .format (access_token_value ))
329
392
0 commit comments