@@ -20,6 +20,21 @@ const (
20
20
DefaultSSHPort = 10022
21
21
)
22
22
23
+ type writeTimeoutConn struct {
24
+ net.Conn
25
+ timeout time.Duration
26
+ }
27
+
28
+ func (c * writeTimeoutConn ) Write (p []byte ) (n int , err error ) {
29
+ if err = c .Conn .SetWriteDeadline (time .Now ().Add (c .timeout )); err != nil {
30
+ return 0 , fmt .Errorf ("writeTimeoutConn: SetWriteDeadline: %w" , err )
31
+ }
32
+ if n , err = c .Conn .Write (p ); err != nil {
33
+ return n , fmt .Errorf ("writeTimeoutConn: write: %w" , err )
34
+ }
35
+ return n , nil
36
+ }
37
+
23
38
// legacyConnection is an insecure TCP connection.
24
39
type legacyConnection struct {
25
40
net.Conn
@@ -32,10 +47,16 @@ func (c *legacyConnection) Connect(addr string, timeout time.Duration) error {
32
47
return err
33
48
}
34
49
35
- c . Conn , err = net .DialTimeout ("tcp" , addr , timeout )
50
+ conn , err : = net .DialTimeout ("tcp" , addr , timeout )
36
51
if err != nil {
37
52
return fmt .Errorf ("legacy connection: dial: %w" , err )
38
53
}
54
+
55
+ c .Conn = & writeTimeoutConn {
56
+ Conn : conn ,
57
+ timeout : timeout ,
58
+ }
59
+
39
60
return nil
40
61
}
41
62
@@ -53,10 +74,16 @@ func (c *sshConnection) Connect(addr string, timeout time.Duration) error {
53
74
return err
54
75
}
55
76
56
- if c .Conn , err = net .DialTimeout ("tcp" , addr , timeout ); err != nil {
77
+ conn , err := net .DialTimeout ("tcp" , addr , timeout )
78
+ if err != nil {
57
79
return fmt .Errorf ("ssh connection: dial: %w" , err )
58
80
}
59
81
82
+ c .Conn = & writeTimeoutConn {
83
+ Conn : conn ,
84
+ timeout : timeout ,
85
+ }
86
+
60
87
clientConn , chans , reqs , err := ssh .NewClientConn (c .Conn , addr , c .config )
61
88
if err != nil {
62
89
return fmt .Errorf ("ssh connecion: ssh client conn: %w" , err )
0 commit comments