Skip to content

Commit 9dd6af1

Browse files
authored
Merge commit from fork
* csrf: use context to determine TLS state r.URL.Scheme is never populated for "server" requests, and so the referer check never runs. Instead we now ask the caller application to signal this explicitly via request conext, and then enforce the check accordingly. Separately, browsers do not always send the full URL as a Referer, especially in the same-origin context meaning we cannot compare its host against our trusted origin list. If the referer does not contain a host we populate r.URL.Host with r.Host which is expected to be sent by all clients as the first header of their request. Add tests against the Origin header before attempting to enforce same-origin restrictions using the Referer header. Matching the Django CSRF behavior: if the Origin is present in either the cleartext or TLS case we will evaluate it. IFF we are in TLS and we have no Origin we will evaluate the Referer against the allowlist. In doing so we take care to permit "path only" Referers that are sent in same-origin context. * add csrf.TLSRequest helper API to set request TLS context Add a csrf.TLSRequest public API method that sets the appropriate TLS context key and signals to the midldeware the need to run the additiontal Referer checks. * Enable Referer-based origin checks by default Reverse the default position and presume that that the server is using TLS either directly or via an upstream proxy and require the user to explicitly disable referer-based checks. This safe default means that users that upgrade the library without making any other code changes will benefit from the Referer checks that they thought were active already. Without this change we risk that some codebases will mistakenly remain vulnerable even while using a patched version of the library.
1 parent a009743 commit 9dd6af1

File tree

3 files changed

+307
-131
lines changed

3 files changed

+307
-131
lines changed

csrf.go

+75-14
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
package csrf
22

33
import (
4+
"context"
45
"errors"
56
"fmt"
67
"net/http"
78
"net/url"
9+
"slices"
810

911
"github.com/gorilla/securecookie"
1012
)
@@ -22,6 +24,14 @@ const (
2224
errorPrefix string = "gorilla/csrf: "
2325
)
2426

27+
type contextKey string
28+
29+
// PlaintextHTTPContextKey is the context key used to store whether the request
30+
// is being served via plaintext HTTP. This is used to signal to the middleware
31+
// that strict Referer checking should not be enforced as is done for HTTPS by
32+
// default.
33+
const PlaintextHTTPContextKey contextKey = "plaintext"
34+
2535
var (
2636
// The name value used in form fields.
2737
fieldName = tokenKey
@@ -41,6 +51,9 @@ var (
4151
// ErrNoReferer is returned when a HTTPS request provides an empty Referer
4252
// header.
4353
ErrNoReferer = errors.New("referer not supplied")
54+
// ErrBadOrigin is returned when the Origin header is present and is not a
55+
// trusted origin.
56+
ErrBadOrigin = errors.New("origin invalid")
4457
// ErrBadReferer is returned when the scheme & host in the URL do not match
4558
// the supplied Referer header.
4659
ErrBadReferer = errors.New("referer invalid")
@@ -242,10 +255,50 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
242255
// HTTP methods not defined as idempotent ("safe") under RFC7231 require
243256
// inspection.
244257
if !contains(safeMethods, r.Method) {
245-
// Enforce an origin check for HTTPS connections. As per the Django CSRF
246-
// implementation (https://goo.gl/vKA7GE) the Referer header is almost
247-
// always present for same-domain HTTP requests.
248-
if r.URL.Scheme == "https" {
258+
var isPlaintext bool
259+
val := r.Context().Value(PlaintextHTTPContextKey)
260+
if val != nil {
261+
isPlaintext, _ = val.(bool)
262+
}
263+
264+
// take a copy of the request URL to avoid mutating the original
265+
// attached to the request.
266+
// set the scheme & host based on the request context as these are not
267+
// populated by default for server requests
268+
// ref: https://pkg.go.dev/net/http#Request
269+
requestURL := *r.URL // shallow clone
270+
271+
requestURL.Scheme = "https"
272+
if isPlaintext {
273+
requestURL.Scheme = "http"
274+
}
275+
if requestURL.Host == "" {
276+
requestURL.Host = r.Host
277+
}
278+
279+
// if we have an Origin header, check it against our allowlist
280+
origin := r.Header.Get("Origin")
281+
if origin != "" {
282+
parsedOrigin, err := url.Parse(origin)
283+
if err != nil {
284+
r = envError(r, ErrBadOrigin)
285+
cs.opts.ErrorHandler.ServeHTTP(w, r)
286+
return
287+
}
288+
if !sameOrigin(&requestURL, parsedOrigin) && !slices.Contains(cs.opts.TrustedOrigins, parsedOrigin.Host) {
289+
r = envError(r, ErrBadOrigin)
290+
cs.opts.ErrorHandler.ServeHTTP(w, r)
291+
return
292+
}
293+
}
294+
295+
// If we are serving via TLS and have no Origin header, prevent against
296+
// CSRF via HTTP machine in the middle attacks by enforcing strict
297+
// Referer origin checks. Consider an attacker who performs a
298+
// successful HTTP Machine-in-the-Middle attack and uses this to inject
299+
// a form and cause submission to our origin. We strictly disallow
300+
// cleartext HTTP origins and evaluate the domain against an allowlist.
301+
if origin == "" && !isPlaintext {
249302
// Fetch the Referer value. Call the error handler if it's empty or
250303
// otherwise fails to parse.
251304
referer, err := url.Parse(r.Referer())
@@ -255,18 +308,17 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
255308
return
256309
}
257310

258-
valid := sameOrigin(r.URL, referer)
259-
260-
if !valid {
261-
for _, trustedOrigin := range cs.opts.TrustedOrigins {
262-
if referer.Host == trustedOrigin {
263-
valid = true
264-
break
265-
}
266-
}
311+
// disallow cleartext HTTP referers when serving via TLS
312+
if referer.Scheme == "http" {
313+
r = envError(r, ErrBadReferer)
314+
cs.opts.ErrorHandler.ServeHTTP(w, r)
315+
return
267316
}
268317

269-
if !valid {
318+
// If the request is being served via TLS and the Referer is not the
319+
// same origin, check the domain against our allowlist. We only
320+
// check when we have host information from the referer.
321+
if referer.Host != "" && referer.Host != r.Host && !slices.Contains(cs.opts.TrustedOrigins, referer.Host) {
270322
r = envError(r, ErrBadReferer)
271323
cs.opts.ErrorHandler.ServeHTTP(w, r)
272324
return
@@ -308,6 +360,15 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
308360
contextClear(r)
309361
}
310362

363+
// PlaintextHTTPRequest accepts as input a http.Request and returns a new
364+
// http.Request with the PlaintextHTTPContextKey set to true. This is used to
365+
// signal to the CSRF middleware that the request is being served over plaintext
366+
// HTTP and that Referer-based origin allow-listing checks should be skipped.
367+
func PlaintextHTTPRequest(r *http.Request) *http.Request {
368+
ctx := context.WithValue(r.Context(), PlaintextHTTPContextKey, true)
369+
return r.WithContext(ctx)
370+
}
371+
311372
// unauthorizedhandler sets a HTTP 403 Forbidden status and writes the
312373
// CSRF failure reason to the response.
313374
func unauthorizedHandler(w http.ResponseWriter, r *http.Request) {

0 commit comments

Comments
 (0)