Skip to content

Commit 0d2d323

Browse files
authored
fix race condition in dictionary operations (#436)
* fix race condition * update register/unregister * no explicit casting
1 parent ae44bad commit 0d2d323

File tree

1 file changed

+20
-21
lines changed

1 file changed

+20
-21
lines changed

src/Authentication.Abstractions/AzureSession.cs

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,8 @@ public abstract class AzureSession : IAzureSession
3030
static IAzureSession _instance;
3131
static bool _initialized = false;
3232
static ReaderWriterLockSlim sessionLock = new ReaderWriterLockSlim(LockRecursionPolicy.SupportsRecursion);
33-
private IDictionary<ComponentKey, object> _componentRegistry = new ConcurrentDictionary<ComponentKey, object>(new ComponentKeyComparer());
33+
// explicit typing for the thread-safe API calls to avoid the need for explicit casting
34+
private ConcurrentDictionary<ComponentKey, object> _componentRegistry = new ConcurrentDictionary<ComponentKey, object>(new ComponentKeyComparer());
3435
private event EventHandler<AzureSessionEventArgs> _eventHandler;
3536

3637
/// <summary>
@@ -89,7 +90,7 @@ public abstract class AzureSession : IAzureSession
8990
public string OldProfileFile { get; set; }
9091

9192
/// <summary>
92-
/// The directory contianing the ARM ContextContainer
93+
/// The directory containing the ARM ContextContainer
9394
/// </summary>
9495
public string ARMProfileDirectory { get; set; }
9596

@@ -214,13 +215,16 @@ public static void Modify(Action<IAzureSession> modifier)
214215
public bool TryGetComponent<T>(string componentName, out T component) where T : class
215216
{
216217
var key = new ComponentKey(componentName, typeof(T));
217-
component = null;
218-
if (_componentRegistry.ContainsKey(key))
218+
if (_componentRegistry.TryGetValue(key, out var componentObj) && componentObj is T componentT)
219219
{
220-
component = _componentRegistry[key] as T;
220+
component = componentT;
221+
return true;
222+
}
223+
else
224+
{
225+
component = null;
226+
return false;
221227
}
222-
223-
return component != null;
224228
}
225229

226230
public void RegisterComponent<T>(string componentName, Func<T> componentInitializer) where T : class
@@ -234,16 +238,16 @@ public void RegisterComponent<T>(string componentName, Func<T> componentInitiali
234238
() =>
235239
{
236240
var key = new ComponentKey(componentName, typeof(T));
237-
if (!_componentRegistry.ContainsKey(key) || overwrite)
241+
if (!_componentRegistry.ContainsKey(key) || overwrite) // only proceed if key not found or overwrite is true
238242
{
239-
if (_componentRegistry.ContainsKey(key) && overwrite)
243+
244+
if (overwrite
245+
&& _componentRegistry.TryGetValue(key, out var existed)
246+
&& existed is IAzureSessionListener existedListener)
240247
{
241-
var existed = _componentRegistry[key];
242-
if (existed is IAzureSessionListener existedListener)
243-
{
244-
_eventHandler -= existedListener.OnEvent;
245-
}
248+
_eventHandler -= existedListener.OnEvent;
246249
}
250+
247251
var component = componentInitializer();
248252
_componentRegistry[key] = component;
249253
if (component is IAzureSessionListener listener)
@@ -260,14 +264,9 @@ public void UnregisterComponent<T>(string componentName) where T : class
260264
() =>
261265
{
262266
var key = new ComponentKey(componentName, typeof(T));
263-
if (_componentRegistry.ContainsKey(key))
267+
if (_componentRegistry.TryRemove(key, out var component) && component is IAzureSessionListener listener)
264268
{
265-
var component = _componentRegistry[key];
266-
if (component is IAzureSessionListener listener)
267-
{
268-
_eventHandler -= listener.OnEvent;
269-
}
270-
_componentRegistry.Remove(key);
269+
_eventHandler -= listener.OnEvent;
271270
}
272271
});
273272
}

0 commit comments

Comments
 (0)