Skip to content

Commit bd86d3a

Browse files
committed
Batch function
1 parent c572b53 commit bd86d3a

File tree

2 files changed

+141
-0
lines changed

2 files changed

+141
-0
lines changed

pkg/database/batch.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
package database
2+
3+
import "context"
4+
5+
// Batch applies fn to successive chunks of elements, each of size at most batchSize.
6+
// A batchSize of 0 processes all the elements in one chunk.
7+
// Stops at the first error and returns it.
8+
func Batch[T any](ctx context.Context, elems []T, batchSize int, fn func(context.Context, []T) error) error {
9+
n := len(elems)
10+
11+
if n == 0 {
12+
return nil
13+
}
14+
15+
if batchSize <= 0 || batchSize > n {
16+
batchSize = n
17+
}
18+
19+
for start := 0; start < n; start += batchSize {
20+
if ctx.Err() != nil {
21+
return ctx.Err()
22+
}
23+
24+
end := min(start+batchSize, n)
25+
if err := fn(ctx, elems[start:end]); err != nil {
26+
return err
27+
}
28+
}
29+
30+
return nil
31+
}

pkg/database/batch_test.go

Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package database_test
2+
3+
import (
4+
"context"
5+
"errors"
6+
"testing"
7+
8+
"github.com/stretchr/testify/assert"
9+
"github.com/stretchr/testify/require"
10+
11+
"github.com/crowdsecurity/crowdsec/pkg/database"
12+
)
13+
14+
func TestBatch(t *testing.T) {
15+
type callRecord struct {
16+
calledBatches [][]int
17+
}
18+
19+
tests := []struct {
20+
name string
21+
elems []int
22+
batchSize int
23+
cancelCtx bool
24+
fnErrorAt int // number of batch where fn fails (0 = never)
25+
wantErr bool
26+
wantBatches [][]int
27+
}{
28+
{
29+
name: "normal batching",
30+
elems: []int{1, 2, 3, 4, 5},
31+
batchSize: 2,
32+
wantBatches: [][]int{{1, 2}, {3, 4}, {5}},
33+
},
34+
{
35+
name: "batchSize zero = all in one batch",
36+
elems: []int{1, 2, 3},
37+
batchSize: 0,
38+
wantBatches: [][]int{{1, 2, 3}},
39+
},
40+
{
41+
name: "batchSize > len(elems)",
42+
elems: []int{1, 2, 3},
43+
batchSize: 10,
44+
wantBatches: [][]int{{1, 2, 3}},
45+
},
46+
{
47+
name: "empty input",
48+
elems: []int{},
49+
batchSize: 3,
50+
wantBatches: nil,
51+
},
52+
{
53+
name: "nil input",
54+
elems: nil,
55+
batchSize: 3,
56+
wantBatches: nil,
57+
},
58+
{
59+
name: "error in fn",
60+
elems: []int{1, 2, 3, 4},
61+
batchSize: 2,
62+
fnErrorAt: 2,
63+
wantErr: true,
64+
wantBatches: [][]int{{1, 2}},
65+
},
66+
{
67+
name: "context canceled before loop",
68+
elems: []int{1, 2, 3},
69+
batchSize: 2,
70+
cancelCtx: true,
71+
wantErr: true,
72+
wantBatches: nil,
73+
},
74+
}
75+
76+
for _, tc := range tests {
77+
t.Run(tc.name, func(t *testing.T) {
78+
var rec callRecord
79+
ctx := t.Context()
80+
81+
// not testing a cancel _between_ batches, this should be enough
82+
if tc.cancelCtx {
83+
canceled, cancel := context.WithCancel(ctx)
84+
cancel()
85+
ctx = canceled
86+
}
87+
88+
err := database.Batch(ctx, tc.elems, tc.batchSize, func(_ context.Context, batch []int) error {
89+
if len(rec.calledBatches) == tc.fnErrorAt-1 {
90+
return errors.New("simulated error")
91+
}
92+
93+
rec.calledBatches = append(rec.calledBatches, batch)
94+
95+
return nil
96+
})
97+
98+
switch {
99+
case tc.wantErr && tc.cancelCtx:
100+
require.ErrorContains(t, err, "context canceled")
101+
case tc.wantErr:
102+
require.ErrorContains(t, err, "simulated error")
103+
default:
104+
require.NoError(t, err)
105+
}
106+
107+
assert.Equal(t, tc.wantBatches, rec.calledBatches)
108+
})
109+
}
110+
}

0 commit comments

Comments
 (0)