Skip to content

Commit b335a61

Browse files
committed
fix: write i/o timeout on ssh connection
1 parent 6cd984d commit b335a61

File tree

3 files changed

+30
-6
lines changed

3 files changed

+30
-6
lines changed

client.go

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -253,9 +253,6 @@ func (c *Client) workHandler() {
253253
}
254254

255255
func (c *Client) process(data string) {
256-
if err := c.conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil {
257-
c.err <- err
258-
}
259256
if _, err := c.conn.Write([]byte(data)); err != nil {
260257
c.err <- err
261258
}

client_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,7 @@ func TestClientWriteFail(t *testing.T) {
128128
if !assert.NoError(t, err) {
129129
return
130130
}
131-
assert.NoError(t, c.conn.(*legacyConnection).Conn.(*net.TCPConn).CloseWrite())
131+
assert.NoError(t, c.conn.(*legacyConnection).Conn.(*writeTimeoutConn).Conn.(*net.TCPConn).CloseWrite())
132132

133133
_, err = c.Exec("version")
134134
assert.Error(t, err)

connection.go

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,21 @@ const (
2020
DefaultSSHPort = 10022
2121
)
2222

23+
type writeTimeoutConn struct {
24+
net.Conn
25+
timeout time.Duration
26+
}
27+
28+
func (c *writeTimeoutConn) Write(p []byte) (n int, err error) {
29+
if err = c.Conn.SetWriteDeadline(time.Now().Add(c.timeout)); err != nil {
30+
return 0, fmt.Errorf("writeTimeoutConn: SetWriteDeadline: %w", err)
31+
}
32+
if n, err = c.Conn.Write(p); err != nil {
33+
return n, fmt.Errorf("writeTimeoutConn: write: %w", err)
34+
}
35+
return n, nil
36+
}
37+
2338
// legacyConnection is an insecure TCP connection.
2439
type legacyConnection struct {
2540
net.Conn
@@ -32,10 +47,16 @@ func (c *legacyConnection) Connect(addr string, timeout time.Duration) error {
3247
return err
3348
}
3449

35-
c.Conn, err = net.DialTimeout("tcp", addr, timeout)
50+
conn, err := net.DialTimeout("tcp", addr, timeout)
3651
if err != nil {
3752
return fmt.Errorf("legacy connection: dial: %w", err)
3853
}
54+
55+
c.Conn = &writeTimeoutConn{
56+
Conn: conn,
57+
timeout: timeout,
58+
}
59+
3960
return nil
4061
}
4162

@@ -53,10 +74,16 @@ func (c *sshConnection) Connect(addr string, timeout time.Duration) error {
5374
return err
5475
}
5576

56-
if c.Conn, err = net.DialTimeout("tcp", addr, timeout); err != nil {
77+
conn, err := net.DialTimeout("tcp", addr, timeout)
78+
if err != nil {
5779
return fmt.Errorf("ssh connection: dial: %w", err)
5880
}
5981

82+
c.Conn = &writeTimeoutConn{
83+
Conn: conn,
84+
timeout: timeout,
85+
}
86+
6087
clientConn, chans, reqs, err := ssh.NewClientConn(c.Conn, addr, c.config)
6188
if err != nil {
6289
return fmt.Errorf("ssh connecion: ssh client conn: %w", err)

0 commit comments

Comments
 (0)