Skip to content

Commit 253973e

Browse files
Fixed race conditions + unsafe struct assignment in SelectAsync (#7521)
* Fixed race conditions + unsafe struct assignment in `SelectAsync` close #7518 * added nullability API approvals * fixed `FlowSelectAsyncSpecs` * Update src/core/Akka.Streams/Implementation/Fusing/Ops.cs Co-authored-by: Michael Buck <[email protected]> * relaxed nullability requirements for `Holder<TOut>` `null` might be a perfectly acceptable value inside a `SelectAsync` stage * fixed `FlowAskSpec` * forgot to pass in innerException * fix race condition in `FlowAskSpec` --------- Co-authored-by: Michael Buck <[email protected]>
1 parent 10b8223 commit 253973e

File tree

5 files changed

+102
-71
lines changed

5 files changed

+102
-71
lines changed

src/core/Akka.API.Tests/verify/CoreAPISpec.ApproveStreams.DotNet.verified.txt

+12-2
Original file line numberDiff line numberDiff line change
@@ -4430,7 +4430,12 @@ namespace Akka.Streams.Implementation.Fusing
44304430
public override string ToString() { }
44314431
}
44324432
[Akka.Annotations.InternalApiAttribute()]
4433-
public sealed class SelectAsyncUnordered<TIn, TOut> : Akka.Streams.Stage.GraphStage<Akka.Streams.FlowShape<TIn, TOut>>
4433+
[System.Runtime.CompilerServices.NullableAttribute(new byte[] {
4434+
0,
4435+
1,
4436+
1,
4437+
1})]
4438+
public sealed class SelectAsyncUnordered<[System.Runtime.CompilerServices.NullableAttribute(2)] TIn, [System.Runtime.CompilerServices.NullableAttribute(2)] TOut> : Akka.Streams.Stage.GraphStage<Akka.Streams.FlowShape<TIn, TOut>>
44344439
{
44354440
public readonly Akka.Streams.Inlet<TIn> In;
44364441
public readonly Akka.Streams.Outlet<TOut> Out;
@@ -4440,7 +4445,12 @@ namespace Akka.Streams.Implementation.Fusing
44404445
protected override Akka.Streams.Stage.GraphStageLogic CreateLogic(Akka.Streams.Attributes inheritedAttributes) { }
44414446
}
44424447
[Akka.Annotations.InternalApiAttribute()]
4443-
public sealed class SelectAsync<TIn, TOut> : Akka.Streams.Stage.GraphStage<Akka.Streams.FlowShape<TIn, TOut>>
4448+
[System.Runtime.CompilerServices.NullableAttribute(new byte[] {
4449+
0,
4450+
1,
4451+
1,
4452+
1})]
4453+
public sealed class SelectAsync<[System.Runtime.CompilerServices.NullableAttribute(2)] TIn, [System.Runtime.CompilerServices.NullableAttribute(2)] TOut> : Akka.Streams.Stage.GraphStage<Akka.Streams.FlowShape<TIn, TOut>>
44444454
{
44454455
public readonly Akka.Streams.Inlet<TIn> In;
44464456
public readonly Akka.Streams.Outlet<TOut> Out;

src/core/Akka.API.Tests/verify/CoreAPISpec.ApproveStreams.Net.verified.txt

+12-2
Original file line numberDiff line numberDiff line change
@@ -4404,7 +4404,12 @@ namespace Akka.Streams.Implementation.Fusing
44044404
public override string ToString() { }
44054405
}
44064406
[Akka.Annotations.InternalApiAttribute()]
4407-
public sealed class SelectAsyncUnordered<TIn, TOut> : Akka.Streams.Stage.GraphStage<Akka.Streams.FlowShape<TIn, TOut>>
4407+
[System.Runtime.CompilerServices.NullableAttribute(new byte[] {
4408+
0,
4409+
1,
4410+
1,
4411+
1})]
4412+
public sealed class SelectAsyncUnordered<[System.Runtime.CompilerServices.NullableAttribute(2)] TIn, [System.Runtime.CompilerServices.NullableAttribute(2)] TOut> : Akka.Streams.Stage.GraphStage<Akka.Streams.FlowShape<TIn, TOut>>
44084413
{
44094414
public readonly Akka.Streams.Inlet<TIn> In;
44104415
public readonly Akka.Streams.Outlet<TOut> Out;
@@ -4414,7 +4419,12 @@ namespace Akka.Streams.Implementation.Fusing
44144419
protected override Akka.Streams.Stage.GraphStageLogic CreateLogic(Akka.Streams.Attributes inheritedAttributes) { }
44154420
}
44164421
[Akka.Annotations.InternalApiAttribute()]
4417-
public sealed class SelectAsync<TIn, TOut> : Akka.Streams.Stage.GraphStage<Akka.Streams.FlowShape<TIn, TOut>>
4422+
[System.Runtime.CompilerServices.NullableAttribute(new byte[] {
4423+
0,
4424+
1,
4425+
1,
4426+
1})]
4427+
public sealed class SelectAsync<[System.Runtime.CompilerServices.NullableAttribute(2)] TIn, [System.Runtime.CompilerServices.NullableAttribute(2)] TOut> : Akka.Streams.Stage.GraphStage<Akka.Streams.FlowShape<TIn, TOut>>
44184428
{
44194429
public readonly Akka.Streams.Inlet<TIn> In;
44204430
public readonly Akka.Streams.Outlet<TOut> Out;

src/core/Akka.Streams.Tests/Dsl/FlowAskSpec.cs

+12-5
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,7 @@ public async Task Flow_with_ask_must_signal_ask_timeout_failure() => await this.
238238

239239
c.ExpectSubscription().Request(10);
240240
var error = c.ExpectError();
241-
error.As<AggregateException>().Flatten()
242-
.InnerException
243-
.Should().BeOfType<AskTimeoutException>();
241+
error.Should().BeOfType<AskTimeoutException>();
244242
return Task.CompletedTask;
245243
}, _materializer);
246244

