Skip to content

Commit d2219c2

Browse files
committed
Allow custom Retry/Backoff per request
This commit extends the `PerformRequestOptions` to pass a custom `Retrier` per request. This is enabled for the Scroll and Bulk API for now. See #666 and #610
1 parent 832286e commit d2219c2

File tree

6 files changed

+111
-40
lines changed

6 files changed

+111
-40
lines changed

.gitignore

+1
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ _testmain.go
2222
*.exe
2323

2424
/.vscode/
25+
/debug.test
2526
/generator
2627
/cluster-test/cluster-test
2728
/cluster-test/*.log

bulk.go

+10-1
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,8 @@ import (
2626
// See https://www.elastic.co/guide/en/elasticsearch/reference/6.0/docs-bulk.html
2727
// for more details.
2828
type BulkService struct {
29-
client *Client
29+
client *Client
30+
retrier Retrier
3031

3132
index string
3233
typ string
@@ -57,6 +58,13 @@ func (s *BulkService) reset() {
5758
s.sizeInBytesCursor = 0
5859
}
5960

61+
// Retrier allows to set specific retry logic for this BulkService.
62+
// If not specified, it will use the client's default retrier.
63+
func (s *BulkService) Retrier(retrier Retrier) *BulkService {
64+
s.retrier = retrier
65+
return s
66+
}
67+
6068
// Index specifies the index to use for all batches. You may also leave
6169
// this blank and specify the index in the individual bulk requests.
6270
func (s *BulkService) Index(index string) *BulkService {
@@ -241,6 +249,7 @@ func (s *BulkService) Do(ctx context.Context) (*BulkResponse, error) {
241249
Params: params,
242250
Body: body,
243251
ContentType: "application/x-ndjson",
252+
Retrier: s.retrier,
244253
})
245254
if err != nil {
246255
return nil, err

bulk_test.go

+24-24
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,30 @@ func TestBulkEstimateSizeInBytesLength(t *testing.T) {
482482
}
483483
}
484484

485+
func TestBulkContentType(t *testing.T) {
486+
var header http.Header
487+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
488+
header = r.Header
489+
fmt.Fprintln(w, `{}`)
490+
}))
491+
defer ts.Close()
492+
493+
client, err := NewSimpleClient(SetURL(ts.URL))
494+
if err != nil {
495+
t.Fatal(err)
496+
}
497+
indexReq := NewBulkIndexRequest().Index(testIndexName).Type("doc").Id("1").Doc(tweet{User: "olivere", Message: "Welcome to Golang and Elasticsearch."})
498+
if _, err := client.Bulk().Add(indexReq).Do(context.Background()); err != nil {
499+
t.Fatal(err)
500+
}
501+
if header == nil {
502+
t.Fatalf("expected header, got %v", header)
503+
}
504+
if want, have := "application/x-ndjson", header.Get("Content-Type"); want != have {
505+
t.Fatalf("Content-Type: want %q, have %q", want, have)
506+
}
507+
}
508+
485509
var benchmarkBulkEstimatedSizeInBytes int64
486510

487511
func BenchmarkBulkEstimatedSizeInBytesWith1Request(b *testing.B) {
@@ -516,30 +540,6 @@ func BenchmarkBulkEstimatedSizeInBytesWith100Requests(b *testing.B) {
516540
benchmarkBulkEstimatedSizeInBytes = result // ensure the compiler doesn't optimize
517541
}
518542

519-
func TestBulkContentType(t *testing.T) {
520-
var header http.Header
521-
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
522-
header = r.Header
523-
fmt.Fprintln(w, `{}`)
524-
}))
525-
defer ts.Close()
526-
527-
client, err := NewSimpleClient(SetURL(ts.URL))
528-
if err != nil {
529-
t.Fatal(err)
530-
}
531-
indexReq := NewBulkIndexRequest().Index(testIndexName).Type("doc").Id("1").Doc(tweet{User: "olivere", Message: "Welcome to Golang and Elasticsearch."})
532-
if _, err := client.Bulk().Add(indexReq).Do(context.Background()); err != nil {
533-
t.Fatal(err)
534-
}
535-
if header == nil {
536-
t.Fatalf("expected header, got %v", header)
537-
}
538-
if want, have := "application/x-ndjson", header.Get("Content-Type"); want != have {
539-
t.Fatalf("Content-Type: want %q, have %q", want, have)
540-
}
541-
}
542-
543543
func BenchmarkBulkAllocs(b *testing.B) {
544544
b.Run("1000 docs with 64 byte", func(b *testing.B) { benchmarkBulkAllocs(b, 64, 1000) })
545545
b.Run("1000 docs with 1 KiB", func(b *testing.B) { benchmarkBulkAllocs(b, 1024, 1000) })

client.go

+8-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ import (
2626

2727
const (
2828
// Version is the current version of Elastic.
29-
Version = "6.1.0"
29+
Version = "6.1.1"
3030

3131
// DefaultURL is the default endpoint of Elasticsearch on the local machine.
3232
// It is used e.g. when initializing a new Client without a specific URL.
@@ -1169,6 +1169,7 @@ type PerformRequestOptions struct {
11691169
Body interface{}
11701170
ContentType string
11711171
IgnoreErrors []int
1172+
Retrier Retrier
11721173
}
11731174

11741175
// PerformRequest does a HTTP request to Elasticsearch.
@@ -1186,6 +1187,10 @@ func (c *Client) PerformRequest(ctx context.Context, opt PerformRequestOptions)
11861187
basicAuthUsername := c.basicAuthUsername
11871188
basicAuthPassword := c.basicAuthPassword
11881189
sendGetBodyAs := c.sendGetBodyAs
1190+
retrier := c.retrier
1191+
if opt.Retrier != nil {
1192+
retrier = opt.Retrier
1193+
}
11891194
c.mu.RUnlock()
11901195

11911196
var err error
@@ -1214,7 +1219,7 @@ func (c *Client) PerformRequest(ctx context.Context, opt PerformRequestOptions)
12141219
// Force a healtcheck as all connections seem to be dead.
12151220
c.healthcheck(timeout, false)
12161221
}
1217-
wait, ok, rerr := c.retrier.Retry(ctx, n, nil, nil, err)
1222+
wait, ok, rerr := retrier.Retry(ctx, n, nil, nil, err)
12181223
if rerr != nil {
12191224
return nil, rerr
12201225
}
@@ -1270,7 +1275,7 @@ func (c *Client) PerformRequest(ctx context.Context, opt PerformRequestOptions)
12701275
}
12711276
if err != nil {
12721277
n++
1273-
wait, ok, rerr := c.retrier.Retry(ctx, n, (*http.Request)(req), res, err)
1278+
wait, ok, rerr := retrier.Retry(ctx, n, (*http.Request)(req), res, err)
12741279
if rerr != nil {
12751280
c.errorf("elastic: %s is dead", conn.URL())
12761281
conn.MarkAsDead()

retrier_test.go

+45
Original file line numberDiff line numberDiff line change
@@ -127,3 +127,48 @@ func TestRetrierWithError(t *testing.T) {
127127
t.Errorf("expected %d Retrier calls; got: %d", 1, retrier.N)
128128
}
129129
}
130+
131+
func TestRetrierOnPerformRequest(t *testing.T) {
132+
var numFailedReqs int
133+
fail := func(r *http.Request) (*http.Response, error) {
134+
numFailedReqs += 1
135+
//return &http.Response{Request: r, StatusCode: 400}, nil
136+
return nil, errors.New("request failed")
137+
}
138+
139+
tr := &failingTransport{path: "/fail", fail: fail}
140+
httpClient := &http.Client{Transport: tr}
141+
142+
defaultRetrier := &testRetrier{
143+
Retrier: NewStopRetrier(),
144+
}
145+
requestRetrier := &testRetrier{
146+
Retrier: NewStopRetrier(),
147+
}
148+
149+
client, err := NewClient(
150+
SetHttpClient(httpClient),
151+
SetHealthcheck(false),
152+
SetRetrier(defaultRetrier))
153+
if err != nil {
154+
t.Fatal(err)
155+
}
156+
157+
res, err := client.PerformRequest(context.TODO(), PerformRequestOptions{
158+
Method: "GET",
159+
Path: "/fail",
160+
Retrier: requestRetrier,
161+
})
162+
if err == nil {
163+
t.Fatal("expected error")
164+
}
165+
if res != nil {
166+
t.Fatal("expected no response")
167+
}
168+
if want, have := int64(0), defaultRetrier.N; want != have {
169+
t.Errorf("defaultRetrier: expected %d calls; got: %d", want, have)
170+
}
171+
if want, have := int64(1), requestRetrier.N; want != have {
172+
t.Errorf("requestRetrier: expected %d calls; got: %d", want, have)
173+
}
174+
}

scroll.go

+23-12
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ const (
2323
// ScrollService iterates over pages of search results from Elasticsearch.
2424
type ScrollService struct {
2525
client *Client
26+
retrier Retrier
2627
indices []string
2728
types []string
2829
keepAlive string
@@ -50,6 +51,13 @@ func NewScrollService(client *Client) *ScrollService {
5051
return builder
5152
}
5253

54+
// Retrier allows to set specific retry logic for this ScrollService.
55+
// If not specified, it will use the client's default retrier.
56+
func (s *ScrollService) Retrier(retrier Retrier) *ScrollService {
57+
s.retrier = retrier
58+
return s
59+
}
60+
5361
// Index sets the name of one or more indices to iterate over.
5462
func (s *ScrollService) Index(indices ...string) *ScrollService {
5563
if s.indices == nil {
@@ -259,10 +267,11 @@ func (s *ScrollService) Clear(ctx context.Context) error {
259267
}
260268

261269
_, err := s.client.PerformRequest(ctx, PerformRequestOptions{
262-
Method: "DELETE",
263-
Path: path,
264-
Params: params,
265-
Body: body,
270+
Method: "DELETE",
271+
Path: path,
272+
Params: params,
273+
Body: body,
274+
Retrier: s.retrier,
266275
})
267276
if err != nil {
268277
return err
@@ -289,10 +298,11 @@ func (s *ScrollService) first(ctx context.Context) (*SearchResult, error) {
289298

290299
// Get HTTP response
291300
res, err := s.client.PerformRequest(ctx, PerformRequestOptions{
292-
Method: "POST",
293-
Path: path,
294-
Params: params,
295-
Body: body,
301+
Method: "POST",
302+
Path: path,
303+
Params: params,
304+
Body: body,
305+
Retrier: s.retrier,
296306
})
297307
if err != nil {
298308
return nil, err
@@ -408,10 +418,11 @@ func (s *ScrollService) next(ctx context.Context) (*SearchResult, error) {
408418

409419
// Get HTTP response
410420
res, err := s.client.PerformRequest(ctx, PerformRequestOptions{
411-
Method: "POST",
412-
Path: path,
413-
Params: params,
414-
Body: body,
421+
Method: "POST",
422+
Path: path,
423+
Params: params,
424+
Body: body,
425+
Retrier: s.retrier,
415426
})
416427
if err != nil {
417428
return nil, err

0 commit comments

Comments
 (0)