4
4
5
5
using System ;
6
6
using System . Collections . Concurrent ;
7
+ using System . Linq ;
8
+ using System . Runtime . Caching ;
7
9
using System . Security ;
10
+ using System . Security . Cryptography ;
11
+ using System . Text ;
8
12
using System . Threading ;
9
13
using System . Threading . Tasks ;
10
14
using Azure . Core ;
@@ -24,6 +28,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
24
28
/// </summary>
25
29
private static ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > s_pcaMap
26
30
= new ConcurrentDictionary < PublicClientAppKey , IPublicClientApplication > ( ) ;
31
+ private static readonly MemoryCache s_accountPwCache = new ( nameof ( ActiveDirectoryAuthenticationProvider ) ) ;
32
+ private static readonly int s_accountPwCacheTtlInHours = 2 ;
27
33
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient" ;
28
34
private static readonly string s_defaultScopeSuffix = "/.default" ;
29
35
private readonly string _type = typeof ( ActiveDirectoryAuthenticationProvider ) . Name ;
@@ -171,7 +177,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
171
177
return new SqlAuthenticationToken ( accessToken . Token , accessToken . ExpiresOn ) ;
172
178
}
173
179
174
- AuthenticationResult result ;
180
+ AuthenticationResult result = null ;
175
181
if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryServicePrincipal )
176
182
{
177
183
AccessToken accessToken = await new ClientSecretCredential ( audience , parameters . UserId , parameters . Password , tokenCredentialOptions ) . GetTokenAsync ( tokenRequestContext , cts . Token ) . ConfigureAwait ( false ) ;
@@ -207,86 +213,87 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
207
213
208
214
if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryIntegrated )
209
215
{
210
- if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
211
- {
212
- result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
213
- . WithCorrelationId ( parameters . ConnectionId )
214
- . WithUsername ( parameters . UserId )
215
- . ExecuteAsync ( cancellationToken : cts . Token )
216
- . ConfigureAwait ( false ) ;
217
- }
218
- else
219
- {
220
- result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
221
- . WithCorrelationId ( parameters . ConnectionId )
222
- . ExecuteAsync ( cancellationToken : cts . Token )
223
- . ConfigureAwait ( false ) ;
224
- }
225
- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
226
- }
227
- else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
228
- {
229
- SecureString password = new SecureString ( ) ;
230
- foreach ( char c in parameters . Password )
231
- password . AppendChar ( c ) ;
232
- password . MakeReadOnly ( ) ;
233
-
234
- result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , password )
235
- . WithCorrelationId ( parameters . ConnectionId )
236
- . ExecuteAsync ( cancellationToken : cts . Token )
237
- . ConfigureAwait ( false ) ;
238
- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
239
- }
240
- else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
241
- parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
242
- {
243
- // Fetch available accounts from 'app' instance
244
- System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) . ConfigureAwait ( false ) ) . GetEnumerator ( ) ;
216
+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
245
217
246
- IAccount account = default ;
247
- if ( accounts . MoveNext ( ) )
218
+ if ( null == result )
248
219
{
249
220
if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
250
221
{
251
- do
252
- {
253
- IAccount currentVal = accounts . Current ;
254
- if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
255
- {
256
- account = currentVal ;
257
- break ;
258
- }
259
- }
260
- while ( accounts . MoveNext ( ) ) ;
222
+ result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
223
+ . WithCorrelationId ( parameters . ConnectionId )
224
+ . WithUsername ( parameters . UserId )
225
+ . ExecuteAsync ( cancellationToken : cts . Token )
226
+ . ConfigureAwait ( false ) ;
261
227
}
262
228
else
263
229
{
264
- account = accounts . Current ;
230
+ result = await app . AcquireTokenByIntegratedWindowsAuth ( scopes )
231
+ . WithCorrelationId ( parameters . ConnectionId )
232
+ . ExecuteAsync ( cancellationToken : cts . Token )
233
+ . ConfigureAwait ( false ) ;
265
234
}
235
+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
236
+ }
237
+ }
238
+ else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryPassword )
239
+ {
240
+ string pwCacheKey = GetAccountPwCacheKey ( parameters ) ;
241
+ object previousPw = s_accountPwCache . Get ( pwCacheKey ) ;
242
+ byte [ ] currPwHash = GetHash ( parameters . Password ) ;
243
+
244
+ if ( null != previousPw &&
245
+ previousPw is byte [ ] previousPwBytes &&
246
+ // Only get the cached token if the current password hash matches the previously used password hash
247
+ currPwHash . SequenceEqual ( previousPwBytes ) )
248
+ {
249
+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
266
250
}
267
251
268
- if ( null != account )
252
+ if ( null == result )
269
253
{
270
- try
254
+ SecureString password = new SecureString ( ) ;
255
+ foreach ( char c in parameters . Password )
256
+ password . AppendChar ( c ) ;
257
+ password . MakeReadOnly ( ) ;
258
+
259
+ result = await app . AcquireTokenByUsernamePassword ( scopes , parameters . UserId , password )
260
+ . WithCorrelationId ( parameters . ConnectionId )
261
+ . ExecuteAsync ( cancellationToken : cts . Token )
262
+ . ConfigureAwait ( false ) ;
263
+
264
+ // We cache the password hash to ensure future connection requests include a validated password
265
+ // when we check for a cached MSAL account. Otherwise, a connection request with the same username
266
+ // against the same tenant could succeed with an invalid password when we re-use the cached token.
267
+ if ( ! s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) )
271
268
{
272
- // If 'account' is available in 'app', we use the same to acquire token silently.
273
- // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
274
- result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) . ConfigureAwait ( false ) ;
275
- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
276
- }
277
- catch ( MsalUiRequiredException )
278
- {
279
- // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
280
- // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
281
- // or the user needs to perform two factor authentication.
282
- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts ) . ConfigureAwait ( false ) ;
283
- SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
269
+ s_accountPwCache . Remove ( pwCacheKey ) ;
270
+ s_accountPwCache . Add ( pwCacheKey , GetHash ( parameters . Password ) , DateTime . UtcNow . AddHours ( s_accountPwCacheTtlInHours ) ) ;
284
271
}
272
+
273
+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}" , result ? . ExpiresOn ) ;
285
274
}
286
- else
275
+ }
276
+ else if ( parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryInteractive ||
277
+ parameters . AuthenticationMethod == SqlAuthenticationMethod . ActiveDirectoryDeviceCodeFlow )
278
+ {
279
+ try
280
+ {
281
+ result = await TryAcquireTokenSilent ( app , parameters , scopes , cts ) . ConfigureAwait ( false ) ;
282
+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
283
+ }
284
+ catch ( MsalUiRequiredException )
285
+ {
286
+ // An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
287
+ // for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
288
+ // or the user needs to perform two factor authentication.
289
+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) . ConfigureAwait ( false ) ;
290
+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
291
+ }
292
+
293
+ if ( null == result )
287
294
{
288
295
// If no existing 'account' is found, we request user to sign in interactively.
289
- result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts ) . ConfigureAwait ( false ) ;
296
+ result = await AcquireTokenInteractiveDeviceFlowAsync ( app , scopes , parameters . ConnectionId , parameters . UserId , parameters . AuthenticationMethod , cts , _customWebUI , _deviceCodeFlowCallback ) . ConfigureAwait ( false ) ;
290
297
SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
291
298
}
292
299
}
@@ -299,8 +306,49 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
299
306
return new SqlAuthenticationToken ( result . AccessToken , result . ExpiresOn ) ;
300
307
}
301
308
302
- private async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
303
- SqlAuthenticationMethod authenticationMethod , CancellationTokenSource cts )
309
+ private static async Task < AuthenticationResult > TryAcquireTokenSilent ( IPublicClientApplication app , SqlAuthenticationParameters parameters ,
310
+ string [ ] scopes , CancellationTokenSource cts )
311
+ {
312
+ AuthenticationResult result = null ;
313
+
314
+ // Fetch available accounts from 'app' instance
315
+ System . Collections . Generic . IEnumerator < IAccount > accounts = ( await app . GetAccountsAsync ( ) . ConfigureAwait ( false ) ) . GetEnumerator ( ) ;
316
+
317
+ IAccount account = default ;
318
+ if ( accounts . MoveNext ( ) )
319
+ {
320
+ if ( ! string . IsNullOrEmpty ( parameters . UserId ) )
321
+ {
322
+ do
323
+ {
324
+ IAccount currentVal = accounts . Current ;
325
+ if ( string . Compare ( parameters . UserId , currentVal . Username , StringComparison . InvariantCultureIgnoreCase ) == 0 )
326
+ {
327
+ account = currentVal ;
328
+ break ;
329
+ }
330
+ }
331
+ while ( accounts . MoveNext ( ) ) ;
332
+ }
333
+ else
334
+ {
335
+ account = accounts . Current ;
336
+ }
337
+ }
338
+
339
+ if ( null != account )
340
+ {
341
+ // If 'account' is available in 'app', we use the same to acquire token silently.
342
+ // Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
343
+ result = await app . AcquireTokenSilent ( scopes , account ) . ExecuteAsync ( cancellationToken : cts . Token ) . ConfigureAwait ( false ) ;
344
+ SqlClientEventSource . Log . TryTraceEvent ( "AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}" , parameters . AuthenticationMethod , result ? . ExpiresOn ) ;
345
+ }
346
+
347
+ return result ;
348
+ }
349
+
350
+ private static async Task < AuthenticationResult > AcquireTokenInteractiveDeviceFlowAsync ( IPublicClientApplication app , string [ ] scopes , Guid connectionId , string userId ,
351
+ SqlAuthenticationMethod authenticationMethod , CancellationTokenSource cts , ICustomWebUi customWebUI , Func < DeviceCodeResult , Task > deviceCodeFlowCallback )
304
352
{
305
353
try
306
354
{
@@ -319,11 +367,11 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
319
367
*/
320
368
ctsInteractive . CancelAfter ( 180000 ) ;
321
369
#endif
322
- if ( _customWebUI != null )
370
+ if ( customWebUI != null )
323
371
{
324
372
return await app . AcquireTokenInteractive ( scopes )
325
373
. WithCorrelationId ( connectionId )
326
- . WithCustomWebUi ( _customWebUI )
374
+ . WithCustomWebUi ( customWebUI )
327
375
. WithLoginHint ( userId )
328
376
. ExecuteAsync ( ctsInteractive . Token )
329
377
. ConfigureAwait ( false ) ;
@@ -357,7 +405,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
357
405
else
358
406
{
359
407
AuthenticationResult result = await app . AcquireTokenWithDeviceCode ( scopes ,
360
- deviceCodeResult => _deviceCodeFlowCallback ( deviceCodeResult ) )
408
+ deviceCodeResult => deviceCodeFlowCallback ( deviceCodeResult ) )
361
409
. WithCorrelationId ( connectionId )
362
410
. ExecuteAsync ( cancellationToken : cts . Token )
363
411
. ConfigureAwait ( false ) ;
@@ -410,6 +458,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
410
458
return clientApplicationInstance ;
411
459
}
412
460
461
+ private static string GetAccountPwCacheKey ( SqlAuthenticationParameters parameters )
462
+ {
463
+ return parameters . Authority + "+" + parameters . UserId ;
464
+ }
465
+
466
+ private static byte [ ] GetHash ( string input )
467
+ {
468
+ byte [ ] unhashedBytes = Encoding . Unicode . GetBytes ( input ) ;
469
+ SHA256 sha256 = SHA256 . Create ( ) ;
470
+ byte [ ] hashedBytes = sha256 . ComputeHash ( unhashedBytes ) ;
471
+ return hashedBytes ;
472
+ }
473
+
413
474
private IPublicClientApplication CreateClientAppInstance ( PublicClientAppKey publicClientAppKey )
414
475
{
415
476
IPublicClientApplication publicClientApplication ;
0 commit comments