@@ -253,8 +251,17 @@ public async Task Flow_with_ask_must_signal_ask_failure() => await this.AssertAl
253251
.Ask<Reply>(failsOn, _timeout, 1)
254252
.RunWith(Sink.FromSubscriber(c), _materializer);
255253

256-
var error = (AggregateException)c.ExpectSubscriptionAndError();
257-
error.InnerException.Message.Should().Be("Booming for 1!");
254+
var error = c.ExpectSubscriptionAndError();
255+
if (error is AggregateException aggregateException) // happens if we hit the fast path and don't await
256+
{
257+
aggregateException.Flatten()
258+
.InnerException!.Message.Should().Be("Booming for 1!");
259+
}
260+
else
261+
{
262+
error.Message.Should().Be("Booming for 1!");
263+
}
264+
258265
return Task.CompletedTask;
259266
}, _materializer);
260267

src/core/Akka.Streams.Tests/Dsl/FlowSelectAsyncSpec.cs

+26-24
Original file line numberDiff line numberDiff line change
@@ -171,10 +171,10 @@ await this.AssertAllStagesStoppedAsync(async() => {
171171

172172
var exception = await probe.AsyncBuilder()
173173
.Request(10)
174-
.ExpectNextN(new[]{1, 2})
174+
.ExpectNextN([1, 2])
175175
.ExpectErrorAsync()
176176
.ShouldCompleteWithin(RemainingOrDefault);
177-
exception.InnerException!.Message.Should().Be("err1");
177+
exception.Message.Should().Be("err1");
178178
}, Materializer);
179179
}
180180

@@ -232,7 +232,7 @@ await this.AssertAllStagesStoppedAsync(async () => {
232232
.RunWith(Sink.FromSubscriber(c), Materializer);
233233
var sub = await c.ExpectSubscriptionAsync();
234234
sub.Request(10);
235-
c.ExpectError().Message.Should().Be("err2");
235+
(await c.ExpectErrorAsync()).Message.Should().Be("err2");
236236
}, Materializer);
237237
}
238238

