Skip to content

Commit 5b01cc1

Browse files
Fix | Throttling of token requests by calling AcquireTokenSilent (#1925) (#1995)
1 parent 86aac95 commit 5b01cc1

File tree

1 file changed

+131
-70
lines changed

1 file changed

+131
-70
lines changed

src/Microsoft.Data.SqlClient/src/Microsoft/Data/SqlClient/ActiveDirectoryAuthenticationProvider.cs

+131-70
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,11 @@
44

55
using System;
66
using System.Collections.Concurrent;
7+
using System.Linq;
8+
using System.Runtime.Caching;
79
using System.Security;
10+
using System.Security.Cryptography;
11+
using System.Text;
812
using System.Threading;
913
using System.Threading.Tasks;
1014
using Azure.Core;
@@ -24,6 +28,8 @@ public sealed class ActiveDirectoryAuthenticationProvider : SqlAuthenticationPro
2428
/// </summary>
2529
private static ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication> s_pcaMap
2630
= new ConcurrentDictionary<PublicClientAppKey, IPublicClientApplication>();
31+
private static readonly MemoryCache s_accountPwCache = new(nameof(ActiveDirectoryAuthenticationProvider));
32+
private static readonly int s_accountPwCacheTtlInHours = 2;
2733
private static readonly string s_nativeClientRedirectUri = "https://login.microsoftonline.com/common/oauth2/nativeclient";
2834
private static readonly string s_defaultScopeSuffix = "/.default";
2935
private readonly string _type = typeof(ActiveDirectoryAuthenticationProvider).Name;
@@ -171,7 +177,7 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
171177
return new SqlAuthenticationToken(accessToken.Token, accessToken.ExpiresOn);
172178
}
173179

