Skip to content

Commit 2b8e544

Browse files
committed
Rework some logic in TlsHandler
* Make sure TlsHandler.MediationStream works well with different style of aync calls(Still not work for Mono, see Azure#374) * Rework some logic in Azure#366, now always close TlsHandler.MediationStream in TlsHandler.HandleFailure since it's never exported.
1 parent f9b86a0 commit 2b8e544

File tree

2 files changed

+47
-27
lines changed

2 files changed

+47
-27
lines changed

src/DotNetty.Handlers/Tls/SniHandler.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ public sealed class SniHandler : ByteToMessageDecoder
2828
bool readPending;
2929

3030
public SniHandler(ServerTlsSniSettings settings)
31-
: this(stream => new SslStream(stream, false), settings)
31+
: this(stream => new SslStream(stream, true), settings)
3232
{
3333
}
3434

src/DotNetty.Handlers/Tls/TlsHandler.cs

Lines changed: 46 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ namespace DotNetty.Handlers.Tls
55
{
66
using System;
77
using System.Collections.Generic;
8+
using System.Diagnostics;
89
using System.Diagnostics.Contracts;
910
using System.IO;
1011
using System.Net.Security;
@@ -41,7 +42,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
4142
Task<int> pendingSslStreamReadFuture;
4243

4344
public TlsHandler(TlsSettings settings)
44-
: this(stream => new SslStream(stream, false), settings)
45+
: this(stream => new SslStream(stream, true), settings)
4546
{
4647
}
4748

@@ -344,6 +345,9 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
344345

345346
outputBuffer = this.pendingSslStreamReadBuffer;
346347
outputBufferLength = outputBuffer.WritableBytes;
348+
349+
this.pendingSslStreamReadFuture = null;
350+
this.pendingSslStreamReadBuffer = null;
347351
}
348352
else
349353
{
@@ -363,17 +367,23 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
363367
if (!currentReadFuture.IsCompleted)
364368
{
365369
// we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
366-
Contract.Assert(this.mediationStream.SourceReadableBytes == 0);
367370

368371
continue;
369372
}
370373

371374
int read = currentReadFuture.Result;
372375

376+
if (read == 0)
377+
{
378+
//Stream closed
379+
return;
380+
}
381+
373382
// Now output the result of previous read and decide whether to do an extra read on the same source or move forward
374383
AddBufferToOutput(outputBuffer, read, output);
375384

376385
currentReadFuture = null;
386+
outputBuffer = null;
377387
if (this.mediationStream.SourceReadableBytes == 0)
378388
{
379389
// we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
@@ -620,6 +630,7 @@ void HandleFailure(Exception cause)
620630
// Release all resources such as internal buffers that SSLEngine
621631
// is managing.
622632

633+
this.mediationStream.Dispose();
623634
try
624635
{
625636
this.sslStream.Dispose();
@@ -701,14 +712,13 @@ public void ExpandSource(int count)
701712

702713
this.inputLength += count;
703714

704-
TaskCompletionSource<int> promise = this.readCompletionSource;
705-
if (promise == null)
715+
ArraySegment<byte> sslBuffer = this.sslOwnedBuffer;
716+
if (sslBuffer.Array == null)
706717
{
707718
// there is no pending read operation - keep for future
708719
return;
709720
}
710-
711-
ArraySegment<byte> sslBuffer = this.sslOwnedBuffer;
721+
this.sslOwnedBuffer = default(ArraySegment<byte>);
712722

713723
#if NETSTANDARD1_3
714724
this.readByteCount = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count);
@@ -718,29 +728,35 @@ public void ExpandSource(int count)
718728
{
719729
var self = (MediationStream)ms;
720730
TaskCompletionSource<int> p = self.readCompletionSource;
721-
this.readCompletionSource = null;
731+
self.readCompletionSource = null;
722732
p.TrySetResult(self.readByteCount);
723733
},
724734
this)
725735
.RunSynchronously(TaskScheduler.Default);
726736
#else
727737
int read = this.ReadFromInput(sslBuffer.Array, sslBuffer.Offset, sslBuffer.Count);
738+
739+
TaskCompletionSource<int> promise = this.readCompletionSource;
728740
this.readCompletionSource = null;
729741
promise.TrySetResult(read);
730-
this.readCallback?.Invoke(promise.Task);
742+
743+
AsyncCallback callback = this.readCallback;
744+
this.readCallback = null;
745+
callback?.Invoke(promise.Task);
731746
#endif
732747
}
733748

734749
#if NETSTANDARD1_3
735750
public override Task<int> ReadAsync(byte[] buffer, int offset, int count, CancellationToken cancellationToken)
736751
{
737-
if (this.inputLength - this.inputOffset > 0)
752+
if (this.SourceReadableBytes > 0)
738753
{
739754
// we have the bytes available upfront - write out synchronously
740755
int read = this.ReadFromInput(buffer, offset, count);
741756
return Task.FromResult(read);
742757
}
743758

759+
Contract.Assert(this.sslOwnedBuffer.Array == null);
744760
// take note of buffer - we will pass bytes there once available
745761
this.sslOwnedBuffer = new ArraySegment<byte>(buffer, offset, count);
746762
this.readCompletionSource = new TaskCompletionSource<int>();
@@ -749,13 +765,16 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
749765
#else
750766
public override IAsyncResult BeginRead(byte[] buffer, int offset, int count, AsyncCallback callback, object state)
751767
{
752-
if (this.inputLength - this.inputOffset > 0)
768+
if (this.SourceReadableBytes > 0)
753769
{
754770
// we have the bytes available upfront - write out synchronously
755771
int read = this.ReadFromInput(buffer, offset, count);
756-
return this.PrepareSyncReadResult(read, state);
772+
var res = this.PrepareSyncReadResult(read, state);
773+
callback?.Invoke(res);
774+
return res;
757775
}
758776

777+
Contract.Assert(this.sslOwnedBuffer.Array == null);
759778
// take note of buffer - we will pass bytes there once available
760779
this.sslOwnedBuffer = new ArraySegment<byte>(buffer, offset, count);
761780
this.readCompletionSource = new TaskCompletionSource<int>(state);
@@ -771,6 +790,7 @@ public override int EndRead(IAsyncResult asyncResult)
771790
return syncResult.Result;
772791
}
773792

793+
Debug.Assert(this.readCompletionSource == null || this.readCompletionSource.Task == asyncResult);
774794
Contract.Assert(!((Task<int>)asyncResult).IsCanceled);
775795

776796
try
@@ -782,12 +802,6 @@ public override int EndRead(IAsyncResult asyncResult)
782802
ExceptionDispatchInfo.Capture(ex.InnerException).Throw();
783803
throw; // unreachable
784804
}
785-
finally
786-
{
787-
this.readCompletionSource = null;
788-
this.readCallback = null;
789-
this.sslOwnedBuffer = default(ArraySegment<byte>);
790-
}
791805
}
792806