@@ -258,7 +258,7 @@ await this.AssertAllStagesStoppedAsync(async () =>
258258

259259
await probe.AsyncBuilder()
260260
.Request(10)
261-
.ExpectNextN(new[] { 1, 2 })
261+
.ExpectNextN([1, 2])
262262
.ExpectErrorAsync();
263263

264264
invoked.Should().BeTrue();
@@ -358,21 +358,21 @@ public async Task A_Flow_with_SelectAsync_must_signal_NPE_when_task_is_completed
358358
{
359359
var c = this.CreateManualSubscriberProbe<string>();
360360

361-
Source.From(new[] {"a", "b"})
362-
.SelectAsync(4, _ => Task.FromResult(null as string))
361+
Source.From(["a", "b"])
362+
.SelectAsync(4, _ => Task.FromResult<string>(null))
363363
.To(Sink.FromSubscriber(c)).Run(Materializer);
364364

365365
var sub = await c.ExpectSubscriptionAsync();
366366
sub.Request(10);
367-
c.ExpectError().Message.Should().StartWith(ReactiveStreamsCompliance.ElementMustNotBeNullMsg);
367+
(await c.ExpectErrorAsync()).Message.Should().StartWith(ReactiveStreamsCompliance.ElementMustNotBeNullMsg);
368368
}
369369

370370
[Fact]
371371
public async Task A_Flow_with_SelectAsync_must_resume_when_task_is_completed_with_null()
372372
{
373373
var c = this.CreateManualSubscriberProbe<string>();
374-
Source.From(new[] { "a", "b", "c" })
375-
.SelectAsync(4, s => s.Equals("b") ? Task.FromResult(null as string) : Task.FromResult(s))
374+
Source.From(["a", "b", "c"])
375+
.SelectAsync(4, s => s.Equals("b") ? Task.FromResult<string>(null) : Task.FromResult(s))
376376
.WithAttributes(ActorAttributes.CreateSupervisionStrategy(Deciders.ResumingDecider))
377377
.To(Sink.FromSubscriber(c)).Run(Materializer);
378378
var sub = await c.ExpectSubscriptionAsync();
@@ -438,21 +438,6 @@ await this.AssertAllStagesStoppedAsync(async() =>
438438
}, cancellation.Token);
439439
#pragma warning restore CS4014 // Because this call is not awaited, execution of the current method continues before the call is completed
440440

441-
Task<int> Deferred()
442-
{
443-
var promise = new TaskCompletionSource<int>();
444-
if (counter.IncrementAndGet() > parallelism)
445-
promise.SetException(new Exception("parallelism exceeded"));
446-
else
447-
{
448-
var wrote = queue.Writer.TryWrite((promise, DateTime.Now.Ticks));
449-
if (!wrote)
450-
promise.SetException(new Exception("Failed to write to queue"));
451-
}
452-
453-
return promise.Task;
454-
}
455-
456441
try
457442
{
458443
const int n = 10000;
@@ -467,6 +452,23 @@ Task<int> Deferred()
467452
{
468453
cancellation.Cancel(false);
469454
}
455+
456+
return;
457+
458+
Task<int> Deferred()
459+
{
460+
var promise = new TaskCompletionSource<int>();
461+
if (counter.IncrementAndGet() > parallelism)
462+
promise.SetException(new Exception("parallelism exceeded"));
463+
else
464+
{
465+
var wrote = queue.Writer.TryWrite((promise, DateTime.Now.Ticks));
466+
if (!wrote)
467+
promise.SetException(new Exception("Failed to write to queue"));
468+
}
469+
470+
return promise.Task;
471+
}
470472
}, Materializer);
471473
}
472474

src/core/Akka.Streams/Implementation/Fusing/Ops.cs

+40-38
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
using Akka.Streams.Util;
2323
using Akka.Util;
2424
using Akka.Util.Internal;
25+
using Debug = System.Diagnostics.Debug;
2526
using Decider = Akka.Streams.Supervision.Decider;
2627
using Directive = Akka.Streams.Supervision.Directive;
2728

