@@ -286,7 +286,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
286
286
c .Close (StatusProtocolError , "received invalid close payload" )
287
287
return xerrors .Errorf ("received invalid close payload: %w" , err )
288
288
}
289
- c .writeClose (b , ce , false )
289
+ c .writeClose (b , xerrors . Errorf ( "received close frame: %w" , ce ) )
290
290
return c .closeErr
291
291
default :
292
292
panic (fmt .Sprintf ("websocket: unexpected control opcode: %#v" , h ))
@@ -644,38 +644,54 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
644
644
case c .setWriteTimeout <- ctx :
645
645
}
646
646
647
- writeErr := func (err error ) error {
648
- select {
649
- case <- c .closed :
650
- return c .closeErr
651
- case <- ctx .Done ():
652
- err = ctx .Err ()
653
- default :
654
- }
655
-
656
- err = xerrors .Errorf ("failed to write %v frame: %w" , h .opcode , err )
657
- // We need to release the lock first before closing the connection to ensure
658
- // the lock can be acquired inside close to ensure no one can access c.bw.
659
- c .releaseLock (c .writeFrameLock )
660
- c .close (err )
647
+ n , err := c .realWriteFrame (ctx , h , p )
648
+ if err != nil {
649
+ return n , err
650
+ }
661
651
662
- return err
652
+ // We already finished writing, no need to potentially brick the connection if
653
+ // the context expires.
654
+ select {
655
+ case <- c .closed :
656
+ return n , c .closeErr
657
+ case c .setWriteTimeout <- context .Background ():
663
658
}
664
659
660
+ return n , nil
661
+ }
662
+
663
+ func (c * Conn ) realWriteFrame (ctx context.Context , h header , p []byte ) (n int , err error ){
664
+ defer func () {
665
+ if err != nil {
666
+ select {
667
+ case <- c .closed :
668
+ err = c .closeErr
669
+ case <- ctx .Done ():
670
+ err = ctx .Err ()
671
+ default :
672
+ }
673
+
674
+ err = xerrors .Errorf ("failed to write %v frame: %w" , h .opcode , err )
675
+ // We need to release the lock first before closing the connection to ensure
676
+ // the lock can be acquired inside close to ensure no one can access c.bw.
677
+ c .releaseLock (c .writeFrameLock )
678
+ c .close (err )
679
+ }
680
+ }()
681
+
665
682
headerBytes := writeHeader (c .writeHeaderBuf , h )
666
683
_ , err = c .bw .Write (headerBytes )
667
684
if err != nil {
668
- return 0 , writeErr ( err )
685
+ return 0 , err
669
686
}
670
687
671
- var n int
672
688
if c .client {
673
689
var keypos int
674
690
for len (p ) > 0 {
675
691
if c .bw .Available () == 0 {
676
692
err = c .bw .Flush ()
677
693
if err != nil {
678
- return n , writeErr ( err )
694
+ return n , err
679
695
}
680
696
}
681
697
@@ -689,7 +705,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
689
705
690
706
n2 , err := c .bw .Write (p2 )
691
707
if err != nil {
692
- return n , writeErr ( err )
708
+ return n , err
693
709
}
694
710
695
711
keypos = fastXOR (h .maskKey , keypos , c .writeBuf [i :i + n2 ])
@@ -700,25 +716,17 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
700
716
} else {
701
717
n , err = c .bw .Write (p )
702
718
if err != nil {
703
- return n , writeErr ( err )
719
+ return n , err
704
720
}
705
721
}
706
722
707
- if fin {
723
+ if h . fin {
708
724
err = c .bw .Flush ()
709
725
if err != nil {
710
- return n , writeErr ( err )
726
+ return n , err
711
727
}
712
728
}
713
729
714
- // We already finished writing, no need to potentially brick the connection if
715
- // the context expires.
716
- select {
717
- case <- c .closed :
718
- return n , c .closeErr
719
- case c .setWriteTimeout <- context .Background ():
720
- }
721
-
722
730
return n , nil
723
731
}
724
732
@@ -767,10 +775,19 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
767
775
p , _ = ce .bytes ()
768
776
}
769
777
770
- return c .writeClose (p , ce , true )
778
+ err = c .writeClose (p , xerrors .Errorf ("sent close frame: %w" , ce ))
779
+ if err != nil {
780
+ return err
781
+ }
782
+
783
+ if ! xerrors .Is (c .closeErr , ce ) {
784
+ return c .closeErr
785
+ }
786
+
787
+ return nil
771
788
}
772
789
773
- func (c * Conn ) writeClose (p []byte , cerr error , us bool ) error {
790
+ func (c * Conn ) writeClose (p []byte , cerr error ) error {
774
791
ctx , cancel := context .WithTimeout (context .Background (), time .Second * 5 )
775
792
defer cancel ()
776
793
@@ -780,16 +797,7 @@ func (c *Conn) writeClose(p []byte, cerr error, us bool) error {
780
797
return err
781
798
}
782
799
783
- if us {
784
- cerr = xerrors .Errorf ("sent close frame: %w" , cerr )
785
- } else {
786
- cerr = xerrors .Errorf ("received close frame: %w" , cerr )
787
- }
788
-
789
800
c .close (cerr )
790
- if ! xerrors .Is (c .closeErr , cerr ) {
791
- return c .closeErr
792
- }
793
801
794
802
return nil
795
803
}
0 commit comments