Skip to content

Commit bb7a9ca

Browse files
authored
Merge pull request #134 from judwhite/feature/fix-race (#141)
fix race conditions in conn.go
1 parent 37f35d7 commit bb7a9ca

File tree

11 files changed

+46
-52
lines changed

11 files changed

+46
-52
lines changed

.travis.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11
language: go
22
env:
33
global:
4-
- VET_VERSIONS="1.6 1.7 tip"
5-
- LINT_VERSIONS="1.6 1.7 tip"
4+
- VET_VERSIONS="1.6 1.7 1.8 1.9 tip"
5+
- LINT_VERSIONS="1.6 1.7 1.8 1.9 tip"
66
go:
77
- 1.2
88
- 1.3
99
- 1.4
1010
- 1.5
1111
- 1.6
1212
- 1.7
13+
- 1.8
14+
- 1.9
1315
- tip
1416
matrix:
1517
fast_finish: true

Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ IS_OLD_GO := $(shell test $(GO_VERSION) -le 2 && echo true)
77
ifeq ($(IS_OLD_GO),true)
88
RACE_FLAG :=
99
else
10-
RACE_FLAG := -race
10+
RACE_FLAG := -race -cpu 1,2,4
1111
endif
1212

1313
default: fmt vet lint build quicktest

conn.go

