@@ -5,6 +5,7 @@ namespace DotNetty.Handlers.Tls
55{
66 using System ;
77 using System . Collections . Generic ;
8+ using System . Diagnostics ;
89 using System . Diagnostics . Contracts ;
910 using System . IO ;
1011 using System . Net . Security ;
@@ -41,7 +42,7 @@ public sealed class TlsHandler : ByteToMessageDecoder
4142 Task < int > pendingSslStreamReadFuture ;
4243
4344 public TlsHandler ( TlsSettings settings )
44- : this ( stream => new SslStream ( stream , false ) , settings )
45+ : this ( stream => new SslStream ( stream , true ) , settings )
4546 {
4647 }
4748
@@ -69,8 +70,6 @@ public TlsHandler(Func<Stream, SslStream> sslStreamFactory, TlsSettings settings
6970
7071 bool IsServer => this . settings is ServerTlsSettings ;
7172
72- public void Dispose ( ) => this . sslStream ? . Dispose ( ) ;
73-
7473 public override void ChannelActive ( IChannelHandlerContext context )
7574 {
7675 base . ChannelActive ( context ) ;
@@ -344,6 +343,9 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
344343
345344 outputBuffer = this . pendingSslStreamReadBuffer ;
346345 outputBufferLength = outputBuffer . WritableBytes ;
346+
347+ this . pendingSslStreamReadFuture = null ;
348+ this . pendingSslStreamReadBuffer = null ;
347349 }
348350 else
349351 {
@@ -363,17 +365,23 @@ void Unwrap(IChannelHandlerContext ctx, IByteBuffer packet, int offset, int leng
363365 if ( ! currentReadFuture . IsCompleted )
364366 {
365367 // we did feed the whole current packet to SslStream yet it did not produce any result -> move to the next packet in input
366- Contract . Assert ( this . mediationStream . SourceReadableBytes == 0 ) ;
367368
368369 continue ;
369370 }
370371
371372 int read = currentReadFuture . Result ;
372373
374+ if ( read == 0 )
375+ {
376+ //Stream closed
377+ return ;
378+ }
379+
373380 // Now output the result of previous read and decide whether to do an extra read on the same source or move forward
374381 AddBufferToOutput ( outputBuffer , read , output ) ;
375382
376383 currentReadFuture = null ;
384+ outputBuffer = null ;
377385 if ( this . mediationStream . SourceReadableBytes == 0 )
378386 {
379387 // we just made a frame available for reading but there was already pending read so SslStream read it out to make further progress there
@@ -620,6 +628,7 @@ void HandleFailure(Exception cause)
620628 // Release all resources such as internal buffers that SSLEngine
621629 // is managing.
622630
631+ this . mediationStream . Dispose ( ) ;
623632 try
624633 {
625634 this . sslStream . Dispose ( ) ;
@@ -701,14 +710,13 @@ public void ExpandSource(int count)
701710
702711 this . inputLength += count ;
703712
704- TaskCompletionSource < int > promise = this . readCompletionSource ;
705- if ( promise == null )
713+ ArraySegment < byte > sslBuffer = this . sslOwnedBuffer ;
714+ if ( sslBuffer . Array == null )
706715 {
707716 // there is no pending read operation - keep for future
708717 return ;
709718 }
710-
711- ArraySegment < byte > sslBuffer = this . sslOwnedBuffer ;
719+ this . sslOwnedBuffer = default ( ArraySegment < byte > ) ;
712720
713721#if NETSTANDARD1_3
714722 this . readByteCount = this . ReadFromInput ( sslBuffer . Array , sslBuffer . Offset , sslBuffer . Count ) ;
@@ -718,29 +726,35 @@ public void ExpandSource(int count)
718726 {
719727 var self = ( MediationStream ) ms ;
720728 TaskCompletionSource < int > p = self . readCompletionSource ;
721- this . readCompletionSource = null ;
729+ self . readCompletionSource = null ;
722730 p . TrySetResult ( self . readByteCount ) ;
723731 } ,
724732 this )
725733 . RunSynchronously ( TaskScheduler . Default ) ;
726734#else
727735 int read = this . ReadFromInput ( sslBuffer . Array , sslBuffer . Offset , sslBuffer . Count ) ;
736+
737+ TaskCompletionSource < int > promise = this . readCompletionSource ;
728738 this . readCompletionSource = null ;
729739 promise . TrySetResult ( read ) ;
730- this . readCallback ? . Invoke ( promise . Task ) ;
740+
741+ AsyncCallback callback = this . readCallback ;
742+ this . readCallback = null ;
743+ callback ? . Invoke ( promise . Task ) ;
731744#endif
732745 }
733746
734747#if NETSTANDARD1_3
735748 public override Task < int > ReadAsync ( byte [ ] buffer , int offset , int count , CancellationToken cancellationToken )
736749 {
737- if ( this . inputLength - this . inputOffset > 0 )
750+ if ( this . SourceReadableBytes > 0 )
738751 {
739752 // we have the bytes available upfront - write out synchronously
740753 int read = this . ReadFromInput ( buffer , offset , count ) ;
741754 return Task . FromResult ( read ) ;
742755 }
743756
757+ Contract . Assert ( this . sslOwnedBuffer . Array == null ) ;
744758 // take note of buffer - we will pass bytes there once available
745759 this . sslOwnedBuffer = new ArraySegment < byte > ( buffer , offset , count ) ;
746760 this . readCompletionSource = new TaskCompletionSource < int > ( ) ;
@@ -749,13 +763,16 @@ public override Task<int> ReadAsync(byte[] buffer, int offset, int count, Cancel
749763#else
750764 public override IAsyncResult BeginRead ( byte [ ] buffer , int offset , int count , AsyncCallback callback , object state )
751765 {
752- if ( this . inputLength - this . inputOffset > 0 )
766+ if ( this . SourceReadableBytes > 0 )
753767 {
754768 // we have the bytes available upfront - write out synchronously
755769 int read = this . ReadFromInput ( buffer , offset , count ) ;
756- return this . PrepareSyncReadResult ( read , state ) ;
770+ var res = this . PrepareSyncReadResult ( read , state ) ;
771+ callback ? . Invoke ( res ) ;
772+ return res ;
757773 }
758774
775+ Contract . Assert ( this . sslOwnedBuffer . Array == null ) ;
759776 // take note of buffer - we will pass bytes there once available
760777 this . sslOwnedBuffer = new ArraySegment < byte > ( buffer , offset , count ) ;
761778 this . readCompletionSource = new TaskCompletionSource < int > ( state ) ;
@@ -771,6 +788,7 @@ public override int EndRead(IAsyncResult asyncResult)
771788 return syncResult . Result ;
772789 }
773790
791+ Debug . Assert ( this . readCompletionSource == null || this . readCompletionSource . Task == asyncResult ) ;
774792 Contract . Assert ( ! ( ( Task < int > ) asyncResult ) . IsCanceled ) ;
775793
776794 try
@@ -782,12 +800,6 @@ public override int EndRead(IAsyncResult asyncResult)
782800 ExceptionDispatchInfo . Capture ( ex . InnerException ) . Throw ( ) ;
783801 throw ; // unreachable
784802 }
785- finally
786- {
787- this . readCompletionSource = null ;
788- this . readCallback = null ;
789- this . sslOwnedBuffer = default ( ArraySegment < byte > ) ;
790- }
791803 }
792804
793805 IAsyncResult PrepareSyncReadResult ( int readBytes , object state )
@@ -817,51 +829,63 @@ public override IAsyncResult BeginWrite(byte[] buffer, int offset, int count, As
817829 // write+flush completed synchronously (and successfully)
818830 var result = new SynchronousAsyncResult < int > ( ) ;
819831 result . AsyncState = state ;
820- callback ( result ) ;
832+ callback ? . Invoke ( result ) ;
821833 return result ;
822834 default :
823- this . writeCallback = callback ;
824- var tcs = new TaskCompletionSource ( state ) ;
825- this . writeCompletion = tcs ;
826- task . ContinueWith ( WriteCompleteCallback , this , TaskContinuationOptions . ExecuteSynchronously ) ;
827- return tcs . Task ;
835+ if ( callback != null || state != task . AsyncState )
836+ {
837+ Contract . Assert ( this . writeCompletion == null ) ;
838+ this . writeCallback = callback ;
839+ var tcs = new TaskCompletionSource ( state ) ;
840+ this . writeCompletion = tcs ;
841+ task . ContinueWith ( WriteCompleteCallback , this , TaskContinuationOptions . ExecuteSynchronously ) ;
842+ return tcs . Task ;
843+ }
844+ else
845+ {
846+ return task ;
847+ }
828848 }
829849 }
830850
831851 static void HandleChannelWriteComplete ( Task writeTask , object state )
832852 {
833853 var self = ( MediationStream ) state ;
854+
855+ AsyncCallback callback = self . writeCallback ;
856+ self . writeCallback = null ;
857+
858+ var promise = self . writeCompletion ;
859+ self . writeCompletion = null ;
860+
834861 switch ( writeTask . Status )
835862 {
836863 case TaskStatus . RanToCompletion :
837- self . writeCompletion . TryComplete ( ) ;
864+ promise . TryComplete ( ) ;
838865 break ;
839866 case TaskStatus . Canceled :
840- self . writeCompletion . TrySetCanceled ( ) ;
867+ promise . TrySetCanceled ( ) ;
841868 break ;
842869 case TaskStatus . Faulted :
843- self . writeCompletion . TrySetException ( writeTask . Exception ) ;
870+ promise . TrySetException ( writeTask . Exception ) ;
844871 break ;
845872 default :
846873 throw new ArgumentOutOfRangeException ( "Unexpected task status: " + writeTask . Status ) ;
847874 }
848875
849- self . writeCallback ? . Invoke ( self . writeCompletion . Task ) ;
876+ callback ? . Invoke ( promise . Task ) ;
850877 }
851878
852879 public override void EndWrite ( IAsyncResult asyncResult )
853880 {
854- this . writeCallback = null ;
855- this . writeCompletion = null ;
856-
857881 if ( asyncResult is SynchronousAsyncResult < int > )
858882 {
859883 return ;
860884 }
861885
862886 try
863887 {
864- ( ( Task < int > ) asyncResult ) . Wait ( ) ;
888+ ( ( Task ) asyncResult ) . Wait ( ) ;
865889 }
866890 catch ( AggregateException ex )
867891 {
@@ -876,7 +900,7 @@ int ReadFromInput(byte[] destination, int destinationOffset, int destinationCapa
876900 Contract . Assert ( destination != null ) ;
877901
878902 byte [ ] source = this . input ;
879- int readableBytes = this . inputLength - this . inputOffset ;
903+ int readableBytes = this . SourceReadableBytes ;
880904 int length = Math . Min ( readableBytes , destinationCapacity ) ;
881905 Buffer . BlockCopy ( source , this . inputStartOffset + this . inputOffset , destination , destinationOffset , length ) ;
882906 this . inputOffset += length ;
@@ -894,8 +918,11 @@ protected override void Dispose(bool disposing)
894918 if ( disposing )
895919 {
896920 TaskCompletionSource < int > p = this . readCompletionSource ;
897- this . readCompletionSource = null ;
898- p ? . TrySetResult ( 0 ) ;
921+ if ( p != null )
922+ {
923+ this . readCompletionSource = null ;
924+ p . TrySetResult ( 0 ) ;
925+ }
899926 }
900927 }
901928
0 commit comments