Skip to content

Commit 32d41ee

Browse files
committed
Allow use of custom http.Client
1 parent db7766e commit 32d41ee

File tree

2 files changed

+51
-6
lines changed

2 files changed

+51
-6
lines changed

rpc.go

+6-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,11 @@ import (
1111
"time"
1212
)
1313

14+
var HttpClient = &http.Client{
15+
Transport: transport,
16+
Timeout: 30 * time.Second,
17+
}
18+
1419
var transport = &http.Transport{
1520
Dial: (&net.Dialer{
1621
Timeout: 5 * time.Second,
@@ -94,15 +99,10 @@ func netReqTyped(req *http.Request, isJson bool) ([]byte, int, error) {
9499
}
95100

96101
func netReq(req *http.Request) ([]byte, int, error) {
97-
// Send the request via a client
98-
client := &http.Client{
99-
Transport: transport,
100-
Timeout: 30 * time.Second,
101-
}
102102
var resp *http.Response
103103
var err error
104104
for i := 0; i < 3; i++ {
105-
resp, err = client.Do(req)
105+
resp, err = HttpClient.Do(req)
106106
if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
107107
// it's a transient network error so we sleep for a bit and try
108108
// again in case it's a short-lived issue

rpc_test.go

+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
package fargo
2+
3+
import (
4+
"fmt"
5+
"net/http"
6+
"net/http/httptest"
7+
"testing"
8+
9+
. "github.com/smartystreets/goconvey/convey"
10+
)
11+
12+
type roundtripper struct {
13+
TripCount int
14+
}
15+
16+
func (r *roundtripper) RoundTrip(req *http.Request) (*http.Response, error) {
17+
r.TripCount++
18+
return http.DefaultTransport.RoundTrip(req)
19+
}
20+
21+
func TestHttpClient(t *testing.T) {
22+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
23+
fmt.Fprint(w, "Hello World")
24+
}))
25+
defer server.Close()
26+
27+
Convey("Given fargo.HttpClient is set to a custom client", t, func() {
28+
rt := new(roundtripper)
29+
HttpClient = &http.Client{
30+
Transport: rt,
31+
}
32+
33+
Convey("netReq uses that client to handle requests", func() {
34+
req, err := http.NewRequest("GET", server.URL, nil)
35+
So(err, ShouldBeNil)
36+
37+
respBody, respCode, err := netReq(req)
38+
So(err, ShouldBeNil)
39+
So(respCode, ShouldEqual, 200)
40+
So(string(respBody), ShouldEqual, "Hello World")
41+
42+
So(rt.TripCount, ShouldEqual, 1)
43+
})
44+
})
45+
}

0 commit comments

Comments
 (0)