@@ -26,7 +26,7 @@ const (
26
26
connOpenInOrder
27
27
)
28
28
29
- func dial (secure , skipVerify bool , hosts []string , noDelay bool , openStrategy openStrategy , logf func (string , ... interface {})) (* connect , error ) {
29
+ func dial (secure , skipVerify bool , hosts []string , readTimeout , writeTimeout time. Duration , noDelay bool , openStrategy openStrategy , logf func (string , ... interface {})) (* connect , error ) {
30
30
var (
31
31
err error
32
32
abs = func (v int ) int {
@@ -66,10 +66,12 @@ func dial(secure, skipVerify bool, hosts []string, noDelay bool, openStrategy op
66
66
tcp .SetNoDelay (noDelay ) // Disable or enable the Nagle Algorithm for this tcp socket
67
67
}
68
68
return & connect {
69
- Conn : conn ,
70
- logf : logf ,
71
- ident : ident ,
72
- buffer : bufio .NewReaderSize (conn , 4 * 1024 * 1024 ),
69
+ Conn : conn ,
70
+ logf : logf ,
71
+ ident : ident ,
72
+ buffer : bufio .NewReaderSize (conn , 4 * 1024 * 1024 ),
73
+ readTimeout : readTimeout ,
74
+ writeTimeout : writeTimeout ,
73
75
}, nil
74
76
}
75
77
}
@@ -78,10 +80,14 @@ func dial(secure, skipVerify bool, hosts []string, noDelay bool, openStrategy op
78
80
79
81
type connect struct {
80
82
net.Conn
81
- logf func (string , ... interface {})
82
- ident int
83
- buffer * bufio.Reader
84
- closed bool
83
+ logf func (string , ... interface {})
84
+ ident int
85
+ buffer * bufio.Reader
86
+ closed bool
87
+ readTimeout time.Duration
88
+ writeTimeout time.Duration
89
+ lastReadDeadlineTime time.Time
90
+ lastWriteDeadlineTime time.Time
85
91
}
86
92
87
93
func (conn * connect ) Read (b []byte ) (int , error ) {
@@ -91,10 +97,14 @@ func (conn *connect) Read(b []byte) (int, error) {
91
97
total int
92
98
dstLen = len (b )
93
99
)
100
+ if currentTime := now (); conn .readTimeout != 0 && currentTime .Sub (conn .lastReadDeadlineTime ) > (conn .readTimeout >> 2 ) {
101
+ conn .SetReadDeadline (time .Now ().Add (conn .readTimeout ))
102
+ conn .lastReadDeadlineTime = currentTime
103
+ }
94
104
for total < dstLen {
95
105
if n , err = conn .buffer .Read (b [total :]); err != nil {
96
106
conn .logf ("[connect] read error: %v" , err )
97
- conn .closed = true
107
+ conn .Close ()
98
108
return n , driver .ErrBadConn
99
109
}
100
110
total += n
@@ -109,10 +119,14 @@ func (conn *connect) Write(b []byte) (int, error) {
109
119
total int
110
120
srcLen = len (b )
111
121
)
122
+ if currentTime := now (); conn .writeTimeout != 0 && currentTime .Sub (conn .lastWriteDeadlineTime ) > (conn .writeTimeout >> 2 ) {
123
+ conn .SetWriteDeadline (time .Now ().Add (conn .writeTimeout ))
124
+ conn .lastWriteDeadlineTime = currentTime
125
+ }
112
126
for total < srcLen {
113
127
if n , err = conn .Conn .Write (b [total :]); err != nil {
114
128
conn .logf ("[connect] write error: %v" , err )
115
- conn .closed = true
129
+ conn .Close ()
116
130
return n , driver .ErrBadConn
117
131
}
118
132
total += n
0 commit comments