@@ -2512,61 +2513,47 @@ public Expand(Func<TIn, IEnumerator<TOut>> extrapolate)
25122513
/// </returns>
25132514
public override string ToString() => "Expand";
25142515
}
2516+
2517+
#nullable enable
25152518

25162519
/// <summary>
25172520
/// INTERNAL API
25182521
/// </summary>
2519-
/// <typeparam name="TIn">TBD</typeparam>
2520-
/// <typeparam name="TOut">TBD</typeparam>
25212522
[InternalApi]
25222523
public sealed class SelectAsync<TIn, TOut> : GraphStage<FlowShape<TIn, TOut>>
25232524
{
25242525
#region internal classes
25252526

25262527
private sealed class Logic : InAndOutGraphStageLogic
25272528
{
2528-
private class Holder<T>
2529+
private sealed class Holder<T>(object? message, Result<T> element)
25292530
{
2530-
private readonly Action<Holder<T>> _callback;
2531-
2532-
public Holder(object message, Result<T> element, Action<Holder<T>> callback)
2533-
{
2534-
_callback = callback;
2535-
Message = message;
2536-
Element = element;
2537-
}
2538-
2539-
public Result<T> Element { get; private set; }
2540-
public object Message { get; }
2531+
public object? Message { get; private set; } = message;
2532+
2533+
public Result<T> Element { get; private set; } = element;
25412534

25422535
public void SetElement(Result<T> result)
25432536
{
25442537
Element = result.IsSuccess && result.Value == null
25452538
? Result.Failure<T>(ReactiveStreamsCompliance.ElementMustNotBeNullException)
25462539
: result;
25472540
}
2548-
2549-
public void Invoke(Result<T> result)
2550-
{
2551-
SetElement(result);
2552-
_callback(this);
2553-
}
25542541
}
25552542

25562543
private static readonly Result<TOut> NotYetThere = Result.Failure<TOut>(new Exception());
25572544

25582545
private readonly SelectAsync<TIn, TOut> _stage;
25592546
private readonly Decider _decider;
25602547
private IBuffer<Holder<TOut>> _buffer;
2561-
private readonly Action<Holder<TOut>> _taskCallback;
2548+
private readonly Action<(Holder<TOut>, Result<TOut>)> _taskCallback;
25622549

25632550
public Logic(Attributes inheritedAttributes, SelectAsync<TIn, TOut> stage) : base(stage.Shape)
25642551
{
25652552
_stage = stage;
2566-
var attr = inheritedAttributes.GetAttribute<ActorAttributes.SupervisionStrategy>(null);
2553+
var attr = inheritedAttributes.GetAttribute<ActorAttributes.SupervisionStrategy>();
25672554
_decider = attr != null ? attr.Decider : Deciders.StoppingDecider;
25682555

2569-
_taskCallback = GetAsyncCallback<Holder<TOut>>(HolderCompleted);
2556+
_taskCallback = GetAsyncCallback<(Holder<TOut> holder, Result<TOut> result)>(t => HolderCompleted(t.holder, t.result));
25702557

25712558
SetHandlers(stage.In, stage.Out, this);
25722559
}
@@ -2577,19 +2564,33 @@ public override void OnPush()
25772564
try
25782565
{
25792566
var task = _stage._mapFunc(message);
2580-
var holder = new Holder<TOut>(message, NotYetThere, _taskCallback);
2567+
var holder = new Holder<TOut>(message, NotYetThere);
25812568
_buffer.Enqueue(holder);
25822569

25832570
// We dispatch the task if it's ready to optimize away
25842571
// scheduling it to an execution context
25852572
if (task.IsCompleted)
25862573
{
2587-
holder.SetElement(Result.FromTask(task));
2588-
HolderCompleted(holder);
2574+
HolderCompleted(holder, Result.FromTask(task));
25892575
}
25902576
else
2591-
task.ContinueWith(t => holder.Invoke(Result.FromTask(t)),
2592-
TaskContinuationOptions.ExecuteSynchronously);
2577+
{
2578+
async Task WaitForTask()
2579+
{
2580+
try
2581+
{
2582+
var result = Result.Success(await task);
2583+
_taskCallback((holder, result));
2584+
}
2585+
catch(Exception ex){
2586+
var result = Result.Failure<TOut>(ex);
2587+
_taskCallback((holder, result));
2588+
}
2589+
}
2590+
2591+
_ = WaitForTask();
2592+
}
2593+
25932594
}
25942595
catch (Exception e)
25952596
{
@@ -2606,7 +2607,7 @@ public override void OnPush()
26062607
break;
26072608

26082609
default:
2609-
throw new AggregateException($"Unknown SupervisionStrategy directive: {strategy}", e);
2610+
throw new ArgumentOutOfRangeException($"Unknown SupervisionStrategy directive: {strategy}", e);
26102611
}
26112612
}
26122613
if (Todo < _stage._parallelism && !HasBeenPulled(_stage.In))
@@ -2663,12 +2664,12 @@ private void PushOne()
26632664
break;
26642665

26652666
default:
2666-
throw new AggregateException($"Unknown SupervisionStrategy directive: {strategy}", result.Exception);
2667+
throw new ArgumentOutOfRangeException($"Unknown SupervisionStrategy directive: {strategy}", result.Exception);
26672668
}
26682669
continue;
26692670
}
26702671

