@@ -21,13 +21,26 @@ const (
2121 MessageRequest = 1
2222 MessageResponse = 2
2323 MessageFinish = 3
24+ MessageTimeout = 4
2425)
2526
27+ type PacketResponse struct {
28+ Packet * ber.Packet
29+ Error error
30+ }
31+
32+ func (pr * PacketResponse ) ReadPacket () (* ber.Packet , error ) {
33+ if (pr == nil ) || (pr .Packet == nil && pr .Error == nil ) {
34+ return nil , NewError (ErrorNetwork , errors .New ("ldap: could not retrieve response" ))
35+ }
36+ return pr .Packet , pr .Error
37+ }
38+
2639type messagePacket struct {
2740 Op int
2841 MessageID int64
2942 Packet * ber.Packet
30- Channel chan * ber. Packet
43+ Channel chan * PacketResponse
3144}
3245
3346type sendMessageFlags uint
@@ -44,14 +57,15 @@ type Conn struct {
4457 isStartingTLS bool
4558 Debug debugging
4659 chanConfirm chan bool
47- chanResults map [int64 ]chan * ber. Packet
60+ chanResults map [int64 ]chan * PacketResponse
4861 chanMessage chan * messagePacket
4962 chanMessageID chan int64
5063 wgSender sync.WaitGroup
5164 wgClose sync.WaitGroup
5265 once sync.Once
5366 outstandingRequests uint
5467 messageMutex sync.Mutex
68+ requestTimeout time.Duration
5569}
5670
5771var _ Client = & Conn {}
@@ -97,12 +111,13 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
97111// NewConn returns a new Conn using conn for network I/O.
98112func NewConn (conn net.Conn , isTLS bool ) * Conn {
99113 return & Conn {
100- conn : conn ,
101- chanConfirm : make (chan bool ),
102- chanMessageID : make (chan int64 ),
103- chanMessage : make (chan * messagePacket , 10 ),
104- chanResults : map [int64 ]chan * ber.Packet {},
105- isTLS : isTLS ,
114+ conn : conn ,
115+ chanConfirm : make (chan bool ),
116+ chanMessageID : make (chan int64 ),
117+ chanMessage : make (chan * messagePacket , 10 ),
118+ chanResults : map [int64 ]chan * PacketResponse {},
119+ requestTimeout : 0 ,
120+ isTLS : isTLS ,
106121 }
107122}
108123
@@ -133,6 +148,13 @@ func (l *Conn) Close() {
133148 l .wgClose .Wait ()
134149}
135150
151+ // Sets the time after a request is sent that a MessageTimeout triggers
152+ func (l * Conn ) SetTimeout (timeout time.Duration ) {
153+ if timeout > 0 {
154+ l .requestTimeout = timeout
155+ }
156+ }
157+
136158// Returns the next available messageID
137159func (l * Conn ) nextMessageID () int64 {
138160 if l .chanMessageID != nil {
@@ -167,9 +189,16 @@ func (l *Conn) StartTLS(config *tls.Config) error {
167189 }
168190
169191 l .Debug .Printf ("%d: waiting for response" , messageID )
170- packet = <- channel
192+ defer l .finishMessage (messageID )
193+ packetResponse , ok := <- channel
194+ if ! ok {
195+ return NewError (ErrorNetwork , errors .New ("ldap: channel closed" ))
196+ }
197+ packet , err = packetResponse .ReadPacket ()
171198 l .Debug .Printf ("%d: got response %p" , messageID , packet )
172- l .finishMessage (messageID )
199+ if err != nil {
200+ return err
201+ }
173202
174203 if l .Debug {
175204 if err := addLDAPDescriptions (packet ); err != nil {
@@ -197,11 +226,11 @@ func (l *Conn) StartTLS(config *tls.Config) error {
197226 return nil
198227}
199228
200- func (l * Conn ) sendMessage (packet * ber.Packet ) (chan * ber. Packet , error ) {
229+ func (l * Conn ) sendMessage (packet * ber.Packet ) (chan * PacketResponse , error ) {
201230 return l .sendMessageWithFlags (packet , 0 )
202231}
203232
204- func (l * Conn ) sendMessageWithFlags (packet * ber.Packet , flags sendMessageFlags ) (chan * ber. Packet , error ) {
233+ func (l * Conn ) sendMessageWithFlags (packet * ber.Packet , flags sendMessageFlags ) (chan * PacketResponse , error ) {
205234 if l .isClosing {
206235 return nil , NewError (ErrorNetwork , errors .New ("ldap: connection closed" ))
207236 }
@@ -223,7 +252,7 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags)
223252
224253 l .messageMutex .Unlock ()
225254
226- out := make (chan * ber. Packet )
255+ out := make (chan * PacketResponse )
227256 message := & messagePacket {
228257 Op : MessageRequest ,
229258 MessageID : packet .Children [0 ].Value .(int64 ),
@@ -283,40 +312,66 @@ func (l *Conn) processMessages() {
283312 select {
284313 case l .chanMessageID <- messageID :
285314 messageID ++
286- case messagePacket , ok := <- l .chanMessage :
315+ case message , ok := <- l .chanMessage :
287316 if ! ok {
288317 l .Debug .Printf ("Shutting down - message channel is closed" )
289318 return
290319 }
291- switch messagePacket .Op {
320+ switch message .Op {
292321 case MessageQuit :
293322 l .Debug .Printf ("Shutting down - quit message received" )
294323 return
295324 case MessageRequest :
296325 // Add to message list and write to network
297- l .Debug .Printf ("Sending message %d" , messagePacket .MessageID )
298- l .chanResults [messagePacket .MessageID ] = messagePacket .Channel
299- // go routine
300- buf := messagePacket .Packet .Bytes ()
326+ l .Debug .Printf ("Sending message %d" , message .MessageID )
327+ l .chanResults [message .MessageID ] = message .Channel
301328
329+ buf := message .Packet .Bytes ()
302330 _ , err := l .conn .Write (buf )
303331 if err != nil {
304332 l .Debug .Printf ("Error Sending Message: %s" , err .Error ())
305333 break
306334 }
335+
336+ // Add timeout if defined
337+ if l .requestTimeout > 0 {
338+ go func () {
339+ defer func () {
340+ if err := recover (); err != nil {
341+ log .Printf ("ldap: recovered panic in RequestTimeout: %v" , err )
342+ }
343+ }()
344+ time .Sleep (l .requestTimeout )
345+ timeoutMessage := & messagePacket {
346+ Op : MessageTimeout ,
347+ MessageID : message .MessageID ,
348+ }
349+ l .sendProcessMessage (timeoutMessage )
350+ }()
351+ }
307352 case MessageResponse :
308- l .Debug .Printf ("Receiving message %d" , messagePacket .MessageID )
309- if chanResult , ok := l .chanResults [messagePacket .MessageID ]; ok {
310- chanResult <- messagePacket .Packet
353+ l .Debug .Printf ("Receiving message %d" , message .MessageID )
354+ if chanResult , ok := l .chanResults [message .MessageID ]; ok {
355+ chanResult <- & PacketResponse { message .Packet , nil }
311356 } else {
312- log .Printf ("Received unexpected message %d" , messagePacket .MessageID )
313- ber .PrintPacket (messagePacket .Packet )
357+ log .Printf ("Received unexpected message %d, %v" , message .MessageID , l .isClosing )
358+ ber .PrintPacket (message .Packet )
359+ }
360+ case MessageTimeout :
361+ // Handle the timeout by closing the channel
362+ // All reads will return immediately
363+ if chanResult , ok := l .chanResults [message .MessageID ]; ok {
364+ chanResult <- & PacketResponse {message .Packet , errors .New ("ldap: connection timed out" )}
365+ l .Debug .Printf ("Receiving message timeout for %d" , message .MessageID )
366+ delete (l .chanResults , message .MessageID )
367+ close (chanResult )
314368 }
315369 case MessageFinish :
316- // Remove from message list
317- l .Debug .Printf ("Finished message %d" , messagePacket .MessageID )
318- close (l .chanResults [messagePacket .MessageID ])
319- delete (l .chanResults , messagePacket .MessageID )
370+ l .Debug .Printf ("Finished message %d" , message .MessageID )
371+ if chanResult , ok := l .chanResults [message .MessageID ]; ok {
372+ close (chanResult )
373+ delete (l .chanResults , message .MessageID )
374+ }
320375 }
321376 }
322377 }
0 commit comments