Skip to content

Commit 2690483

Browse files
committed
feat: improve behavior of http redirects
This commit modifies the Go core so that it will include "safe" headers when performing a cross-site redirect where both the original and redirected hosts are within IBM's "cloud.ibm.com" domain. Signed-off-by: Phil Adams <[email protected]>
1 parent 0b8c2b2 commit 2690483

5 files changed

+230
-11
lines changed

.travis.yml

+7
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,13 @@ go:
1010
notifications:
1111
email: false
1212

13+
addons:
14+
hosts:
15+
- region1.cloud.ibm.com
16+
- region1.notcloud.ibm.com
17+
- region2.cloud.ibm.com
18+
- region2.notcloud.ibm.com
19+
1320
env:
1421
global:
1522
- GO111MODULE=on

Makefile

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@ test:
1616

1717
lint:
1818
${LINT} run --build-tags=all
19-
DIFF=$$(${FORMATTER} -d core); if [[ -n "$$DIFF" ]]; then printf "\n$$DIFF" && exit 1; fi
19+
DIFF=$$(${FORMATTER} -d core); if [ -n "$$DIFF" ]; then printf "\n$$DIFF" && exit 1; fi
2020

2121
scan-gosec:
2222
${GOSEC} ./...

core/base_service.go

+75-6
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ import (
3838
const (
3939
headerNameUserAgent = "User-Agent"
4040
sdkName = "ibm-go-sdk-core"
41+
maxRedirects = 10
4142
)
4243

4344
// ServiceOptions is a struct of configuration values for a service.
@@ -117,7 +118,7 @@ func (service *BaseService) Clone() *BaseService {
117118
// First, copy the service options struct.
118119
serviceOptions := *service.Options
119120

120-
// Next, make a copy the service struct, then use the copy of the service options.
121+
// Next, make a copy of the service struct, then use the copy of the service options.
121122
// Note, we'll re-use the "Client" instance from the original BaseService instance.
122123
clone := *service
123124
clone.Options = &serviceOptions
@@ -234,7 +235,7 @@ func (service *BaseService) SetDefaultHeaders(headers http.Header) {
234235
// the retryable client; otherwise "client" will be stored
235236
// directly on "service".
236237
func (service *BaseService) SetHTTPClient(client *http.Client) {
237-
setMinimumTLSVersion(client)
238+
setupHTTPClient(client)
238239

239240
if isRetryableClient(service.Client) {
240241
// If "service" is currently holding a retryable client,
@@ -298,15 +299,83 @@ func (service *BaseService) IsSSLDisabled() bool {
298299
return false
299300
}
300301

301-
// setMinimumTLSVersion sets the minimum TLS version required by the client to TLS v1.2
302-
func setMinimumTLSVersion(client *http.Client) {
302+
// setupHTTPClient will configure "client" for use with the BaseService object.
303+
func setupHTTPClient(client *http.Client) {
304+
// Set the minimum TLS version to be 1.2
303305
if tr, ok := client.Transport.(*http.Transport); tr != nil && ok {
304306
if tr.TLSClientConfig == nil {
305307
tr.TLSClientConfig = &tls.Config{} // #nosec G402
306308
}
307309

308310
tr.TLSClientConfig.MinVersion = tls.VersionTLS12
309311
}
312+
313+
// Set our "CheckRedirect" function to allow safe headers to be included
314+
// in redirected requests under certain conditions.
315+
if client.CheckRedirect == nil {
316+
client.CheckRedirect = checkRedirect
317+
}
318+
}
319+
320+
// checkRedirect is used as an override for the default "CheckRedirect" function supplied
321+
// by the net/http package and implements some additional logic required by IBM SDKs.
322+
func checkRedirect(req *http.Request, via []*http.Request) error {
323+
324+
// The net/http module is implemented such that it will only include "safe" headers
325+
// ("Authorization", "WWW-Authenticate", "Cookie", "Cookie2") when redirecting a request
326+
// if the redirected host is the same host or a sub-domain of the original request's host.
327+
// Example: foo.com redirected to foo.com or bar.foo.com would work, but bar.com would not.
328+
// This "CheckRedirect" implementation will propagate "safe" headers in a redirected request
329+
// only in situations where the hosts associated with the original and redirected request URLs
330+
// are both located within the ".cloud.ibm.com" domain.
331+
332+
// First, perform the check that is done by the default CheckRedirect function
333+
// to ensure we don't exhaust our max redirect limit.
334+
if len(via) >= maxRedirects {
335+
GetLogger().Debug("Exceeded max redirects: %d", maxRedirects)
336+
return fmt.Errorf("stopped after %d redirects", maxRedirects)
337+
}
338+
339+
if len(via) > 0 {
340+
GetLogger().Debug("Detected %d prior request(s)", len(via))
341+
originalReq := via[0]
342+
redirectedReq := req
343+
GetLogger().Debug("Redirecting request from %s to %s", originalReq.URL.String(), redirectedReq.URL.String())
344+
redirectedHeader := req.Header
345+
originalHeader := via[0].Header
346+
347+
originalHost := originalReq.URL.Hostname()
348+
redirectedHost := redirectedReq.URL.Hostname()
349+
350+
if shouldCopySafeHeadersOnRedirect(originalHost, redirectedHost) {
351+
352+
// We're only concerned with "safe" headers since these are the ones that are not
353+
// propagated automatically by net/http for a "cross-site" redirect.
354+
for _, headerKey := range []string{"Authorization", "WWW-Authenticate", "Cookie", "Cookie2"} {
355+
// If the original request contains a value for "headerKey"
356+
// *and* this header is not already present in the redirected request,
357+
// then copy the value from the original request to the redirected request.
358+
if v, inOriginalRequest := originalHeader[headerKey]; inOriginalRequest {
359+
if _, inRedirectedRequest := redirectedHeader[headerKey]; !inRedirectedRequest {
360+
redirectedHeader[headerKey] = v
361+
GetLogger().Debug("Propagating header '%s' in redirected request", headerKey)
362+
}
363+
}
364+
}
365+
} else {
366+
GetLogger().Debug("Redirected request is not within the trusted domain.")
367+
}
368+
} else {
369+
GetLogger().Debug("Detected no prior requests!")
370+
}
371+
return nil
372+
}
373+
374+
// shouldCopySafeHeadersOnRedirect returns true iff safe headers should be copied
375+
// to a redirected request.
376+
func shouldCopySafeHeadersOnRedirect(fromHost, toHost string) bool {
377+
GetLogger().Debug("hosts: %s %s", fromHost, toHost)
378+
return strings.HasSuffix(fromHost, ".cloud.ibm.com") && strings.HasSuffix(toHost, ".cloud.ibm.com")
310379
}
311380

312381
// SetEnableGzipCompression sets the service's EnableGzipCompression field
@@ -693,7 +762,7 @@ func (service *BaseService) DisableRetries() {
693762
// DefaultHTTPClient returns a non-retryable http client with default configuration.
694763
func DefaultHTTPClient() *http.Client {
695764
client := cleanhttp.DefaultPooledClient()
696-
setMinimumTLSVersion(client)
765+
setupHTTPClient(client)
697766
return client
698767
}
699768

@@ -731,7 +800,7 @@ func NewRetryableClientWithHTTPClient(httpClient *http.Client) *retryablehttp.Cl
731800
// as our embedded client used to invoke individual requests.
732801
client.HTTPClient = httpClient
733802
} else {
734-
// Otherwise, we'll use construct a default HTTP client and use that
803+
// Otherwise, we'll construct a default HTTP client and use that
735804
client.HTTPClient = DefaultHTTPClient()
736805
}
737806

core/base_service_redirect_test.go

+147
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
//go:build all || slow || auth
2+
// +build all slow auth
3+
4+
package core
5+
6+
// (C) Copyright IBM Corp. 2023.
7+
//
8+
// Licensed under the Apache License, Version 2.0 (the "License");
9+
// you may not use this file except in compliance with the License.
10+
// You may obtain a copy of the License at
11+
//
12+
// http://www.apache.org/licenses/LICENSE-2.0
13+
//
14+
// Unless required by applicable law or agreed to in writing, software
15+
// distributed under the License is distributed on an "AS IS" BASIS,
16+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
17+
// See the License for the specific language governing permissions and
18+
// limitations under the License.
19+
20+
import (
21+
"fmt"
22+
"net/http"
23+
"net/http/httptest"
24+
"strings"
25+
"testing"
26+
27+
"github.com/stretchr/testify/assert"
28+
)
29+
30+
// Note: this unit test depends on some bogus hostnames being defined in /etc/hosts.
31+
// Append this to your /etc/hosts file:
32+
// # for testing
33+
// 127.0.0.1 region1.cloud.ibm.com region2.cloud.ibm.com region1.notcloud.ibm.com region2.notcloud.ibm.com
34+
35+
var (
36+
operationPath string = "/api/redirector"
37+
38+
// To enable debug mode while running these tests, set this to LevelDebug.
39+
redirectTestLogLevel LogLevel = LevelError
40+
)
41+
42+
// Start a mock server that will redirect requests to the second mock server
43+
// located at "redirectServerURL"
44+
func startMockServer1(t *testing.T, redirectServerURL string) *httptest.Server {
45+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
46+
t.Logf(`server1 received request: %s %s`, req.Method, req.URL.String())
47+
48+
// Make sure the Authorization header was sent.
49+
assert.NotEmpty(t, req.Header.Get("Authorization"))
50+
51+
path := req.URL.Path
52+
location := redirectServerURL + path
53+
54+
// Create the response (a 302 redirect).
55+
w.Header().Add("Location", location)
56+
w.WriteHeader(http.StatusFound)
57+
t.Logf(`Sent redirect request to: %s`, location)
58+
}))
59+
return server
60+
}
61+
62+
// Start a second mock server to which redirected requests will be sent.
63+
func startMockServer2(t *testing.T) *httptest.Server {
64+
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
65+
t.Logf(`server2 received request: %s %s`, req.Method, req.URL.String())
66+
67+
// Create the response.
68+
if req.Header.Get("Authorization") != "" {
69+
w.Header().Set("Content-Type", "application/json")
70+
w.WriteHeader(http.StatusOK)
71+
fmt.Fprintf(w, `{"name":"Jason Bourne"}`)
72+
} else {
73+
w.WriteHeader(http.StatusUnauthorized)
74+
}
75+
}))
76+
return server
77+
}
78+
79+
func startServers(t *testing.T, host1 string, host2 string) (server1 *httptest.Server, server1URL string,
80+
server2 *httptest.Server, server2URL string) {
81+
server2 = startMockServer2(t)
82+
server2URL = strings.ReplaceAll(server2.URL, "127.0.0.1", host2)
83+
t.Logf(`Server 2 listening on endpoint: %s (%s)`, server2URL, server2.URL)
84+
85+
server1 = startMockServer1(t, server2URL)
86+
server1URL = strings.ReplaceAll(server1.URL, "127.0.0.1", host1)
87+
t.Logf(`Server 1 listening on endpoint: %s (%s)`, server1URL, server1.URL)
88+
89+
return
90+
}
91+
92+
func testRedirection(t *testing.T, host1 string, host2 string, expectedStatusCode int) {
93+
GetLogger().SetLogLevel(redirectTestLogLevel)
94+
95+
// Both servers within trusted domain.
96+
server1, server1URL, server2, _ := startServers(t, host1, host2)
97+
defer server1.Close()
98+
defer server2.Close()
99+
100+
builder := NewRequestBuilder("GET")
101+
_, err := builder.ResolveRequestURL(server1URL, operationPath, nil)
102+
assert.Nil(t, err)
103+
req, _ := builder.Build()
104+
105+
authenticator, err := NewBasicAuthenticator("xxx", "yyy")
106+
assert.Nil(t, err)
107+
assert.NotNil(t, authenticator)
108+
109+
options := &ServiceOptions{
110+
URL: server1.URL,
111+
Authenticator: authenticator,
112+
}
113+
service, err := NewBaseService(options)
114+
assert.Nil(t, err)
115+
116+
var foo *Foo
117+
detailedResponse, err := service.Request(req, &foo)
118+
assert.NotNil(t, detailedResponse)
119+
assert.Equal(t, expectedStatusCode, detailedResponse.StatusCode)
120+
if expectedStatusCode >= 200 && expectedStatusCode <= 299 {
121+
assert.Nil(t, err)
122+
123+
result, ok := detailedResponse.Result.(*Foo)
124+
assert.Equal(t, true, ok)
125+
assert.NotNil(t, result)
126+
assert.NotNil(t, foo)
127+
assert.Equal(t, "Jason Bourne", *result.Name)
128+
} else {
129+
assert.NotNil(t, err)
130+
}
131+
}
132+
133+
func TestRedirectAuthSuccess(t *testing.T) {
134+
testRedirection(t, "region1.cloud.ibm.com", "region2.cloud.ibm.com", http.StatusOK)
135+
}
136+
137+
func TestRedirectAuthFail1(t *testing.T) {
138+
testRedirection(t, "region1.notcloud.ibm.com", "region2.cloud.ibm.com", http.StatusUnauthorized)
139+
}
140+
141+
func TestRedirectAuthFail2(t *testing.T) {
142+
testRedirection(t, "region1.cloud.ibm.com", "region2.notcloud.ibm.com", http.StatusUnauthorized)
143+
}
144+
145+
func TestRedirectAuthFail3(t *testing.T) {
146+
testRedirection(t, "region1.notcloud.ibm.com", "region2.notcloud.ibm.com", http.StatusUnauthorized)
147+
}

