Skip to content

Commit 0e7db8e

Browse files
committed
Merge pull request #57 from retailnext/master
Add ability to set a timeout on requests
2 parents a3bce49 + cf1b293 commit 0e7db8e

File tree

10 files changed

+200
-51
lines changed

10 files changed

+200
-51
lines changed

add.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,14 @@ func (l *Conn) Add(addRequest *AddRequest) error {
7777
defer l.finishMessage(messageID)
7878

7979
l.Debug.Printf("%d: waiting for response", messageID)
80-
packet = <-channel
80+
packetResponse, ok := <-channel
81+
if !ok {
82+
return NewError(ErrorNetwork, errors.New("ldap: channel closed"))
83+
}
84+
packet, err = packetResponse.ReadPacket()
8185
l.Debug.Printf("%d: got response %p", messageID, packet)
82-
if packet == nil {
83-
return NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
86+
if err != nil {
87+
return err
8488
}
8589

8690
if l.Debug {

bind.go

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,14 @@ func (l *Conn) SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResu
6060
}
6161
defer l.finishMessage(messageID)
6262

63-
packet = <-channel
64-
if packet == nil {
65-
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
63+
packetResponse, ok := <-channel
64+
if !ok {
65+
return nil, NewError(ErrorNetwork, errors.New("ldap: channel closed"))
66+
}
67+
packet, err = packetResponse.ReadPacket()
68+
l.Debug.Printf("%d: got response %p", messageID, packet)
69+
if err != nil {
70+
return nil, err
6671
}
6772

6873
if l.Debug {
@@ -114,9 +119,14 @@ func (l *Conn) Bind(username, password string) error {
114119
}
115120
defer l.finishMessage(messageID)
116121

117-
packet = <-channel
118-
if packet == nil {
119-
return NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
122+
packetResponse, ok := <-channel
123+
if !ok {
124+
return NewError(ErrorNetwork, errors.New("ldap: channel closed"))
125+
}
126+
packet, err = packetResponse.ReadPacket()
127+
l.Debug.Printf("%d: got response %p", messageID, packet)
128+
if err != nil {
129+
return err
120130
}
121131

122132
if l.Debug {

client.go

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,16 @@
11
package ldap
22

3-
import "crypto/tls"
3+
import (
4+
"crypto/tls"
5+
"time"
6+
)
47

58
// Client knows how to interact with an LDAP server
69
type Client interface {
710
Start()
811
StartTLS(config *tls.Config) error
912
Close()
13+
SetTimeout(time.Duration)
1014

1115
Bind(username, password string) error
1216
SimpleBind(simpleBindRequest *SimpleBindRequest) (*SimpleBindResult, error)

compare.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -58,10 +58,14 @@ func (l *Conn) Compare(dn, attribute, value string) (bool, error) {
5858
defer l.finishMessage(messageID)
5959

6060
l.Debug.Printf("%d: waiting for response", messageID)
61-
packet = <-channel
61+
packetResponse, ok := <-channel
62+
if !ok {
63+
return false, NewError(ErrorNetwork, errors.New("ldap: channel closed"))
64+
}
65+
packet, err = packetResponse.ReadPacket()
6266
l.Debug.Printf("%d: got response %p", messageID, packet)
63-
if packet == nil {
64-
return false, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
67+
if err != nil {
68+
return false, err
6569
}
6670

6771
if l.Debug {

conn.go

Lines changed: 83 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,26 @@ const (
2121
MessageRequest = 1
2222
MessageResponse = 2
2323
MessageFinish = 3
24+
MessageTimeout = 4
2425
)
2526

27+
type PacketResponse struct {
28+
Packet *ber.Packet
29+
Error error
30+
}
31+
32+
func (pr *PacketResponse) ReadPacket() (*ber.Packet, error) {
33+
if (pr == nil) || (pr.Packet == nil && pr.Error == nil) {
34+
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve response"))
35+
}
36+
return pr.Packet, pr.Error
37+
}
38+
2639
type messagePacket struct {
2740
Op int
2841
MessageID int64
2942
Packet *ber.Packet
30-
Channel chan *ber.Packet
43+
Channel chan *PacketResponse
3144
}
3245

3346
type sendMessageFlags uint
@@ -44,14 +57,15 @@ type Conn struct {
4457
isStartingTLS bool
4558
Debug debugging
4659
chanConfirm chan bool
47-
chanResults map[int64]chan *ber.Packet
60+
chanResults map[int64]chan *PacketResponse
4861
chanMessage chan *messagePacket
4962
chanMessageID chan int64
5063
wgSender sync.WaitGroup
5164
wgClose sync.WaitGroup
5265
once sync.Once
5366
outstandingRequests uint
5467
messageMutex sync.Mutex
68+
requestTimeout time.Duration
5569
}
5670

5771
var _ Client = &Conn{}
@@ -97,12 +111,13 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
97111
// NewConn returns a new Conn using conn for network I/O.
98112
func NewConn(conn net.Conn, isTLS bool) *Conn {
99113
return &Conn{
100-
conn: conn,
101-
chanConfirm: make(chan bool),
102-
chanMessageID: make(chan int64),
103-
chanMessage: make(chan *messagePacket, 10),
104-
chanResults: map[int64]chan *ber.Packet{},
105-
isTLS: isTLS,
114+
conn: conn,
115+
chanConfirm: make(chan bool),
116+
chanMessageID: make(chan int64),
117+
chanMessage: make(chan *messagePacket, 10),
118+
chanResults: map[int64]chan *PacketResponse{},
119+
requestTimeout: 0,
120+
isTLS: isTLS,
106121
}
107122
}
108123

@@ -133,6 +148,13 @@ func (l *Conn) Close() {
133148
l.wgClose.Wait()
134149
}
135150

151+
// Sets the time after a request is sent that a MessageTimeout triggers
152+
func (l *Conn) SetTimeout(timeout time.Duration) {
153+
if timeout > 0 {
154+
l.requestTimeout = timeout
155+
}
156+
}
157+
136158
// Returns the next available messageID
137159
func (l *Conn) nextMessageID() int64 {
138160
if l.chanMessageID != nil {
@@ -167,9 +189,16 @@ func (l *Conn) StartTLS(config *tls.Config) error {
167189
}
168190

169191
l.Debug.Printf("%d: waiting for response", messageID)
170-
packet = <-channel
192+
defer l.finishMessage(messageID)
193+
packetResponse, ok := <-channel
194+
if !ok {
195+
return NewError(ErrorNetwork, errors.New("ldap: channel closed"))
196+
}
197+
packet, err = packetResponse.ReadPacket()
171198
l.Debug.Printf("%d: got response %p", messageID, packet)
172-
l.finishMessage(messageID)
199+
if err != nil {
200+
return err
201+
}
173202

174203
if l.Debug {
175204
if err := addLDAPDescriptions(packet); err != nil {
@@ -197,11 +226,11 @@ func (l *Conn) StartTLS(config *tls.Config) error {
197226
return nil
198227
}
199228

200-
func (l *Conn) sendMessage(packet *ber.Packet) (chan *ber.Packet, error) {
229+
func (l *Conn) sendMessage(packet *ber.Packet) (chan *PacketResponse, error) {
201230
return l.sendMessageWithFlags(packet, 0)
202231
}
203232

204-
func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (chan *ber.Packet, error) {
233+
func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags) (chan *PacketResponse, error) {
205234
if l.isClosing {
206235
return nil, NewError(ErrorNetwork, errors.New("ldap: connection closed"))
207236
}
@@ -223,7 +252,7 @@ func (l *Conn) sendMessageWithFlags(packet *ber.Packet, flags sendMessageFlags)
223252

224253
l.messageMutex.Unlock()
225254

226-
out := make(chan *ber.Packet)
255+
out := make(chan *PacketResponse)
227256
message := &messagePacket{
228257
Op: MessageRequest,
229258
MessageID: packet.Children[0].Value.(int64),
@@ -283,40 +312,66 @@ func (l *Conn) processMessages() {
283312
select {
284313
case l.chanMessageID <- messageID:
285314
messageID++
286-
case messagePacket, ok := <-l.chanMessage:
315+
case message, ok := <-l.chanMessage:
287316
if !ok {
288317
l.Debug.Printf("Shutting down - message channel is closed")
289318
return
290319
}
291-
switch messagePacket.Op {
320+
switch message.Op {
292321
case MessageQuit:
293322
l.Debug.Printf("Shutting down - quit message received")
294323
return
295324
case MessageRequest:
296325
// Add to message list and write to network
297-
l.Debug.Printf("Sending message %d", messagePacket.MessageID)
298-
l.chanResults[messagePacket.MessageID] = messagePacket.Channel
299-
// go routine
300-
buf := messagePacket.Packet.Bytes()
326+
l.Debug.Printf("Sending message %d", message.MessageID)
327+
l.chanResults[message.MessageID] = message.Channel
301328

329+
buf := message.Packet.Bytes()
302330
_, err := l.conn.Write(buf)
303331
if err != nil {
304332
l.Debug.Printf("Error Sending Message: %s", err.Error())
305333
break
306334
}
335+
336+
// Add timeout if defined
337+
if l.requestTimeout > 0 {
338+
go func() {
339+
defer func() {
340+
if err := recover(); err != nil {
341+
log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
342+
}
343+
}()
344+
time.Sleep(l.requestTimeout)
345+
timeoutMessage := &messagePacket{
346+
Op: MessageTimeout,
347+
MessageID: message.MessageID,
348+
}
349+
l.sendProcessMessage(timeoutMessage)
350+
}()
351+
}
307352
case MessageResponse:
308-
l.Debug.Printf("Receiving message %d", messagePacket.MessageID)
309-
if chanResult, ok := l.chanResults[messagePacket.MessageID]; ok {
310-
chanResult <- messagePacket.Packet
353+
l.Debug.Printf("Receiving message %d", message.MessageID)
354+
if chanResult, ok := l.chanResults[message.MessageID]; ok {
355+
chanResult <- &PacketResponse{message.Packet, nil}
311356
} else {
312-
log.Printf("Received unexpected message %d", messagePacket.MessageID)
313-
ber.PrintPacket(messagePacket.Packet)
357+
log.Printf("Received unexpected message %d, %v", message.MessageID, l.isClosing)
358+
ber.PrintPacket(message.Packet)
359+
}
360+
case MessageTimeout:
361+
// Handle the timeout by closing the channel
362+
// All reads will return immediately
363+
if chanResult, ok := l.chanResults[message.MessageID]; ok {
364+
chanResult <- &PacketResponse{message.Packet, errors.New("ldap: connection timed out")}
365+
l.Debug.Printf("Receiving message timeout for %d", message.MessageID)
366+
delete(l.chanResults, message.MessageID)
367+
close(chanResult)
314368
}
315369
case MessageFinish:
316-
// Remove from message list
317-
l.Debug.Printf("Finished message %d", messagePacket.MessageID)
318-
close(l.chanResults[messagePacket.MessageID])
319-
delete(l.chanResults, messagePacket.MessageID)
370+
l.Debug.Printf("Finished message %d", message.MessageID)
371+
if chanResult, ok := l.chanResults[message.MessageID]; ok {
372+
close(chanResult)
373+
delete(l.chanResults, message.MessageID)
374+
}
320375
}
321376
}
322377
}

conn_test.go

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
package ldap
2+
3+
import (
4+
"net"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
"time"
9+
10+
"gopkg.in/asn1-ber.v1"
11+
)
12+
13+
func TestUnresponsiveConnection(t *testing.T) {
14+
// The do-nothing server that accepts requests and does nothing
15+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
16+
}))
17+
defer ts.Close()
18+
c, err := net.Dial(ts.Listener.Addr().Network(), ts.Listener.Addr().String())
19+
if err != nil {
20+
t.Fatalf("error connecting to localhost tcp: %v", err)
21+
}
22+
23+
// Create an Ldap connection
24+
conn := NewConn(c, false)
25+
conn.SetTimeout(time.Millisecond)
26+
conn.Start()
27+
defer conn.Close()
28+
29+
// Mock a packet
30+
messageID := conn.nextMessageID()
31+
packet := ber.Encode(ber.ClassUniversal, ber.TypeConstructed, ber.TagSequence, nil, "LDAP Request")
32+
packet.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, messageID, "MessageID"))
33+
bindRequest := ber.Encode(ber.ClassApplication, ber.TypeConstructed, ApplicationBindRequest, nil, "Bind Request")
34+
bindRequest.AppendChild(ber.NewInteger(ber.ClassUniversal, ber.TypePrimitive, ber.TagInteger, 3, "Version"))
35+
packet.AppendChild(bindRequest)
36+
37+
// Send packet and test response
38+
channel, err := conn.sendMessage(packet)
39+
if err != nil {
40+
t.Fatalf("error sending message: %v", err)
41+
}
42+
packetResponse, ok := <-channel
43+
if !ok {
44+
t.Fatalf("no PacketResponse in response channel")
45+
}
46+
packet, err = packetResponse.ReadPacket()
47+
if err == nil {
48+
t.Fatalf("expected timeout error")
49+
}
50+
if err.Error() != "ldap: connection timed out" {
51+
t.Fatalf("unexpected error: %v", err)
52+
}
53+
}

del.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,14 @@ func (l *Conn) Del(delRequest *DelRequest) error {
5252
defer l.finishMessage(messageID)
5353

5454
l.Debug.Printf("%d: waiting for response", messageID)
55-
packet = <-channel
55+
packetResponse, ok := <-channel
56+
if !ok {
57+
return NewError(ErrorNetwork, errors.New("ldap: channel closed"))
58+
}
59+
packet, err = packetResponse.ReadPacket()
5660
l.Debug.Printf("%d: got response %p", messageID, packet)
57-
if packet == nil {
58-
return NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
61+
if err != nil {
62+
return err
5963
}
6064

6165
if l.Debug {

modify.go

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -129,10 +129,14 @@ func (l *Conn) Modify(modifyRequest *ModifyRequest) error {
129129
defer l.finishMessage(messageID)
130130

131131
l.Debug.Printf("%d: waiting for response", messageID)
132-
packet = <-channel
132+
packetResponse, ok := <-channel
133+
if !ok {
134+
return NewError(ErrorNetwork, errors.New("ldap: channel closed"))
135+
}
136+
packet, err = packetResponse.ReadPacket()
133137
l.Debug.Printf("%d: got response %p", messageID, packet)
134-
if packet == nil {
135-
return NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))
138+
if err != nil {
139+
return err
136140
}
137141

138142
if l.Debug {

passwdmodify.go

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -98,8 +98,15 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa
9898
result := &PasswordModifyResult{}
9999

100100
l.Debug.Printf("%d: waiting for response", messageID)
101-
packet = <-channel
101+
packetResponse, ok := <-channel
102+
if !ok {
103+
return nil, NewError(ErrorNetwork, errors.New("ldap: channel closed"))
104+
}
105+
packet, err = packetResponse.ReadPacket()
102106
l.Debug.Printf("%d: got response %p", messageID, packet)
107+
if err != nil {
108+
return nil, err
109+
}
103110

104111
if packet == nil {
105112
return nil, NewError(ErrorNetwork, errors.New("ldap: could not retrieve message"))

0 commit comments

Comments
 (0)