793807
IAsyncResult PrepareSyncReadResult(int readBytes, object state)
@@ -817,10 +831,11 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As
817831
// write+flush completed synchronously (and successfully)
818832
var result = new SynchronousAsyncResult<int>();
819833
result.AsyncState = state;
820-
callback(result);
834+
callback?.Invoke(result);
821835
return result;
822836
default:
823837
this.writeCallback = callback;
838+
Contract.Assert(this.writeCompletion == null);
824839
var tcs = new TaskCompletionSource(state);
825840
this.writeCompletion = tcs;
826841
task.ContinueWith(WriteCompleteCallback, this, TaskContinuationOptions.ExecuteSynchronously);
@@ -831,34 +846,39 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As
831846
static void HandleChannelWriteComplete(Task writeTask, object state)
832847
{
833848
var self = (MediationStream)state;
849+
850+
AsyncCallback callback = self.writeCallback;
851+
self.writeCallback = null;
852+
853+
var promise = self.writeCompletion;
854+
self.writeCompletion = null;
855+
834856
switch (writeTask.Status)
835857
{
836858
case TaskStatus.RanToCompletion:
837-
self.writeCompletion.TryComplete();
859+
promise.TryComplete();
838860
break;
839861
case TaskStatus.Canceled:
840-
self.writeCompletion.TrySetCanceled();
862+
promise.TrySetCanceled();
841863
break;
842864
case TaskStatus.Faulted:
843-
self.writeCompletion.TrySetException(writeTask.Exception);
865+
promise.TrySetException(writeTask.Exception);
844866
break;
845867
default:
846868
throw new ArgumentOutOfRangeException("Unexpected task status: " + writeTask.Status);
847869
}
848870

849-
self.writeCallback?.Invoke(self.writeCompletion.Task);
871+
callback?.Invoke(promise.Task);
850872
}
851873

852874
public override void EndWrite(IAsyncResult asyncResult)
853875
{
854-
this.writeCallback = null;
855-
this.writeCompletion = null;
856-
857876
if (asyncResult is SynchronousAsyncResult<int>)
858877
{
859878
return;
860879
}
861880

881+
Debug.Assert(this.writeCompletion == null || this.writeCompletion.Task == asyncResult);
862882
try
863883
{
864884
((Task<int>)asyncResult).Wait();
@@ -876,7 +896,7 @@ int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapa
876896
Contract.Assert(destination != null);
877897

878898
byte[] source = this.input;
879-
int readableBytes = this.inputLength - this.inputOffset;
899+
int readableBytes = this.SourceReadableBytes;
880900
int length = Math.Min(readableBytes, destinationCapacity);
881901
Buffer.BlockCopy(source, this.inputStartOffset + this.inputOffset, destination, destinationOffset, length);
882902
this.inputOffset += length;

0 commit comments

Comments
 (0)