@@ -61,8 +61,6 @@ type Conn struct {
61
61
}
62
62
63
63
func (c * Conn ) close (err error ) {
64
- err = xerrors .Errorf ("websocket closed: %w" , err )
65
-
66
64
c .closeOnce .Do (func () {
67
65
runtime .SetFinalizer (c , nil )
68
66
@@ -71,7 +69,7 @@ func (c *Conn) close(err error) {
71
69
cerr = err
72
70
}
73
71
74
- c .closeErr = cerr
72
+ c .closeErr = xerrors . Errorf ( "websocket closed: %w" , cerr )
75
73
76
74
close (c .closed )
77
75
})
@@ -98,7 +96,7 @@ func (c *Conn) init() {
98
96
c .readDone = make (chan int )
99
97
100
98
runtime .SetFinalizer (c , func (c * Conn ) {
101
- c .Close ( StatusInternalError , "connection garbage collected" )
99
+ c .close ( xerrors . New ( "connection garbage collected" ) )
102
100
})
103
101
104
102
go c .writeLoop ()
@@ -238,7 +236,7 @@ func (c *Conn) handleControl(h header) {
238
236
case opClose :
239
237
ce , err := parseClosePayload (b )
240
238
if err != nil {
241
- c .close (xerrors .Errorf ("read invalid close payload: %w" , err ))
239
+ c .close (xerrors .Errorf ("received invalid close payload: %w" , err ))
242
240
return
243
241
}
244
242
if ce .Code == StatusNoStatusRcvd {
@@ -302,7 +300,7 @@ func (c *Conn) readLoop() {
302
300
}
303
301
}
304
302
305
- func (c * Conn ) dataReadLoop (h header ) ( err error ) {
303
+ func (c * Conn ) dataReadLoop (h header ) error {
306
304
maskPos := 0
307
305
left := h .payloadLength
308
306
firstReadDone := false
@@ -355,7 +353,6 @@ func (c *Conn) writePong(p []byte) error {
355
353
356
354
// Close closes the WebSocket connection with the given status code and reason.
357
355
// It will write a WebSocket close frame with a timeout of 5 seconds.
358
- // Concurrent calls to Close are ok.
359
356
func (c * Conn ) Close (code StatusCode , reason string ) error {
360
357
err := c .exportedClose (code , reason )
361
358
if err != nil {
@@ -400,7 +397,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
400
397
return err
401
398
}
402
399
403
- if cerr != c .closeErr {
400
+ if ! xerrors . Is ( c .closeErr , cerr ) {
404
401
return c .closeErr
405
402
}
406
403
@@ -420,9 +417,8 @@ func (c *Conn) writeSingleFrame(ctx context.Context, opcode opcode, p []byte) er
420
417
payload : p ,
421
418
}:
422
419
case <- ctx .Done ():
423
- err := xerrors .Errorf ("control frame write timed out: %w" , ctx .Err ())
424
- c .close (err )
425
- return err
420
+ c .close (xerrors .Errorf ("control frame write timed out: %w" , ctx .Err ()))
421
+ return ctx .Err ()
426
422
}
427
423
428
424
select {
@@ -487,7 +483,7 @@ func (w messageWriter) write(p []byte) (int, error) {
487
483
select {
488
484
case <- w .ctx .Done ():
489
485
w .c .close (xerrors .Errorf ("data write timed out: %w" , w .ctx .Err ()))
490
- // Wait for writeLoop to complete so we know p is done.
486
+ // Wait for writeLoop to complete so we know p is done with .
491
487
<- w .c .writeDone
492
488
return 0 , w .ctx .Err ()
493
489
case _ , ok := <- w .c .writeDone :
@@ -542,25 +538,21 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
542
538
}
543
539
544
540
func (c * Conn ) reader (ctx context.Context ) (MessageType , io.Reader , error ) {
545
- for ! atomic .CompareAndSwapInt64 (& c .activeReader , 0 , 1 ) {
546
- select {
547
- case <- c .closed :
548
- return 0 , nil , c .closeErr
549
- case c .readBytes <- nil :
550
- select {
551
- case <- ctx .Done ():
552
- return 0 , nil , ctx .Err ()
553
- case _ , ok := <- c .readDone :
554
- if ! ok {
555
- return 0 , nil , c .closeErr
556
- }
557
- if atomic .LoadInt64 (& c .activeReader ) == 1 {
558
- return 0 , nil , xerrors .New ("previous message not fully read" )
559
- }
560
- }
561
- case <- ctx .Done ():
562
- return 0 , nil , ctx .Err ()
541
+ if ! atomic .CompareAndSwapInt64 (& c .activeReader , 0 , 1 ) {
542
+ // If the next read yields io.EOF we are good to go.
543
+ r := messageReader {
544
+ ctx : ctx ,
545
+ c : c ,
563
546
}
547
+ _ , err := r .Read (nil )
548
+ if err == nil {
549
+ return 0 , nil , xerrors .New ("previous message not fully read" )
550
+ }
551
+ if ! xerrors .Is (err , io .EOF ) {
552
+ return 0 , nil , xerrors .Errorf ("failed to check if last message at io.EOF: %w" , err )
553
+ }
554
+
555
+ atomic .StoreInt64 (& c .activeReader , 1 )
564
556
}
565
557
566
558
select {
@@ -586,7 +578,8 @@ type messageReader struct {
586
578
func (r messageReader ) Read (p []byte ) (int , error ) {
587
579
n , err := r .read (p )
588
580
if err != nil {
589
- // Have to return io.EOF directly for now, cannot wrap.
581
+ // Have to return io.EOF directly for now, we cannot wrap as xerrors
582
+ // isn't used in stdlib.
590
583
if err == io .EOF {
591
584
return n , io .EOF
592
585
}
0 commit comments