@@ -34,7 +34,32 @@ func TestIsPrivateIP(t *testing.T) {
34
34
}
35
35
}
36
36
37
- func TestSafeHTTPClient (t * testing.T ) {
37
+ func TestBlockList (t * testing.T ) {
38
+ t .Run ("safe http client" , func (t * testing.T ) {
39
+ client := SafeHTTPClient (& http.Client {}, logrus .New ())
40
+ testBlockList (t , client )
41
+ })
42
+ t .Run ("safe dial" , func (t * testing.T ) {
43
+ client := & http.Client {Transport : SafeTransport ()}
44
+ testBlockList (t , client )
45
+ })
46
+ }
47
+
48
+ func TestAllowList (t * testing.T ) {
49
+ _ , local , err := net .ParseCIDR ("127.0.0.1/8" )
50
+ require .NoError (t , err )
51
+
52
+ t .Run ("safe http client" , func (t * testing.T ) {
53
+ client := SafeHTTPClient (& http.Client {}, logrus .New (), local )
54
+ testAllowList (t , client )
55
+ })
56
+ t .Run ("safe dial" , func (t * testing.T ) {
57
+ client := & http.Client {Transport : SafeTransport (local )}
58
+ testAllowList (t , client )
59
+ })
60
+ }
61
+
62
+ func testBlockList (t * testing.T , client * http.Client ) {
38
63
ts := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
39
64
w .Write ([]byte ("Done" ))
40
65
}))
@@ -44,8 +69,6 @@ func TestSafeHTTPClient(t *testing.T) {
44
69
t .Fatal (err )
45
70
}
46
71
47
- client := SafeHTTPClient (& http.Client {}, logrus .New ())
48
-
49
72
// It allows accessing non-local addresses
50
73
_ , err = client .Get ("https://google.com" )
51
74
require .Nil (t , err )
@@ -64,26 +87,13 @@ func TestSafeHTTPClient(t *testing.T) {
64
87
assert .Nil (t , res )
65
88
assert .NotNil (t , err )
66
89
}
67
-
68
- // It succeeds when the local IP range used by the testserver is removed from
69
- // the blacklist.
70
- ipNet := popMatchingBlock (net .ParseIP (tsURL .Hostname ()))
71
- _ , err = client .Get (ts .URL )
72
- assert .Nil (t , err )
73
- privateIPBlocks = append (privateIPBlocks , ipNet )
74
-
75
- // It allows whitelisting for local development.
76
- client = SafeHTTPClient (& http.Client {}, logrus .New (), ipNet )
77
- _ , err = client .Get (ts .URL )
78
- assert .Nil (t , err )
79
90
}
80
91
81
- func popMatchingBlock (ip net.IP ) * net.IPNet {
82
- for i , ipNet := range privateIPBlocks {
83
- if ipNet .Contains (ip ) {
84
- privateIPBlocks = append (privateIPBlocks [:i ], privateIPBlocks [i + 1 :]... )
85
- return ipNet
86
- }
87
- }
88
- return nil
92
+ func testAllowList (t * testing.T , client * http.Client ) {
93
+ ts := httptest .NewServer (http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
94
+ w .Write ([]byte ("Done" ))
95
+ }))
96
+ defer ts .Close ()
97
+ _ , err := client .Get (ts .URL )
98
+ assert .NoError (t , err )
89
99
}
0 commit comments