Skip to content

Commit 6386f12

Browse files
authored
Add SafeDial to block connections (#270)
* Add SafeDial to block connections Closes #255 * Add SafeTransport for convenience
1 parent af3da3f commit 6386f12

File tree

2 files changed

+94
-23
lines changed

2 files changed

+94
-23
lines changed

http/http.go

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@ package http
22

33
import (
44
"context"
5+
"errors"
6+
"fmt"
57
"net"
68
"net/http"
79
"net/http/httptrace"
@@ -85,6 +87,8 @@ func (no noLocalTransport) RoundTrip(req *http.Request) (*http.Response, error)
8587
return no.inner.RoundTrip(req)
8688
}
8789

90+
// SafeRoundtripper blocks requests to private ip ranges
91+
// Deprecated: use SafeTransport instead
8892
func SafeRoundtripper(trans http.RoundTripper, log logrus.FieldLogger, allowedBlocks ...*net.IPNet) http.RoundTripper {
8993
if trans == nil {
9094
trans = http.DefaultTransport
@@ -99,8 +103,65 @@ func SafeRoundtripper(trans http.RoundTripper, log logrus.FieldLogger, allowedBl
99103
return ret
100104
}
101105

106+
// SafeHTTPClient blocks requests to private ip ranges
107+
// Deprecated: use SafeTransport instead
102108
func SafeHTTPClient(client *http.Client, log logrus.FieldLogger, allowedBlocks ...*net.IPNet) *http.Client {
103109
client.Transport = SafeRoundtripper(client.Transport, log, allowedBlocks...)
104110

105111
return client
106112
}
113+
114+
// SafeTransport blocks requests to private ip ranges
115+
func SafeTransport(allowedBlocks ...*net.IPNet) *http.Transport {
116+
return &http.Transport{
117+
DialContext: SafeDial(&net.Dialer{}, allowedBlocks...),
118+
}
119+
}
120+
121+
type DialFunc func(ctx context.Context, network, address string) (net.Conn, error)
122+
123+
// SafeDial wraps a *net.Dialer and restricts connections to private ip ranges.
124+
func SafeDial(dialer *net.Dialer, allowedBlocks ...*net.IPNet) DialFunc {
125+
d := &safeDialer{dialer: dialer, allowedBlocks: allowedBlocks}
126+
return d.dialContext
127+
}
128+
129+
type safeDialer struct {
130+
allowedBlocks []*net.IPNet
131+
dialer *net.Dialer
132+
}
133+
134+
func (d *safeDialer) dialContext(ctx context.Context, network, address string) (net.Conn, error) {
135+
conn, err := d.dialer.DialContext(ctx, network, address)
136+
if err != nil {
137+
return nil, err
138+
}
139+
addr := conn.RemoteAddr().String()
140+
141+
host, _, err := net.SplitHostPort(addr)
142+
if err != nil {
143+
_ = conn.Close()
144+
return nil, fmt.Errorf("safe dialer: invalid address: %w", err)
145+
}
146+
147+
ip := net.ParseIP(host)
148+
if ip == nil {
149+
_ = conn.Close()
150+
return nil, fmt.Errorf("safe dialer: invalid ip: %v", host)
151+
}
152+
153+
for _, block := range d.allowedBlocks {
154+
if block.Contains(ip) {
155+
return conn, nil
156+
}
157+
}
158+
159+
for _, block := range privateIPBlocks {
160+
if block.Contains(ip) {
161+
_ = conn.Close()
162+
return nil, errors.New("safe dialer: private ip not allowed")
163+
}
164+
}
165+
166+
return conn, nil
167+
}

http/http_test.go

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,32 @@ func TestIsPrivateIP(t *testing.T) {
3434
}
3535
}
3636

37-
func TestSafeHTTPClient(t *testing.T) {
37+
func TestBlockList(t *testing.T) {
38+
t.Run("safe http client", func(t *testing.T) {
39+
client := SafeHTTPClient(&http.Client{}, logrus.New())
40+
testBlockList(t, client)
41+
})
42+
t.Run("safe dial", func(t *testing.T) {
43+
client := &http.Client{Transport: SafeTransport()}
44+
testBlockList(t, client)
45+
})
46+
}
47+
48+
func TestAllowList(t *testing.T) {
49+
_, local, err := net.ParseCIDR("127.0.0.1/8")
50+
require.NoError(t, err)
51+
52+
t.Run("safe http client", func(t *testing.T) {
53+
client := SafeHTTPClient(&http.Client{}, logrus.New(), local)
54+
testAllowList(t, client)
55+
})
56+
t.Run("safe dial", func(t *testing.T) {
57+
client := &http.Client{Transport: SafeTransport(local)}
58+
testAllowList(t, client)
59+
})
60+
}
61+
62+
func testBlockList(t *testing.T, client *http.Client) {
3863
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
3964
w.Write([]byte("Done"))
4065
}))
@@ -44,8 +69,6 @@ func TestSafeHTTPClient(t *testing.T) {
4469
t.Fatal(err)
4570
}
4671

47-
client := SafeHTTPClient(&http.Client{}, logrus.New())
48-
4972
// It allows accessing non-local addresses
5073
_, err = client.Get("https://google.com")
5174
require.Nil(t, err)
@@ -64,26 +87,13 @@ func TestSafeHTTPClient(t *testing.T) {
6487
assert.Nil(t, res)
6588
assert.NotNil(t, err)
6689
}
67-
68-
// It succeeds when the local IP range used by the testserver is removed from
69-
// the blacklist.
70-
ipNet := popMatchingBlock(net.ParseIP(tsURL.Hostname()))
71-
_, err = client.Get(ts.URL)
72-
assert.Nil(t, err)
73-
privateIPBlocks = append(privateIPBlocks, ipNet)
74-
75-
// It allows whitelisting for local development.
76-
client = SafeHTTPClient(&http.Client{}, logrus.New(), ipNet)
77-
_, err = client.Get(ts.URL)
78-
assert.Nil(t, err)
7990
}
8091

81-
func popMatchingBlock(ip net.IP) *net.IPNet {
82-
for i, ipNet := range privateIPBlocks {
83-
if ipNet.Contains(ip) {
84-
privateIPBlocks = append(privateIPBlocks[:i], privateIPBlocks[i+1:]...)
85-
return ipNet
86-
}
87-
}
88-
return nil
92+
func testAllowList(t *testing.T, client *http.Client) {
93+
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
94+
w.Write([]byte("Done"))
95+
}))
96+
defer ts.Close()
97+
_, err := client.Get(ts.URL)
98+
assert.NoError(t, err)
8999
}

0 commit comments

Comments
 (0)