174-
AuthenticationResult result;
180+
AuthenticationResult result = null;
175181
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryServicePrincipal)
176182
{
177183
AccessToken accessToken = await new ClientSecretCredential(audience, parameters.UserId, parameters.Password, tokenCredentialOptions).GetTokenAsync(tokenRequestContext, cts.Token).ConfigureAwait(false);
@@ -207,86 +213,87 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
207213

208214
if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryIntegrated)
209215
{
210-
if (!string.IsNullOrEmpty(parameters.UserId))
211-
{
212-
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
213-
.WithCorrelationId(parameters.ConnectionId)
214-
.WithUsername(parameters.UserId)
215-
.ExecuteAsync(cancellationToken: cts.Token)
216-
.ConfigureAwait(false);
217-
}
218-
else
219-
{
220-
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
221-
.WithCorrelationId(parameters.ConnectionId)
222-
.ExecuteAsync(cancellationToken: cts.Token)
223-
.ConfigureAwait(false);
224-
}
225-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
226-
}
227-
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
228-
{
229-
SecureString password = new SecureString();
230-
foreach (char c in parameters.Password)
231-
password.AppendChar(c);
232-
password.MakeReadOnly();
233-
234-
result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
235-
.WithCorrelationId(parameters.ConnectionId)
236-
.ExecuteAsync(cancellationToken: cts.Token)
237-
.ConfigureAwait(false);
238-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
239-
}
240-
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
241-
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
242-
{
243-
// Fetch available accounts from 'app' instance
244-
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
216+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
245217

246-
IAccount account = default;
247-
if (accounts.MoveNext())
218+
if (null == result)
248219
{
249220
if (!string.IsNullOrEmpty(parameters.UserId))
250221
{
251-
do
252-
{
253-
IAccount currentVal = accounts.Current;
254-
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
255-
{
256-
account = currentVal;
257-
break;
258-
}
259-
}
260-
while (accounts.MoveNext());
222+
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
223+
.WithCorrelationId(parameters.ConnectionId)
224+
.WithUsername(parameters.UserId)
225+
.ExecuteAsync(cancellationToken: cts.Token)
226+
.ConfigureAwait(false);
261227
}
262228
else
263229
{
264-
account = accounts.Current;
230+
result = await app.AcquireTokenByIntegratedWindowsAuth(scopes)
231+
.WithCorrelationId(parameters.ConnectionId)
232+
.ExecuteAsync(cancellationToken: cts.Token)
233+
.ConfigureAwait(false);
265234
}
235+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Integrated auth mode. Expiry Time: {0}", result?.ExpiresOn);
236+
}
237+
}
238+
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryPassword)
239+
{
240+
string pwCacheKey = GetAccountPwCacheKey(parameters);
241+
object previousPw = s_accountPwCache.Get(pwCacheKey);
242+
byte[] currPwHash = GetHash(parameters.Password);
243+
244+
if (null != previousPw &&
245+
previousPw is byte[] previousPwBytes &&
246+
// Only get the cached token if the current password hash matches the previously used password hash
247+
currPwHash.SequenceEqual(previousPwBytes))
248+
{
249+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
266250
}
267251

268-
if (null != account)
252+
if (null == result)
269253
{
270-
try
254+
SecureString password = new SecureString();
255+
foreach (char c in parameters.Password)
256+
password.AppendChar(c);
257+
password.MakeReadOnly();
258+
259+
result = await app.AcquireTokenByUsernamePassword(scopes, parameters.UserId, password)
260+
.WithCorrelationId(parameters.ConnectionId)
261+
.ExecuteAsync(cancellationToken: cts.Token)
262+
.ConfigureAwait(false);
263+
264+
// We cache the password hash to ensure future connection requests include a validated password
265+
// when we check for a cached MSAL account. Otherwise, a connection request with the same username
266+
// against the same tenant could succeed with an invalid password when we re-use the cached token.
267+
if (!s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours)))
271268
{
272-
// If 'account' is available in 'app', we use the same to acquire token silently.
273-
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
274-
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
275-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
276-
}
277-
catch (MsalUiRequiredException)
278-
{
279-
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
280-
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
281-
// or the user needs to perform two factor authentication.
282-
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
283-
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
269+
s_accountPwCache.Remove(pwCacheKey);
270+
s_accountPwCache.Add(pwCacheKey, GetHash(parameters.Password), DateTime.UtcNow.AddHours(s_accountPwCacheTtlInHours));
284271
}
272+
273+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token for Active Directory Password auth mode. Expiry Time: {0}", result?.ExpiresOn);
285274
}
286-
else
275+
}
276+
else if (parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryInteractive ||
277+
parameters.AuthenticationMethod == SqlAuthenticationMethod.ActiveDirectoryDeviceCodeFlow)
278+
{
279+
try
280+
{
281+
result = await TryAcquireTokenSilent(app, parameters, scopes, cts).ConfigureAwait(false);
282+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
283+
}
284+
catch (MsalUiRequiredException)
285+
{
286+
// An 'MsalUiRequiredException' is thrown in the case where an interaction is required with the end user of the application,
287+
// for instance, if no refresh token was in the cache, or the user needs to consent, or re-sign-in (for instance if the password expired),
288+
// or the user needs to perform two factor authentication.
289+
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false);
290+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
291+
}
292+
293+
if (null == result)
287294
{
288295
// If no existing 'account' is found, we request user to sign in interactively.
289-
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts).ConfigureAwait(false);
296+
result = await AcquireTokenInteractiveDeviceFlowAsync(app, scopes, parameters.ConnectionId, parameters.UserId, parameters.AuthenticationMethod, cts, _customWebUI, _deviceCodeFlowCallback).ConfigureAwait(false);
290297
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (interactive) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
291298
}
292299
}
@@ -299,8 +306,49 @@ public override async Task<SqlAuthenticationToken> AcquireTokenAsync(SqlAuthenti
299306
return new SqlAuthenticationToken(result.AccessToken, result.ExpiresOn);
300307
}
301308