2671-
Push(_stage.Out, result.Value);
2672+
Push(_stage.Out!, result.Value);
26722673
if (Todo < _stage._parallelism && !HasBeenPulled(inlet))
26732674
TryPull(inlet);
26742675
}
@@ -2677,17 +2678,18 @@ private void PushOne()
26772678
}
26782679
}
26792680

2680-
private void HolderCompleted(Holder<TOut> holder)
2681+
private void HolderCompleted(Holder<TOut> holder, Result<TOut> result)
26812682
{
2682-
var element = holder.Element;
2683-
if (element.IsSuccess)
2683+
// we may not be at the front of the line right now, so save the result for later
2684+
holder.SetElement(result);
2685+
if (result.IsSuccess)
26842686
{
26852687
if (IsAvailable(_stage.Out))
26862688
PushOne();
26872689
return;
26882690
}
26892691

2690-
var exception = element.Exception;
2692+
var exception = result.Exception;
26912693
var strategy = _decider(exception);
26922694
Log.Error(exception, "An exception occured inside SelectAsync while executing Task. Supervision strategy: {0}", strategy);
26932695
switch (strategy)
@@ -2703,7 +2705,7 @@ private void HolderCompleted(Holder<TOut> holder)
27032705
break;
27042706

27052707
default:
2706-
throw new AggregateException($"Unknown SupervisionStrategy directive: {strategy}", exception);
2708+
throw new ArgumentOutOfRangeException($"Unknown SupervisionStrategy directive: {strategy}", exception);
27072709
}
27082710
}
27092711

@@ -2758,8 +2760,6 @@ protected override GraphStageLogic CreateLogic(Attributes inheritedAttributes)
27582760
/// <summary>
27592761
/// INTERNAL API
27602762
/// </summary>
2761-
/// <typeparam name="TIn">TBD</typeparam>
2762-
/// <typeparam name="TOut">TBD</typeparam>
27632763
[InternalApi]
27642764
public sealed class SelectAsyncUnordered<TIn, TOut> : GraphStage<FlowShape<TIn, TOut>>
27652765
{
@@ -2904,6 +2904,8 @@ public SelectAsyncUnordered(int parallelism, Func<TIn, Task<TOut>> mapFunc)
29042904
protected override GraphStageLogic CreateLogic(Attributes inheritedAttributes)
29052905
=> new Logic(inheritedAttributes, this);
29062906
}
2907+
2908+
#nullable disable
29072909

29082910
/// <summary>
29092911
/// INTERNAL API

0 commit comments

Comments
 (0)