core/base_service_test.go

-4
Original file line numberDiff line numberDiff line change
@@ -1769,23 +1769,20 @@ func TestClientWithRetries(t *testing.T) {
17691769
service.SetHTTPClient(client)
17701770
actualClient := service.GetHTTPClient()
17711771
assert.Equal(t, client, actualClient)
1772-
assert.Equal(t, *client, *actualClient)
17731772
assert.Equal(t, client, service.Client)
17741773

17751774
// Next, enable retries and make sure the client survived.
17761775
service.EnableRetries(4, 90*time.Second)
17771776
assert.True(t, isRetryableClient(service.Client))
17781777
actualClient = service.GetHTTPClient()
17791778
assert.Equal(t, client, actualClient)
1780-
assert.Equal(t, *client, *actualClient)
17811779

17821780
// Finally, disable retries and make sure
17831781
// we're left with the same client instance.
17841782
service.DisableRetries()
17851783
assert.False(t, isRetryableClient(service.Client))
17861784
actualClient = service.GetHTTPClient()
17871785
assert.Equal(t, client, actualClient)
1788-
assert.Equal(t, *client, *actualClient)
17891786
assert.Equal(t, client, service.Client)
17901787

17911788
// Create a new service and perform the steps in a different order.
@@ -1804,7 +1801,6 @@ func TestClientWithRetries(t *testing.T) {
18041801
assert.True(t, isRetryableClient(service.Client))
18051802
actualClient = service.GetHTTPClient()
18061803
assert.Equal(t, client, actualClient)
1807-
assert.Equal(t, *client, *actualClient)
18081804
}
18091805

18101806
func TestSetEnableGzipCompression(t *testing.T) {

0 commit comments

Comments
 (0)