@@ -83,20 +83,18 @@ const (
8383type Conn struct {
8484 conn net.Conn
8585 isTLS bool
86- closeCount uint32
86+ closing uint32
8787 closeErr atomicValue
8888 isStartingTLS bool
8989 Debug debugging
90- chanConfirm chan bool
90+ chanConfirm chan struct {}
9191 messageContexts map [int64 ]* messageContext
9292 chanMessage chan * messagePacket
9393 chanMessageID chan int64
94- wgSender sync.WaitGroup
9594 wgClose sync.WaitGroup
96- once sync.Once
9795 outstandingRequests uint
9896 messageMutex sync.Mutex
99- requestTimeout time. Duration
97+ requestTimeout int64
10098}
10199
102100var _ Client = & Conn {}
@@ -143,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
143141func NewConn (conn net.Conn , isTLS bool ) * Conn {
144142 return & Conn {
145143 conn : conn ,
146- chanConfirm : make (chan bool ),
144+ chanConfirm : make (chan struct {} ),
147145 chanMessageID : make (chan int64 ),
148146 chanMessage : make (chan * messagePacket , 10 ),
149147 messageContexts : map [int64 ]* messageContext {},
@@ -161,48 +159,46 @@ func (l *Conn) Start() {
161159
162160// isClosing returns whether or not we're currently closing.
163161func (l * Conn ) isClosing () bool {
164- return atomic .LoadUint32 (& l .closeCount ) > 0
162+ return atomic .LoadUint32 (& l .closing ) == 1
165163}
166164
167165// setClosing sets the closing value to true
168- func (l * Conn ) setClosing () {
169- atomic .AddUint32 (& l .closeCount , 1 )
166+ func (l * Conn ) setClosing () bool {
167+ return atomic .CompareAndSwapUint32 (& l .closing , 0 , 1 )
170168}
171169
172170// Close closes the connection.
173171func (l * Conn ) Close () {
174- l .once .Do (func () {
175- l .setClosing ()
176- l .wgSender .Wait ()
172+ l .messageMutex .Lock ()
173+ defer l .messageMutex .Unlock ()
177174
175+ if l .setClosing () {
178176 l .Debug .Printf ("Sending quit message and waiting for confirmation" )
179177 l .chanMessage <- & messagePacket {Op : MessageQuit }
180178 <- l .chanConfirm
181179 close (l .chanMessage )
182180
183181 l .Debug .Printf ("Closing network connection" )
184182 if err := l .conn .Close (); err != nil {
185- log .Print (err )
183+ log .Println (err )
186184 }
187185
188186 l .wgClose .Done ()
189- })
187+ }
190188 l .wgClose .Wait ()
191189}
192190
193191// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
194192func (l * Conn ) SetTimeout (timeout time.Duration ) {
195193 if timeout > 0 {
196- l .requestTimeout = timeout
194+ atomic . StoreInt64 ( & l .requestTimeout , int64 ( timeout ))
197195 }
198196}
199197
200198// Returns the next available messageID
201199func (l * Conn ) nextMessageID () int64 {
202- if l .chanMessageID != nil {
203- if messageID , ok := <- l .chanMessageID ; ok {
204- return messageID
205- }
200+ if messageID , ok := <- l .chanMessageID ; ok {
201+ return messageID
206202 }
207203 return 0
208204}
@@ -327,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) {
327323}
328324
329325func (l * Conn ) sendProcessMessage (message * messagePacket ) bool {
326+ l .messageMutex .Lock ()
327+ defer l .messageMutex .Unlock ()
330328 if l .isClosing () {
331329 return false
332330 }
333- l .wgSender .Add (1 )
334331 l .chanMessage <- message
335- l .wgSender .Done ()
336332 return true
337333}
338334
@@ -352,7 +348,6 @@ func (l *Conn) processMessages() {
352348 delete (l .messageContexts , messageID )
353349 }
354350 close (l .chanMessageID )
355- l .chanConfirm <- true
356351 close (l .chanConfirm )
357352 }()
358353
@@ -361,11 +356,7 @@ func (l *Conn) processMessages() {
361356 select {
362357 case l .chanMessageID <- messageID :
363358 messageID ++
364- case message , ok := <- l .chanMessage :
365- if ! ok {
366- l .Debug .Printf ("Shutting down - message channel is closed" )
367- return
368- }
359+ case message := <- l .chanMessage :
369360 switch message .Op {
370361 case MessageQuit :
371362 l .Debug .Printf ("Shutting down - quit message received" )
@@ -388,14 +379,15 @@ func (l *Conn) processMessages() {
388379 l .messageContexts [message .MessageID ] = message .Context
389380
390381 // Add timeout if defined
391- if l .requestTimeout > 0 {
382+ requestTimeout := time .Duration (atomic .LoadInt64 (& l .requestTimeout ))
383+ if requestTimeout > 0 {
392384 go func () {
393385 defer func () {
394386 if err := recover (); err != nil {
395387 log .Printf ("ldap: recovered panic in RequestTimeout: %v" , err )
396388 }
397389 }()
398- time .Sleep (l . requestTimeout )
390+ time .Sleep (requestTimeout )
399391 timeoutMessage := & messagePacket {
400392 Op : MessageTimeout ,
401393 MessageID : message .MessageID ,
0 commit comments