Skip to content

Commit d76d893

Browse files
committed
Improve write structure
1 parent 0ed9c74 commit d76d893

File tree

1 file changed

+50
-42
lines changed

1 file changed

+50
-42
lines changed

websocket.go

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -286,7 +286,7 @@ func (c *Conn) handleControl(ctx context.Context, h header) error {
286286
c.Close(StatusProtocolError, "received invalid close payload")
287287
return xerrors.Errorf("received invalid close payload: %w", err)
288288
}
289-
c.writeClose(b, ce, false)
289+
c.writeClose(b, xerrors.Errorf("received close frame: %w", ce))
290290
return c.closeErr
291291
default:
292292
panic(fmt.Sprintf("websocket: unexpected control opcode: %#v", h))
@@ -644,38 +644,54 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
644644
case c.setWriteTimeout <- ctx:
645645
}
646646

647-
writeErr := func(err error) error {
648-
select {
649-
case <-c.closed:
650-
return c.closeErr
651-
case <-ctx.Done():
652-
err = ctx.Err()
653-
default:
654-
}
655-
656-
err = xerrors.Errorf("failed to write %v frame: %w", h.opcode, err)
657-
// We need to release the lock first before closing the connection to ensure
658-
// the lock can be acquired inside close to ensure no one can access c.bw.
659-
c.releaseLock(c.writeFrameLock)
660-
c.close(err)
647+
n, err := c.realWriteFrame(ctx, h, p)
648+
if err != nil {
649+
return n, err
650+
}
661651

662-
return err
652+
// We already finished writing, no need to potentially brick the connection if
653+
// the context expires.
654+
select {
655+
case <-c.closed:
656+
return n, c.closeErr
657+
case c.setWriteTimeout <- context.Background():
663658
}
664659

660+
return n, nil
661+
}
662+
663+
func (c *Conn) realWriteFrame(ctx context.Context, h header, p []byte) (n int, err error){
664+
defer func() {
665+
if err != nil {
666+
select {
667+
case <-c.closed:
668+
err = c.closeErr
669+
case <-ctx.Done():
670+
err = ctx.Err()
671+
default:
672+
}
673+
674+
err = xerrors.Errorf("failed to write %v frame: %w", h.opcode, err)
675+
// We need to release the lock first before closing the connection to ensure
676+
// the lock can be acquired inside close to ensure no one can access c.bw.
677+
c.releaseLock(c.writeFrameLock)
678+
c.close(err)
679+
}
680+
}()
681+
665682
headerBytes := writeHeader(c.writeHeaderBuf, h)
666683
_, err = c.bw.Write(headerBytes)
667684
if err != nil {
668-
return 0, writeErr(err)
685+
return 0, err
669686
}
670687

671-
var n int
672688
if c.client {
673689
var keypos int
674690
for len(p) > 0 {
675691
if c.bw.Available() == 0 {
676692
err = c.bw.Flush()
677693
if err != nil {
678-
return n, writeErr(err)
694+
return n, err
679695
}
680696
}
681697

@@ -689,7 +705,7 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
689705

690706
n2, err := c.bw.Write(p2)
691707
if err != nil {
692-
return n, writeErr(err)
708+
return n, err
693709
}
694710

695711
keypos = fastXOR(h.maskKey, keypos, c.writeBuf[i:i+n2])
@@ -700,25 +716,17 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, opcode opcode, p []byte
700716
} else {
701717
n, err = c.bw.Write(p)
702718
if err != nil {
703-
return n, writeErr(err)
719+
return n, err
704720
}
705721
}
706722

707-
if fin {
723+
if h.fin {
708724
err = c.bw.Flush()
709725
if err != nil {
710-
return n, writeErr(err)
726+
return n, err
711727
}
712728
}
713729

714-
// We already finished writing, no need to potentially brick the connection if
715-
// the context expires.
716-
select {
717-
case <-c.closed:
718-
return n, c.closeErr
719-
case c.setWriteTimeout <- context.Background():
720-
}
721-
722730
return n, nil
723731
}
724732

@@ -767,10 +775,19 @@ func (c *Conn) exportedClose(code StatusCode, reason string) error {
767775
p, _ = ce.bytes()
768776
}
769777

770-
return c.writeClose(p, ce, true)
778+
err = c.writeClose(p, xerrors.Errorf("sent close frame: %w", ce))
779+
if err != nil {
780+
return err
781+
}
782+
783+
if !xerrors.Is(c.closeErr, ce) {
784+
return c.closeErr
785+
}
786+
787+
return nil
771788
}
772789

773-
func (c *Conn) writeClose(p []byte, cerr error, us bool) error {
790+
func (c *Conn) writeClose(p []byte, cerr error) error {
774791
ctx, cancel := context.WithTimeout(context.Background(), time.Second*5)
775792
defer cancel()
776793

@@ -780,16 +797,7 @@ func (c *Conn) writeClose(p []byte, cerr error, us bool) error {
780797
return err
781798
}
782799

783-
if us {
784-
cerr = xerrors.Errorf("sent close frame: %w", cerr)
785-
} else {
786-
cerr = xerrors.Errorf("received close frame: %w", cerr)
787-
}
788-
789800
c.close(cerr)
790-
if !xerrors.Is(c.closeErr, cerr) {
791-
return c.closeErr
792-
}
793801

794802
return nil
795803
}

0 commit comments

Comments
 (0)