Skip to content

Commit 59659b1

Browse files
committed
Add workerpool package
1 parent 2bfd780 commit 59659b1

File tree

2 files changed

+267
-0
lines changed

2 files changed

+267
-0
lines changed

workerpool/workerpool.go

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
//
2+
// Worker pool with concurrency control
3+
//
4+
5+
package workerpool
6+
7+
import (
8+
"sync"
9+
)
10+
11+
type WorkerPool struct {
12+
queue chan struct{}
13+
errors []error
14+
closed bool
15+
16+
mu sync.RWMutex
17+
wg sync.WaitGroup
18+
closer sync.Once
19+
}
20+
21+
//
22+
// Initialize new pool
23+
//
24+
25+
func New(concurrency int) *WorkerPool {
26+
return &WorkerPool{
27+
queue: make(chan struct{}, concurrency),
28+
errors: make([]error, 0),
29+
}
30+
}
31+
32+
//
33+
// Add work function to the queue
34+
//
35+
36+
func (p *WorkerPool) Submit(fn func(done func(err error)) error) {
37+
// Wait for open slot in queue
38+
p.wg.Add(1)
39+
p.queue <- struct{}{}
40+
41+
// Callback function that signals a job is done
42+
done := func(err error) {
43+
p.wg.Done()
44+
45+
p.mu.RLock()
46+
closed := p.closed
47+
p.mu.RUnlock()
48+
49+
if !closed {
50+
<-p.queue
51+
if err != nil {
52+
p.mu.Lock()
53+
p.errors = append(p.errors, err)
54+
p.mu.Unlock()
55+
}
56+
}
57+
58+
}
59+
60+
// Run dispatcher function
61+
err := fn(done)
62+
if err != nil {
63+
done(err)
64+
}
65+
}
66+
67+
//
68+
// Close queue channel
69+
//
70+
71+
func (p *WorkerPool) Close() {
72+
p.closer.Do(func() {
73+
p.mu.Lock()
74+
p.closed = true
75+
close(p.queue)
76+
p.mu.Unlock()
77+
})
78+
}
79+
80+
//
81+
// Wait for workers to finish and close
82+
//
83+
84+
func (p *WorkerPool) Wait() {
85+
p.wg.Wait()
86+
p.Close()
87+
}
88+
89+
//
90+
// Produce list of errors from finished workers
91+
//
92+
93+
func (p *WorkerPool) Errors() []error {
94+
p.Wait()
95+
return p.errors
96+
}

workerpool/workerpool_test.go

Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
package workerpool
2+
3+
import (
4+
"fmt"
5+
"sync"
6+
"testing"
7+
"time"
8+
9+
"github.com/stretchr/testify/assert"
10+
)
11+
12+
func BenchmarkPool(b *testing.B) {
13+
pool := New(10)
14+
count := 0
15+
var mu sync.Mutex
16+
17+
for i := 0; i < b.N; i++ {
18+
pool.Submit(func(done func(err error)) error {
19+
go func() {
20+
mu.Lock()
21+
count++
22+
mu.Unlock()
23+
done(nil)
24+
}()
25+
26+
return nil
27+
})
28+
}
29+
30+
pool.Wait()
31+
assert.Equal(b, count, b.N)
32+
}
33+
34+
func TestPool(t *testing.T) {
35+
for _, n := range []int{5, 15, 30} {
36+
pool := New(3)
37+
count := 0
38+
var mu sync.Mutex
39+
40+
for i := 0; i < n; i++ {
41+
pool.Submit(func(done func(err error)) error {
42+
go func() {
43+
mu.Lock()
44+
count++
45+
mu.Unlock()
46+
done(nil)
47+
}()
48+
49+
return nil
50+
})
51+
}
52+
53+
pool.Wait()
54+
assert.Equal(t, count, n)
55+
assert.Empty(t, pool.Errors())
56+
}
57+
}
58+
59+
func TestPoolErrorReturn(t *testing.T) {
60+
pool := New(3)
61+
62+
for i := 0; i < 10; i++ {
63+
pool.Submit(func(done func(err error)) error {
64+
return fmt.Errorf("big error")
65+
})
66+
}
67+
68+
errors := pool.Errors()
69+
assert.Equal(t, len(errors), 10)
70+
for _, err := range errors {
71+
assert.Equal(t, err.Error(), "big error")
72+
}
73+
}
74+
75+
func TestPoolErrorAsync(t *testing.T) {
76+
pool := New(3)
77+
78+
for i := 0; i < 10; i++ {
79+
pool.Submit(func(done func(err error)) error {
80+
go func() {
81+
done(fmt.Errorf("big error"))
82+
}()
83+
84+
return nil
85+
})
86+
}
87+
88+
errors := pool.Errors()
89+
assert.Equal(t, len(errors), 10)
90+
for _, err := range errors {
91+
assert.Equal(t, err.Error(), "big error")
92+
}
93+
}
94+
95+
func TestPoolSingle(t *testing.T) {
96+
for _, n := range []int{1, 2, 3} {
97+
pool := New(1)
98+
start := time.Now()
99+
100+
for i := 0; i < n; i++ {
101+
pool.Submit(func(done func(err error)) error {
102+
go func() {
103+
time.Sleep(100 * time.Millisecond)
104+
done(nil)
105+
}()
106+
107+
return nil
108+
})
109+
}
110+
111+
pool.Wait()
112+
delta := time.Since(start)
113+
floor := time.Duration(n*100) * time.Millisecond
114+
ceil := time.Duration((n+1)*100) * time.Millisecond
115+
116+
assert.GreaterOrEqual(t, delta, floor)
117+
assert.LessOrEqual(t, delta, ceil)
118+
}
119+
}
120+
121+
func TestPoolDouble(t *testing.T) {
122+
for _, n := range []int{2, 4, 6} {
123+
pool := New(2)
124+
start := time.Now()
125+
126+
for i := 0; i < n; i++ {
127+
pool.Submit(func(done func(err error)) error {
128+
go func() {
129+
time.Sleep(100 * time.Millisecond)
130+
done(nil)
131+
}()
132+
133+
return nil
134+
})
135+
}
136+
137+
pool.Wait()
138+
delta := time.Since(start)
139+
floor := time.Duration(n*50) * time.Millisecond
140+
ceil := time.Duration((n+2)*50) * time.Millisecond
141+
142+
assert.GreaterOrEqual(t, delta, floor)
143+
assert.LessOrEqual(t, delta, ceil)
144+
}
145+
}
146+
147+
func TestPoolTriple(t *testing.T) {
148+
for _, n := range []int{3, 6, 9} {
149+
pool := New(3)
150+
start := time.Now()
151+
152+
for i := 0; i < n; i++ {
153+
pool.Submit(func(done func(err error)) error {
154+
go func() {
155+
time.Sleep(100 * time.Millisecond)
156+
done(nil)
157+
}()
158+
159+
return nil
160+
})
161+
}
162+
163+
pool.Wait()
164+
delta := time.Since(start)
165+
floor := time.Duration(n*33) * time.Millisecond
166+
ceil := time.Duration((n+3)*33) * time.Millisecond
167+
168+
assert.GreaterOrEqual(t, delta, floor)
169+
assert.LessOrEqual(t, delta, ceil)
170+
}
171+
}

0 commit comments

Comments
 (0)