302-
private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
303-
SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts)
309+
private static async Task<AuthenticationResult> TryAcquireTokenSilent(IPublicClientApplication app, SqlAuthenticationParameters parameters,
310+
string[] scopes, CancellationTokenSource cts)
311+
{
312+
AuthenticationResult result = null;
313+
314+
// Fetch available accounts from 'app' instance
315+
System.Collections.Generic.IEnumerator<IAccount> accounts = (await app.GetAccountsAsync().ConfigureAwait(false)).GetEnumerator();
316+
317+
IAccount account = default;
318+
if (accounts.MoveNext())
319+
{
320+
if (!string.IsNullOrEmpty(parameters.UserId))
321+
{
322+
do
323+
{
324+
IAccount currentVal = accounts.Current;
325+
if (string.Compare(parameters.UserId, currentVal.Username, StringComparison.InvariantCultureIgnoreCase) == 0)
326+
{
327+
account = currentVal;
328+
break;
329+
}
330+
}
331+
while (accounts.MoveNext());
332+
}
333+
else
334+
{
335+
account = accounts.Current;
336+
}
337+
}
338+
339+
if (null != account)
340+
{
341+
// If 'account' is available in 'app', we use the same to acquire token silently.
342+
// Read More on API docs: https://docs.microsoft.com/dotnet/api/microsoft.identity.client.clientapplicationbase.acquiretokensilent
343+
result = await app.AcquireTokenSilent(scopes, account).ExecuteAsync(cancellationToken: cts.Token).ConfigureAwait(false);
344+
SqlClientEventSource.Log.TryTraceEvent("AcquireTokenAsync | Acquired access token (silent) for {0} auth mode. Expiry Time: {1}", parameters.AuthenticationMethod, result?.ExpiresOn);
345+
}
346+
347+
return result;
348+
}
349+
350+
private static async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(IPublicClientApplication app, string[] scopes, Guid connectionId, string userId,
351+
SqlAuthenticationMethod authenticationMethod, CancellationTokenSource cts, ICustomWebUi customWebUI, Func<DeviceCodeResult, Task> deviceCodeFlowCallback)
304352
{
305353
try
306354
{
@@ -319,11 +367,11 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
319367
*/
320368
ctsInteractive.CancelAfter(180000);
321369
#endif
322-
if (_customWebUI != null)
370+
if (customWebUI != null)
323371
{
324372
return await app.AcquireTokenInteractive(scopes)
325373
.WithCorrelationId(connectionId)
326-
.WithCustomWebUi(_customWebUI)
374+
.WithCustomWebUi(customWebUI)
327375
.WithLoginHint(userId)
328376
.ExecuteAsync(ctsInteractive.Token)
329377
.ConfigureAwait(false);
@@ -357,7 +405,7 @@ private async Task<AuthenticationResult> AcquireTokenInteractiveDeviceFlowAsync(
357405
else
358406
{
359407
AuthenticationResult result = await app.AcquireTokenWithDeviceCode(scopes,
360-
deviceCodeResult => _deviceCodeFlowCallback(deviceCodeResult))
408+
deviceCodeResult => deviceCodeFlowCallback(deviceCodeResult))
361409
.WithCorrelationId(connectionId)
362410
.ExecuteAsync(cancellationToken: cts.Token)
363411
.ConfigureAwait(false);
@@ -410,6 +458,19 @@ private IPublicClientApplication GetPublicClientAppInstance(PublicClientAppKey p
410458
return clientApplicationInstance;
411459
}
412460

461+
private static string GetAccountPwCacheKey(SqlAuthenticationParameters parameters)
462+
{
463+
return parameters.Authority + "+" + parameters.UserId;
464+
}
465+
466+
private static byte[] GetHash(string input)
467+
{
468+
byte[] unhashedBytes = Encoding.Unicode.GetBytes(input);
469+
SHA256 sha256 = SHA256.Create();
470+
byte[] hashedBytes = sha256.ComputeHash(unhashedBytes);
471+
return hashedBytes;
472+
}
473+
413474
private IPublicClientApplication CreateClientAppInstance(PublicClientAppKey publicClientAppKey)
414475
{
415476
IPublicClientApplication publicClientApplication;

0 commit comments

Comments
 (0)