Skip to content

Commit 4aa8fd7

Browse files
committed
Cleanup websocket.go
1 parent 6b2e258 commit 4aa8fd7

File tree

1 file changed

+24
-31
lines changed

1 file changed

+24
-31
lines changed

websocket.go

Lines changed: 24 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,6 @@ type Conn struct {
6161
}
6262

6363
func (c *Conn) close(err error) {
64-
err = xerrors.Errorf("websocket closed: %w", err)
65-
6664
c.closeOnce.Do(func() {
6765
runtime.SetFinalizer(c, nil)
6866

@@ -71,7 +69,7 @@ func (c *Conn) close(err error) {
7169
cerr = err
7270
}
7371

74-
c.closeErr = cerr
72+
c.closeErr = xerrors.Errorf("websocket closed: %w", cerr)
7573

7674
close(c.closed)
7775
})
@@ -98,7 +96,7 @@ func (c *Conn) init() {
9896
c.readDone = make(chan int)
9997

10098
runtime.SetFinalizer(c, func(c *Conn) {
101-
c.Close(StatusInternalError, "connection garbage collected")
99+
c.close(xerrors.New("connection garbage collected"))
102100
})
103101

104102
go c.writeLoop()
@@ -238,7 +236,7 @@ func (c *Conn) handleControl(h header) {
238236
case opClose:
239237
ce, err := parseClosePayload(b)
240238
if err != nil {
241-
c.close(xerrors.Errorf("read invalid close payload: %w", err))
239+
c.close(xerrors.Errorf("received invalid close payload: %w", err))
242240
return
243241
}
244242
if ce.Code == StatusNoStatusRcvd {
@@ -302,7 +300,7 @@ func (c *Conn) readLoop() {
302300
}
303301
}
304302

305-
func (c *Conn) dataReadLoop(h header) (err error) {
303+
func (c *Conn) dataReadLoop(h header) error {
306304
maskPos := 0
307305
left := h.payloadLength
308306
firstReadDone := false
@@ -355,7 +353,6 @@ func (c *Conn) writePong(p []byte) error {
355353

356354
// Close closes the WebSocket connection with the given status code and reason.
357355
// It will write a WebSocket close frame with a timeout of 5 seconds.
358-
// Concurrent calls to Close are ok.
359356
func (c *Conn) Close(code StatusCode, reason string) error {
360357
err := c.exportedClose(code, reason)
361358
if err != nil {
@@ -400,7 +397,7 @@ func (c *Conn) writeClose(p []byte, cerr CloseError) error {
400397
return err
401398
}
402399

403-
if cerr != c.closeErr {
400+
if !xerrors.Is(c.closeErr, cerr) {
404401
return c.closeErr
405402
}
406403

@@ -420,9 +417,8 @@ func (c *Conn) writeSingleFrame(ctx context.Context, opcode opcode, p []byte) er
420417
payload: p,
421418
}:
422419
case <-ctx.Done():
423-
err := xerrors.Errorf("control frame write timed out: %w", ctx.Err())
424-
c.close(err)
425-
return err
420+
c.close(xerrors.Errorf("control frame write timed out: %w", ctx.Err()))
421+
return ctx.Err()
426422
}
427423

428424
select {
@@ -487,7 +483,7 @@ func (w messageWriter) write(p []byte) (int, error) {
487483
select {
488484
case <-w.ctx.Done():
489485
w.c.close(xerrors.Errorf("data write timed out: %w", w.ctx.Err()))
490-
// Wait for writeLoop to complete so we know p is done.
486+
// Wait for writeLoop to complete so we know p is done with.
491487
<-w.c.writeDone
492488
return 0, w.ctx.Err()
493489
case _, ok := <-w.c.writeDone:
@@ -542,25 +538,21 @@ func (c *Conn) Reader(ctx context.Context) (MessageType, io.Reader, error) {
542538
}
543539

544540
func (c *Conn) reader(ctx context.Context) (MessageType, io.Reader, error) {
545-
for !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) {
546-
select {
547-
case <-c.closed:
548-
return 0, nil, c.closeErr
549-
case c.readBytes <- nil:
550-
select {
551-
case <-ctx.Done():
552-
return 0, nil, ctx.Err()
553-
case _, ok := <-c.readDone:
554-
if !ok {
555-
return 0, nil, c.closeErr
556-
}
557-
if atomic.LoadInt64(&c.activeReader) == 1 {
558-
return 0, nil, xerrors.New("previous message not fully read")
559-
}
560-
}
561-
case <-ctx.Done():
562-
return 0, nil, ctx.Err()
541+
if !atomic.CompareAndSwapInt64(&c.activeReader, 0, 1) {
542+
// If the next read yields io.EOF we are good to go.
543+
r := messageReader{
544+
ctx: ctx,
545+
c: c,
563546
}
547+
_, err := r.Read(nil)
548+
if err == nil {
549+
return 0, nil, xerrors.New("previous message not fully read")
550+
}
551+
if !xerrors.Is(err, io.EOF) {
552+
return 0, nil, xerrors.Errorf("failed to check if last message at io.EOF: %w", err)
553+
}
554+
555+
atomic.StoreInt64(&c.activeReader, 1)
564556
}
565557

566558
select {
@@ -586,7 +578,8 @@ type messageReader struct {
586578
func (r messageReader) Read(p []byte) (int, error) {
587579
n, err := r.read(p)
588580
if err != nil {
589-
// Have to return io.EOF directly for now, cannot wrap.
581+
// Have to return io.EOF directly for now, we cannot wrap as xerrors
582+
// isn't used in stdlib.
590583
if err == io.EOF {
591584
return n, io.EOF
592585
}

0 commit comments

Comments
 (0)