@@ -5,6 +5,7 @@ namespace DotNetty.Handlers.Tls
55{
66 using System ;
77 using System . Collections . Generic ;
8+ using System . Diagnostics ;
89 using System . Diagnostics . Contracts ;
910 using System . IO ;
1011 using System . Net . Security ;
@@ -41,7 +42,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
4142 Task < int > pendingSslStreamReadFuture ;
4243
4344 public TlsHandler ( TlsSettings settings )
44- : this ( stream => new SslStream ( stream , false ) , settings )
45+ : this ( stream => new SslStream ( stream , true ) , settings )
4546 {
4647 }
4748
@@ -344,6 +345,9 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
344345
345346 outputBuffer = this . pendingSslStreamReadBuffer ;
346347 outputBufferLength = outputBuffer . WritableBytes ;
348+
349+ this . pendingSslStreamReadFuture = null ;
350+ this . pendingSslStreamReadBuffer = null ;
347351 }
348352 else
349353 {
@@ -363,17 +367,23 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
363367 if ( ! currentReadFuture . IsCompleted )
364368 {
365369 // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
366- Contract . Assert ( this . mediationStream . SourceReadableBytes == 0 ) ;
367370
368371 continue ;
369372 }
370373
371374 int read = currentReadFuture . Result ;
372375
376+ if ( read == 0 )
377+ {
378+ //Stream closed
379+ return ;
380+ }
381+
373382 // Now output the result of previous read and decide whether to do an extra read on the same source or move forward
374383 AddBufferToOutput ( outputBuffer , read , output ) ;
375384
376385 currentReadFuture = null ;
386+ outputBuffer = null ;
377387 if ( this . mediationStream . SourceReadableBytes == 0 )
378388 {
379389 // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
@@ -620,6 +630,7 @@ void HandleFailure(Exception cause)
620630 // Release all resources such as internal buffers that SSLEngine
621631 // is managing.
622632
633+ this . mediationStream . Dispose ( ) ;
623634 try
624635 {
625636 this . sslStream . Dispose ( ) ;
@@ -701,14 +712,13 @@ public void ExpandSource(int count)
701712
702713 this . inputLength += count ;
703714
704- TaskCompletionSource < int > promise = this . readCompletionSource ;
705- if ( promise == null )
715+ ArraySegment < byte > sslBuffer = this . sslOwnedBuffer ;
716+ if ( sslBuffer . Array == null )
706717 {
707718 // there is no pending read operation - keep for future
708719 return ;
709720 }
710-
711- ArraySegment < byte > sslBuffer = this . sslOwnedBuffer ;
721+ this . sslOwnedBuffer = default ( ArraySegment < byte > ) ;
712722
713723#if NETSTANDARD1_3
714724 this . readByteCount = this . ReadFromInput ( sslBuffer . Array , sslBuffer . Offset , sslBuffer . Count ) ;
@@ -718,29 +728,35 @@ public void ExpandSource(int count)
718728 {
719729 var self = ( MediationStream ) ms ;
720730 TaskCompletionSource < int > p = self . readCompletionSource ;
721- this . readCompletionSource = null ;
731+ self . readCompletionSource = null ;
722732 p . TrySetResult ( self . readByteCount ) ;
723733 } ,
724734 this )
725735 . RunSynchronously ( TaskScheduler . Default ) ;
726736#else
727737 int read = this . ReadFromInput ( sslBuffer . Array , sslBuffer . Offset , sslBuffer . Count ) ;
738+
739+ TaskCompletionSource < int > promise = this . readCompletionSource ;
728740 this . readCompletionSource = null ;
729741 promise . TrySetResult ( read ) ;
730- this . readCallback ? . Invoke ( promise . Task ) ;
742+
743+ AsyncCallback callback = this . readCallback ;
744+ this . readCallback = null ;
745+ callback ? . Invoke ( promise . Task ) ;
731746#endif
732747 }
733748
734749#if NETSTANDARD1_3
735750 public override Task < int > ReadAsync ( byte [ ] buffer , int offset , int count , CancellationToken cancellationToken )
736751 {
737- if ( this . inputLength - this . inputOffset > 0 )
752+ if ( this . SourceReadableBytes > 0 )
738753 {
739754 // we have the bytes available upfront - write out synchronously
740755 int read = this . ReadFromInput ( buffer , offset , count ) ;
741756 return Task . FromResult ( read ) ;
742757 }
743758
759+ Contract . Assert ( this . sslOwnedBuffer . Array == null ) ;
744760 // take note of buffer - we will pass bytes there once available
745761 this . sslOwnedBuffer = new ArraySegment < byte > ( buffer , offset , count ) ;
746762 this . readCompletionSource = new TaskCompletionSource < int > ( ) ;
@@ -749,13 +765,16 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
749765#else
750766 public override IAsyncResult BeginRead ( byte [ ] buffer , int offset , int count , AsyncCallback callback , object state )
751767 {
752- if ( this . inputLength - this . inputOffset > 0 )
768+ if ( this . SourceReadableBytes > 0 )
753769 {
754770 // we have the bytes available upfront - write out synchronously
755771 int read = this . ReadFromInput ( buffer , offset , count ) ;
756- return this . PrepareSyncReadResult ( read , state ) ;
772+ var res = this . PrepareSyncReadResult ( read , state ) ;
773+ callback ? . Invoke ( res ) ;
774+ return res ;
757775 }
758776
777+ Contract . Assert ( this . sslOwnedBuffer . Array == null ) ;
759778 // take note of buffer - we will pass bytes there once available
760779 this . sslOwnedBuffer = new ArraySegment < byte > ( buffer , offset , count ) ;
761780 this . readCompletionSource = new TaskCompletionSource < int > ( state ) ;
@@ -771,6 +790,7 @@ public override int EndRead(IAsyncResult asyncResult)
771790 return syncResult . Result ;
772791 }
773792
793+ Debug . Assert ( this . readCompletionSource == null || this . readCompletionSource . Task == asyncResult ) ;
774794 Contract . Assert ( ! ( ( Task < int > ) asyncResult ) . IsCanceled ) ;
775795
776796 try
@@ -782,12 +802,6 @@ public override int EndRead(IAsyncResult asyncResult)
782802 ExceptionDispatchInfo . Capture ( ex . InnerException ) . Throw ( ) ;
783803 throw ; // unreachable
784804 }
785- finally
786- {
787- this . readCompletionSource = null ;
788- this . readCallback = null ;
789- this . sslOwnedBuffer = default ( ArraySegment < byte > ) ;
790- }
791805 }
792806
793807 IAsyncResult PrepareSyncReadResult ( int readBytes , object state )
@@ -817,10 +831,11 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As
817831 // write+flush completed synchronously (and successfully)
818832 var result = new SynchronousAsyncResult < int > ( ) ;
819833 result . AsyncState = state ;
820- callback ( result ) ;
834+ callback ? . Invoke ( result ) ;
821835 return result ;
822836 default :
823837 this . writeCallback = callback ;
838+ Contract . Assert ( this . writeCompletion == null ) ;
824839 var tcs = new TaskCompletionSource ( state ) ;
825840 this . writeCompletion = tcs ;
826841 task . ContinueWith ( WriteCompleteCallback , this , TaskContinuationOptions . ExecuteSynchronously ) ;
@@ -831,34 +846,39 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As
831846 static void HandleChannelWriteComplete ( Task writeTask , object state )
832847 {
833848 var self = ( MediationStream ) state ;
849+
850+ AsyncCallback callback = self . writeCallback ;
851+ self . writeCallback = null ;
852+
853+ var promise = self . writeCompletion ;
854+ self . writeCompletion = null ;
855+
834856 switch ( writeTask . Status )
835857 {
836858 case TaskStatus . RanToCompletion :
837- self . writeCompletion . TryComplete ( ) ;
859+ promise . TryComplete ( ) ;
838860 break ;
839861 case TaskStatus . Canceled :
840- self . writeCompletion . TrySetCanceled ( ) ;
862+ promise . TrySetCanceled ( ) ;
841863 break ;
842864 case TaskStatus . Faulted :
843- self . writeCompletion . TrySetException ( writeTask . Exception ) ;
865+ promise . TrySetException ( writeTask . Exception ) ;
844866 break ;
845867 default :
846868 throw new ArgumentOutOfRangeException ( "Unexpected task status: " + writeTask . Status ) ;
847869 }
848870
849- self . writeCallback ? . Invoke ( self . writeCompletion . Task ) ;
871+ callback ? . Invoke ( promise . Task ) ;
850872 }
851873
852874 public override void EndWrite ( IAsyncResult asyncResult )
853875 {
854- this . writeCallback = null ;
855- this . writeCompletion = null ;
856-
857876 if ( asyncResult is SynchronousAsyncResult < int > )
858877 {
859878 return ;
860879 }
861880
881+ Debug . Assert ( this . writeCompletion == null || this . writeCompletion . Task == asyncResult ) ;
862882 try
863883 {
864884 ( ( Task < int > ) asyncResult ) . Wait ( ) ;
@@ -876,7 +896,7 @@ int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapa
876896 Contract . Assert ( destination != null ) ;
877897
878898 byte [ ] source = this . input ;
879- int readableBytes = this . inputLength - this . inputOffset ;
899+ int readableBytes = this . SourceReadableBytes ;
880900 int length = Math . Min ( readableBytes , destinationCapacity ) ;
881901 Buffer . BlockCopy ( source , this . inputStartOffset + this . inputOffset , destination , destinationOffset , length ) ;
882902 this . inputOffset += length ;
0 commit comments