Lines changed: 21 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -83,20 +83,18 @@ const (
8383
type Conn struct {
8484
conn net.Conn
8585
isTLS bool
86-
closeCount uint32
86+
closing uint32
8787
closeErr atomicValue
8888
isStartingTLS bool
8989
Debug debugging
90-
chanConfirm chan bool
90+
chanConfirm chan struct{}
9191
messageContexts map[int64]*messageContext
9292
chanMessage chan *messagePacket
9393
chanMessageID chan int64
94-
wgSender sync.WaitGroup
9594
wgClose sync.WaitGroup
96-
once sync.Once
9795
outstandingRequests uint
9896
messageMutex sync.Mutex
99-
requestTimeout time.Duration
97+
requestTimeout int64
10098
}
10199

102100
var _ Client = &Conn{}
@@ -143,7 +141,7 @@ func DialTLS(network, addr string, config *tls.Config) (*Conn, error) {
143141
func NewConn(conn net.Conn, isTLS bool) *Conn {
144142
return &Conn{
145143
conn: conn,
146-
chanConfirm: make(chan bool),
144+
chanConfirm: make(chan struct{}),
147145
chanMessageID: make(chan int64),
148146
chanMessage: make(chan *messagePacket, 10),
149147
messageContexts: map[int64]*messageContext{},
@@ -161,48 +159,46 @@ func (l *Conn) Start() {
161159

162160
// isClosing returns whether or not we're currently closing.
163161
func (l *Conn) isClosing() bool {
164-
return atomic.LoadUint32(&l.closeCount) > 0
162+
return atomic.LoadUint32(&l.closing) == 1
165163
}
166164

167165
// setClosing sets the closing value to true
168-
func (l *Conn) setClosing() {
169-
atomic.AddUint32(&l.closeCount, 1)
166+
func (l *Conn) setClosing() bool {
167+
return atomic.CompareAndSwapUint32(&l.closing, 0, 1)
170168
}
171169

172170
// Close closes the connection.
173171
func (l *Conn) Close() {
174-
l.once.Do(func() {
175-
l.setClosing()
176-
l.wgSender.Wait()
172+
l.messageMutex.Lock()
173+
defer l.messageMutex.Unlock()
177174

175+
if l.setClosing() {
178176
l.Debug.Printf("Sending quit message and waiting for confirmation")
179177
l.chanMessage <- &messagePacket{Op: MessageQuit}
180178
<-l.chanConfirm
181179
close(l.chanMessage)
182180

183181
l.Debug.Printf("Closing network connection")
184182
if err := l.conn.Close(); err != nil {
185-
log.Print(err)
183+
log.Println(err)
186184
}
187185

188186
l.wgClose.Done()
189-
})
187+
}
190188
l.wgClose.Wait()
191189
}
192190

193191
// SetTimeout sets the time after a request is sent that a MessageTimeout triggers
194192
func (l *Conn) SetTimeout(timeout time.Duration) {
195193
if timeout > 0 {
196-
l.requestTimeout = timeout
194+
atomic.StoreInt64(&l.requestTimeout, int64(timeout))
197195
}
198196
}
199197

200198
// Returns the next available messageID
201199
func (l *Conn) nextMessageID() int64 {
202-
if l.chanMessageID != nil {
203-
if messageID, ok := <-l.chanMessageID; ok {
204-
return messageID
205-
}
200+
if messageID, ok := <-l.chanMessageID; ok {
201+
return messageID
206202
}
207203
return 0
208204
}
@@ -327,12 +323,12 @@ func (l *Conn) finishMessage(msgCtx *messageContext) {
327323
}
328324

329325
func (l *Conn) sendProcessMessage(message *messagePacket) bool {
326+
l.messageMutex.Lock()
327+
defer l.messageMutex.Unlock()
330328
if l.isClosing() {
331329
return false
332330
}
333-
l.wgSender.Add(1)
334331
l.chanMessage <- message
335-
l.wgSender.Done()
336332
return true
337333
}
338334

@@ -352,7 +348,6 @@ func (l *Conn) processMessages() {
352348
delete(l.messageContexts, messageID)
353349
}
354350
close(l.chanMessageID)
355-
l.chanConfirm <- true
356351
close(l.chanConfirm)
357352
}()
358353

@@ -361,11 +356,7 @@ func (l *Conn) processMessages() {
361356
select {
362357
case l.chanMessageID <- messageID:
363358
messageID++
364-
case message, ok := <-l.chanMessage:
365-
if !ok {
366-
l.Debug.Printf("Shutting down - message channel is closed")
367-
return
368-
}
359+
case message := <-l.chanMessage:
369360
switch message.Op {
370361
case MessageQuit:
371362
l.Debug.Printf("Shutting down - quit message received")
@@ -388,14 +379,15 @@ func (l *Conn) processMessages() {
388379
l.messageContexts[message.MessageID] = message.Context
389380

390381
// Add timeout if defined
391-
if l.requestTimeout > 0 {
382+
requestTimeout := time.Duration(atomic.LoadInt64(&l.requestTimeout))
383+
if requestTimeout > 0 {
392384
go func() {
393385
defer func() {
394386
if err := recover(); err != nil {
395387
log.Printf("ldap: recovered panic in RequestTimeout: %v", err)
396388
}
397389
}()
398-
time.Sleep(l.requestTimeout)
390+
time.Sleep(requestTimeout)
399391
timeoutMessage := &messagePacket{
400392
Op: MessageTimeout,
401393
MessageID: message.MessageID,

conn_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -188,7 +188,7 @@ func runWithTimeout(t *testing.T, timeout time.Duration, f func()) {
188188
}
189189
}
190190

191-
// packetTranslatorConn is a helful type which can be used with various tests
191+
// packetTranslatorConn is a helpful type which can be used with various tests
192192
// in this package. It implements the net.Conn interface to be used as an
193193
// underlying connection for a *ldap.Conn. Most methods are no-ops but the
194194
// Read() and Write() methods are able to translate ber-encoded packets for
@@ -241,7 +241,7 @@ func (c *packetTranslatorConn) Read(b []byte) (n int, err error) {
241241
}
242242

243243
// SendResponse writes the given response packet to the response buffer for
244-
// this conection, signalling any goroutine waiting to read a response.
244+
// this connection, signalling any goroutine waiting to read a response.
245245
func (c *packetTranslatorConn) SendResponse(packet *ber.Packet) error {
246246
c.lock.Lock()
247247
defer c.lock.Unlock()

debug.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ import (
66
"gopkg.in/asn1-ber.v1"
77
)
88

9-
// debbuging type
9+
// debugging type
1010
// - has a Printf method to write the debug output
1111
type debugging bool
1212

dn.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
// Use of this source code is governed by a BSD-style
33
// license that can be found in the LICENSE file.
44
//
5-
// File contains DN parsing functionallity
5+
// File contains DN parsing functionality
66
//
77
// https://tools.ietf.org/html/rfc4514
88
//
@@ -52,7 +52,7 @@ import (
5252
"fmt"
5353
"strings"
5454

55-
ber "gopkg.in/asn1-ber.v1"
55+
"gopkg.in/asn1-ber.v1"
5656
)
5757

5858
// AttributeTypeAndValue represents an attributeTypeAndValue from https://tools.ietf.org/html/rfc4514

error_test.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ func TestConnReadErr(t *testing.T) {
4949
// Send the signal after a short amount of time.
5050
time.AfterFunc(10*time.Millisecond, func() { conn.signals <- expectedError })
5151

52-
// This should block until the underlyiny conn gets the error signal
52+
// This should block until the underlying conn gets the error signal
5353
// which should bubble up through the reader() goroutine, close the
5454
// connection, and
5555
_, err := ldapConn.Search(searchReq)
@@ -58,7 +58,7 @@ func TestConnReadErr(t *testing.T) {
5858
}
5959
}
6060

61-
// signalErrConn is a helful type used with TestConnReadErr. It implements the
61+
// signalErrConn is a helpful type used with TestConnReadErr. It implements the
6262
// net.Conn interface to be used as a connection for the test. Most methods are
6363
// no-ops but the Read() method blocks until it receives a signal which it
6464
// returns as an error.

example_test.go

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
)
1010

1111
// ExampleConn_Bind demonstrates how to bind a connection to an ldap user
12-
// allowing access to restricted attrabutes that user has access to
12+
// allowing access to restricted attributes that user has access to
1313
func ExampleConn_Bind() {
1414
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389))
1515
if err != nil {
@@ -63,10 +63,10 @@ func ExampleConn_StartTLS() {
6363
log.Fatal(err)
6464
}
6565

66-
// Opertations via l are now encrypted
66+
// Operations via l are now encrypted
6767
}
6868

69-
// ExampleConn_Compare demonstrates how to comapre an attribute with a value
69+
// ExampleConn_Compare demonstrates how to compare an attribute with a value
7070
func ExampleConn_Compare() {
7171
l, err := ldap.Dial("tcp", fmt.Sprintf("%s:%d", "ldap.example.com", 389))
7272
if err != nil {
@@ -215,7 +215,7 @@ func Example_userAuthentication() {
215215
log.Fatal(err)
216216
}
217217

218-
// Rebind as the read only user for any futher queries
218+
// Rebind as the read only user for any further queries
219219
err = l.Bind(bindusername, bindpassword)
220220
if err != nil {
221221
log.Fatal(err)
@@ -240,7 +240,7 @@ func Example_beherappolicy() {
240240
if ppolicyControl != nil {
241241
ppolicy = ppolicyControl.(*ldap.ControlBeheraPasswordPolicy)
242242
} else {
243-
log.Printf("ppolicyControl response not avaliable.\n")
243+
log.Printf("ppolicyControl response not available.\n")
244244
}
245245
if err != nil {
246246
errStr := "ERROR: Cannot bind: " + err.Error()

ldap.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ import (
99
"io/ioutil"
1010
"os"
1111

12-
ber "gopkg.in/asn1-ber.v1"
12+
"gopkg.in/asn1-ber.v1"
1313
)
1414

1515
// LDAP Application Codes

passwdmodify.go

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -135,10 +135,10 @@ func (l *Conn) PasswordModify(passwordModifyRequest *PasswordModifyRequest) (*Pa
135135
extendedResponse := packet.Children[1]
136136
for _, child := range extendedResponse.Children {
137137
if child.Tag == 11 {
138-
passwordModifyReponseValue := ber.DecodePacket(child.Data.Bytes())
139-
if len(passwordModifyReponseValue.Children) == 1 {
140-
if passwordModifyReponseValue.Children[0].Tag == 0 {
141-
result.GeneratedPassword = ber.DecodeString(passwordModifyReponseValue.Children[0].Data.Bytes())
138+
passwordModifyResponseValue := ber.DecodePacket(child.Data.Bytes())
139+
if len(passwordModifyResponseValue.Children) == 1 {
140+
if passwordModifyResponseValue.Children[0].Tag == 0 {
141+
result.GeneratedPassword = ber.DecodeString(passwordModifyResponseValue.Children[0].Data.Bytes())
142142
}
143143
}
144144
}

0 commit comments

Comments
 (0)