Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit f968bcf

Browse files
authoredFeb 25, 2025
Merge branch 'main' into add/csrf/origin
Signed-off-by: SangHyuk <[email protected]>
2 parents aaa75a9 + 9dd6af1 commit f968bcf

File tree

3 files changed

+307
-131
lines changed

3 files changed

+307
-131
lines changed
 

‎csrf.go

+75-14
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,12 @@ package csrf
22

33
import (
44
"encoding/base64"
5+
"context"
56
"errors"
67
"fmt"
78
"net/http"
89
"net/url"
10+
"slices"
911

1012
"github.com/gorilla/securecookie"
1113
)
@@ -23,6 +25,14 @@ const (
2325
errorPrefix string = "gorilla/csrf: "
2426
)
2527

28+
type contextKey string
29+
30+
// PlaintextHTTPContextKey is the context key used to store whether the request
31+
// is being served via plaintext HTTP. This is used to signal to the middleware
32+
// that strict Referer checking should not be enforced as is done for HTTPS by
33+
// default.
34+
const PlaintextHTTPContextKey contextKey = "plaintext"
35+
2636
var (
2737
// The name value used in form fields.
2838
fieldName = tokenKey
@@ -42,6 +52,9 @@ var (
4252
// ErrNoReferer is returned when a HTTPS request provides an empty Referer
4353
// header.
4454
ErrNoReferer = errors.New("referer not supplied")
55+
// ErrBadOrigin is returned when the Origin header is present and is not a
56+
// trusted origin.
57+
ErrBadOrigin = errors.New("origin invalid")
4558
// ErrBadReferer is returned when the scheme & host in the URL do not match
4659
// the supplied Referer header.
4760
ErrBadReferer = errors.New("referer invalid")
@@ -248,10 +261,50 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
248261
// HTTP methods not defined as idempotent ("safe") under RFC7231 require
249262
// inspection.
250263
if !contains(safeMethods, r.Method) {
251-
// Enforce an origin check for HTTPS connections. As per the Django CSRF
252-
// implementation (https://goo.gl/vKA7GE) the Referer header is almost
253-
// always present for same-domain HTTP requests.
254-
if r.URL.Scheme == "https" {
264+
var isPlaintext bool
265+
val := r.Context().Value(PlaintextHTTPContextKey)
266+
if val != nil {
267+
isPlaintext, _ = val.(bool)
268+
}
269+
270+
// take a copy of the request URL to avoid mutating the original
271+
// attached to the request.
272+
// set the scheme & host based on the request context as these are not
273+
// populated by default for server requests
274+
// ref: https://pkg.go.dev/net/http#Request
275+
requestURL := *r.URL // shallow clone
276+
277+
requestURL.Scheme = "https"
278+
if isPlaintext {
279+
requestURL.Scheme = "http"
280+
}
281+
if requestURL.Host == "" {
282+
requestURL.Host = r.Host
283+
}
284+
285+
// if we have an Origin header, check it against our allowlist
286+
origin := r.Header.Get("Origin")
287+
if origin != "" {
288+
parsedOrigin, err := url.Parse(origin)
289+
if err != nil {
290+
r = envError(r, ErrBadOrigin)
291+
cs.opts.ErrorHandler.ServeHTTP(w, r)
292+
return
293+
}
294+
if !sameOrigin(&requestURL, parsedOrigin) && !slices.Contains(cs.opts.TrustedOrigins, parsedOrigin.Host) {
295+
r = envError(r, ErrBadOrigin)
296+
cs.opts.ErrorHandler.ServeHTTP(w, r)
297+
return
298+
}
299+
}
300+
301+
// If we are serving via TLS and have no Origin header, prevent against
302+
// CSRF via HTTP machine in the middle attacks by enforcing strict
303+
// Referer origin checks. Consider an attacker who performs a
304+
// successful HTTP Machine-in-the-Middle attack and uses this to inject
305+
// a form and cause submission to our origin. We strictly disallow
306+
// cleartext HTTP origins and evaluate the domain against an allowlist.
307+
if origin == "" && !isPlaintext {
255308
// Fetch the Referer value. Call the error handler if it's empty or
256309
// otherwise fails to parse.
257310
referer, err := url.Parse(r.Referer())
@@ -261,18 +314,17 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
261314
return
262315
}
263316

264-
valid := sameOrigin(r.URL, referer)
265-
266-
if !valid {
267-
for _, trustedOrigin := range cs.opts.TrustedOrigins {
268-
if referer.Host == trustedOrigin {
269-
valid = true
270-
break
271-
}
272-
}
317+
// disallow cleartext HTTP referers when serving via TLS
318+
if referer.Scheme == "http" {
319+
r = envError(r, ErrBadReferer)
320+
cs.opts.ErrorHandler.ServeHTTP(w, r)
321+
return
273322
}
274323

275-
if !valid {
324+
// If the request is being served via TLS and the Referer is not the
325+
// same origin, check the domain against our allowlist. We only
326+
// check when we have host information from the referer.
327+
if referer.Host != "" && referer.Host != r.Host && !slices.Contains(cs.opts.TrustedOrigins, referer.Host) {
276328
r = envError(r, ErrBadReferer)
277329
cs.opts.ErrorHandler.ServeHTTP(w, r)
278330
return
@@ -314,6 +366,15 @@ func (cs *csrf) ServeHTTP(w http.ResponseWriter, r *http.Request) {
314366
contextClear(r)
315367
}
316368

369+
// PlaintextHTTPRequest accepts as input a http.Request and returns a new
370+
// http.Request with the PlaintextHTTPContextKey set to true. This is used to
371+
// signal to the CSRF middleware that the request is being served over plaintext
372+
// HTTP and that Referer-based origin allow-listing checks should be skipped.
373+
func PlaintextHTTPRequest(r *http.Request) *http.Request {
374+
ctx := context.WithValue(r.Context(), PlaintextHTTPContextKey, true)
375+
return r.WithContext(ctx)
376+
}
377+
317378
// unauthorizedhandler sets a HTTP 403 Forbidden status and writes the
318379
// CSRF failure reason to the response.
319380
func unauthorizedHandler(w http.ResponseWriter, r *http.Request) {

‎csrf_test.go

+224-101
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package csrf
22

33
import (
4+
"fmt"
45
"net/http"
56
"net/http/httptest"
67
"strings"
@@ -16,10 +17,7 @@ func TestProtect(t *testing.T) {
1617
s := http.NewServeMux()
1718
s.HandleFunc("/", testHandler)
1819

19-
r, err := http.NewRequest("GET", "/", nil)
20-
if err != nil {
21-
t.Fatal(err)
22-
}
20+
r := createRequest("GET", "/", false)
2321

2422
rr := httptest.NewRecorder()
2523
p := Protect(testKey)(s)
@@ -46,10 +44,7 @@ func TestCookieOptions(t *testing.T) {
4644
s := http.NewServeMux()
4745
s.HandleFunc("/", testHandler)
4846

49-
r, err := http.NewRequest("GET", "/", nil)
50-
if err != nil {
51-
t.Fatal(err)
52-
}
47+
r := createRequest("GET", "/", false)
5348

5449
rr := httptest.NewRecorder()
5550
p := Protect(testKey, CookieName("nameoverride"), Secure(false), HttpOnly(false), Path("/pathoverride"), Domain("domainoverride"), MaxAge(173))(s)
@@ -86,10 +81,7 @@ func TestMethods(t *testing.T) {
8681

8782
// Test idempontent ("safe") methods
8883
for _, method := range safeMethods {
89-
r, err := http.NewRequest(method, "/", nil)
90-
if err != nil {
91-
t.Fatal(err)
92-
}
84+
r := createRequest(method, "/", false)
9385

9486
rr := httptest.NewRecorder()
9587
p.ServeHTTP(rr, r)
@@ -107,10 +99,7 @@ func TestMethods(t *testing.T) {
10799
// Test non-idempotent methods (should return a 403 without a cookie set)
108100
nonIdempotent := []string{"POST", "PUT", "DELETE", "PATCH"}
109101
for _, method := range nonIdempotent {
110-
r, err := http.NewRequest(method, "/", nil)
111-
if err != nil {
112-
t.Fatal(err)
113-
}
102+
r := createRequest(method, "/", false)
114103

115104
rr := httptest.NewRecorder()
116105
p.ServeHTTP(rr, r)
@@ -133,10 +122,7 @@ func TestNoCookie(t *testing.T) {
133122
p := Protect(testKey)(s)
134123

135124
// POST the token back in the header.
136-
r, err := http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil)
137-
if err != nil {
138-
t.Fatal(err)
139-
}
125+
r := createRequest("POST", "/", false)
140126

141127
rr := httptest.NewRecorder()
142128
p.ServeHTTP(rr, r)
@@ -158,19 +144,13 @@ func TestBadCookie(t *testing.T) {
158144
}))
159145

160146
// Obtain a CSRF cookie via a GET request.
161-
r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil)
162-
if err != nil {
163-
t.Fatal(err)
164-
}
147+
r := createRequest("GET", "/", false)
165148

166149
rr := httptest.NewRecorder()
167150
p.ServeHTTP(rr, r)
168151

169152
// POST the token back in the header.
170-
r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil)
171-
if err != nil {
172-
t.Fatal(err)
173-
}
153+
r = createRequest("POST", "/", false)
174154

175155
// Replace the cookie prefix
176156
badHeader := strings.Replace(cookieName+"=", rr.Header().Get("Set-Cookie"), "_badCookie", -1)
@@ -193,10 +173,7 @@ func TestVaryHeader(t *testing.T) {
193173
s.HandleFunc("/", testHandler)
194174
p := Protect(testKey)(s)
195175

196-
r, err := http.NewRequest("HEAD", "https://www.golang.org/", nil)
197-
if err != nil {
198-
t.Fatal(err)
199-
}
176+
r := createRequest("GET", "/", true)
200177

201178
rr := httptest.NewRecorder()
202179
p.ServeHTTP(rr, r)
@@ -211,16 +188,13 @@ func TestVaryHeader(t *testing.T) {
211188
}
212189
}
213190

214-
// Requests with no Referer header should fail.
191+
// TestNoReferer checks that HTTPS requests with no Referer header fail.
215192
func TestNoReferer(t *testing.T) {
216193
s := http.NewServeMux()
217194
s.HandleFunc("/", testHandler)
218195
p := Protect(testKey)(s)
219196

220-
r, err := http.NewRequest("POST", "https://golang.org/", nil)
221-
if err != nil {
222-
t.Fatal(err)
223-
}
197+
r := createRequest("POST", "https://golang.org/", true)
224198

225199
rr := httptest.NewRecorder()
226200
p.ServeHTTP(rr, r)
@@ -243,20 +217,12 @@ func TestBadReferer(t *testing.T) {
243217
}))
244218

245219
// Obtain a CSRF cookie via a GET request.
246-
r, err := http.NewRequest("GET", "https://www.gorillatoolkit.org/", nil)
247-
if err != nil {
248-
t.Fatal(err)
249-
}
250-
220+
r := createRequest("GET", "/", true)
251221
rr := httptest.NewRecorder()
252222
p.ServeHTTP(rr, r)
253223

254224
// POST the token back in the header.
255-
r, err = http.NewRequest("POST", "https://www.gorillatoolkit.org/", nil)
256-
if err != nil {
257-
t.Fatal(err)
258-
}
259-
225+
r = createRequest("POST", "/", true)
260226
setCookie(rr, r)
261227
r.Header.Set("X-CSRF-Token", token)
262228

@@ -289,50 +255,47 @@ func TestTrustedReferer(t *testing.T) {
289255
}
290256

291257
for _, item := range testTable {
292-
s := http.NewServeMux()
258+
t.Run(fmt.Sprintf("TrustedOrigin: %v", item.trustedOrigin), func(t *testing.T) {
293259

294-
p := Protect(testKey, TrustedOrigins(item.trustedOrigin))(s)
260+
s := http.NewServeMux()
295261

296-
var token string
297-
s.Handle("/", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
298-
token = Token(r)
299-
}))
262+
p := Protect(testKey, TrustedOrigins(item.trustedOrigin))(s)
300263

301-
// Obtain a CSRF cookie via a GET request.
302-
r, err := http.NewRequest("GET", "https://www.gorillatoolkit.org/", nil)
303-
if err != nil {
304-
t.Fatal(err)
305-
}
264+
var token string
265+
s.Handle("/", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
266+
token = Token(r)
267+
}))
306268

307-
rr := httptest.NewRecorder()
308-
p.ServeHTTP(rr, r)
269+
// Obtain a CSRF cookie via a GET request.
270+
r := createRequest("GET", "/", true)
309271

310-
// POST the token back in the header.
311-
r, err = http.NewRequest("POST", "https://www.gorillatoolkit.org/", nil)
312-
if err != nil {
313-
t.Fatal(err)
314-
}
272+
rr := httptest.NewRecorder()
273+
p.ServeHTTP(rr, r)
315274

316-
setCookie(rr, r)
317-
r.Header.Set("X-CSRF-Token", token)
275+
// POST the token back in the header.
276+
r = createRequest("POST", "/", true)
318277

319-
// Set a non-matching Referer header.
320-
r.Header.Set("Referer", "http://golang.org/")
278+
setCookie(rr, r)
279+
r.Header.Set("X-CSRF-Token", token)
321280

322-
rr = httptest.NewRecorder()
323-
p.ServeHTTP(rr, r)
281+
// Set a non-matching Referer header.
282+
r.Header.Set("Referer", "https://golang.org/")
324283

325-
if item.shouldPass {
326-
if rr.Code != http.StatusOK {
327-
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
328-
rr.Code, http.StatusOK)
329-
}
330-
} else {
331-
if rr.Code != http.StatusForbidden {
332-
t.Fatalf("middleware failed reject a non-matching Referer header: got %v want %v",
333-
rr.Code, http.StatusForbidden)
284+
rr = httptest.NewRecorder()
285+
p.ServeHTTP(rr, r)
286+
287+
if item.shouldPass {
288+
if rr.Code != http.StatusOK {
289+
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
290+
rr.Code, http.StatusOK)
291+
}
292+
} else {
293+
if rr.Code != http.StatusForbidden {
294+
t.Fatalf("middleware failed reject a non-matching Referer header: got %v want %v",
295+
rr.Code, http.StatusForbidden)
296+
}
334297
}
335-
}
298+
})
336299
}
337300
}
338301

@@ -347,23 +310,16 @@ func TestWithReferer(t *testing.T) {
347310
}))
348311

349312
// Obtain a CSRF cookie via a GET request.
350-
r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil)
351-
if err != nil {
352-
t.Fatal(err)
353-
}
354-
313+
r := createRequest("GET", "/", true)
355314
rr := httptest.NewRecorder()
356315
p.ServeHTTP(rr, r)
357316

358317
// POST the token back in the header.
359-
r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil)
360-
if err != nil {
361-
t.Fatal(err)
362-
}
318+
r = createRequest("POST", "/", true)
363319

364320
setCookie(rr, r)
365321
r.Header.Set("X-CSRF-Token", token)
366-
r.Header.Set("Referer", "http://www.gorillatoolkit.org/")
322+
r.Header.Set("Referer", "https://www.gorillatoolkit.org/")
367323

368324
rr = httptest.NewRecorder()
369325
p.ServeHTTP(rr, r)
@@ -387,26 +343,19 @@ func TestNoTokenProvided(t *testing.T) {
387343
s.Handle("/", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
388344
token = Token(r)
389345
}))
390-
391346
// Obtain a CSRF cookie via a GET request.
392-
r, err := http.NewRequest("GET", "http://www.gorillatoolkit.org/", nil)
393-
if err != nil {
394-
t.Fatal(err)
395-
}
347+
r := createRequest("GET", "/", true)
396348

397349
rr := httptest.NewRecorder()
398350
p.ServeHTTP(rr, r)
399351

400352
// POST the token back in the header.
401-
r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", nil)
402-
if err != nil {
403-
t.Fatal(err)
404-
}
353+
r = createRequest("POST", "/", true)
405354

406355
setCookie(rr, r)
407356
// By accident we use the wrong header name for the token...
408357
r.Header.Set("X-CSRF-nekot", token)
409-
r.Header.Set("Referer", "http://www.gorillatoolkit.org/")
358+
r.Header.Set("Referer", "https://www.gorillatoolkit.org/")
410359

411360
rr = httptest.NewRecorder()
412361
p.ServeHTTP(rr, r)
@@ -419,3 +368,177 @@ func TestNoTokenProvided(t *testing.T) {
419368
func setCookie(rr *httptest.ResponseRecorder, r *http.Request) {
420369
r.Header.Set("Cookie", rr.Header().Get("Set-Cookie"))
421370
}
371+
372+
func TestProtectScenarios(t *testing.T) {
373+
tests := []struct {
374+
name string
375+
safeMethod bool
376+
originUntrusted bool
377+
originHTTP bool
378+
originTrusted bool
379+
secureRequest bool
380+
refererTrusted bool
381+
refererUntrusted bool
382+
refererHTTPDowngrade bool
383+
refererRelative bool
384+
tokenValid bool
385+
tokenInvalid bool
386+
want bool
387+
}{
388+
{
389+
name: "safe method pass",
390+
safeMethod: true,
391+
want: true,
392+
},
393+
{
394+
name: "cleartext POST with trusted origin & valid token pass",
395+
originHTTP: true,
396+
tokenValid: true,
397+
want: true,
398+
},
399+
{
400+
name: "cleartext POST with untrusted origin reject",
401+
originUntrusted: true,
402+
tokenValid: true,
403+
},
404+
{
405+
name: "cleartext POST with HTTP origin & invalid token reject",
406+
originHTTP: true,
407+
},
408+
{
409+
name: "cleartext POST without origin with valid token pass",
410+
tokenValid: true,
411+
want: true,
412+
},
413+
{
414+
name: "cleartext POST without origin with invalid token reject",
415+
},
416+
{
417+
name: "TLS POST with HTTP origin & no referer & valid token reject",
418+
tokenValid: true,
419+
secureRequest: true,
420+
originHTTP: true,
421+
},
422+
{
423+
name: "TLS POST without origin and without referer reject",
424+
secureRequest: true,
425+
tokenValid: true,
426+
},
427+
{
428+
name: "TLS POST without origin with untrusted referer reject",
429+
secureRequest: true,
430+
refererUntrusted: true,
431+
tokenValid: true,
432+
},
433+
{
434+
name: "TLS POST without origin with trusted referer & valid token pass",
435+
secureRequest: true,
436+
refererTrusted: true,
437+
tokenValid: true,
438+
want: true,
439+
},
440+
{
441+
name: "TLS POST without origin from _cleartext_ same domain referer with valid token reject",
442+
secureRequest: true,
443+
refererHTTPDowngrade: true,
444+
tokenValid: true,
445+
},
446+
{
447+
name: "TLS POST without origin from relative referer with valid token pass",
448+
secureRequest: true,
449+
refererRelative: true,
450+
tokenValid: true,
451+
want: true,
452+
},
453+
{
454+
name: "TLS POST without origin from relative referer with invalid token reject",
455+
secureRequest: true,
456+
refererRelative: true,
457+
tokenInvalid: true,
458+
},
459+
}
460+
461+
for _, tt := range tests {
462+
t.Run(tt.name, func(t *testing.T) {
463+
var token string
464+
var flag bool
465+
mux := http.NewServeMux()
466+
mux.Handle("/", http.HandlerFunc(func(_ http.ResponseWriter, r *http.Request) {
467+
token = Token(r)
468+
}))
469+
mux.Handle("/submit", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
470+
flag = true
471+
}))
472+
p := Protect(testKey)(mux)
473+
474+
// Obtain a CSRF cookie via a GET request.
475+
r := createRequest("GET", "/", tt.secureRequest)
476+
rr := httptest.NewRecorder()
477+
p.ServeHTTP(rr, r)
478+
479+
r = createRequest("POST", "/submit", tt.secureRequest)
480+
if tt.safeMethod {
481+
r = createRequest("GET", "/submit", tt.secureRequest)
482+
}
483+
484+
// Set the Origin header
485+
switch {
486+
case tt.originUntrusted:
487+
r.Header.Set("Origin", "http://www.untrusted-origin.org")
488+
case tt.originTrusted:
489+
r.Header.Set("Origin", "https://www.gorillatoolkit.org")
490+
case tt.originHTTP:
491+
r.Header.Set("Origin", "http://www.gorillatoolkit.org")
492+
}
493+
494+
// Set the Referer header
495+
switch {
496+
case tt.refererTrusted:
497+
p = Protect(testKey, TrustedOrigins([]string{"external-trusted-origin.test"}))(mux)
498+
r.Header.Set("Referer", "https://external-trusted-origin.test/foobar")
499+
case tt.refererUntrusted:
500+
r.Header.Set("Referer", "http://www.invalid-referer.org")
501+
case tt.refererHTTPDowngrade:
502+
r.Header.Set("Referer", "http://www.gorillatoolkit.org/foobar")
503+
case tt.refererRelative:
504+
r.Header.Set("Referer", "/foobar")
505+
}
506+
507+
// Set the CSRF token & associated cookie
508+
switch {
509+
case tt.tokenInvalid:
510+
setCookie(rr, r)
511+
r.Header.Set("X-CSRF-Token", "this-is-an-invalid-token")
512+
case tt.tokenValid:
513+
setCookie(rr, r)
514+
r.Header.Set("X-CSRF-Token", token)
515+
}
516+
517+
rr = httptest.NewRecorder()
518+
p.ServeHTTP(rr, r)
519+
520+
if tt.want && rr.Code != http.StatusOK {
521+
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
522+
rr.Code, http.StatusOK)
523+
}
524+
525+
if tt.want && !flag {
526+
t.Fatalf("middleware failed to pass to the next handler: got %v want %v",
527+
flag, true)
528+
529+
}
530+
if !tt.want && flag {
531+
t.Fatalf("middleware failed to reject the request: got %v want %v", flag, false)
532+
}
533+
})
534+
}
535+
}
536+
537+
func createRequest(method, path string, useTLS bool) *http.Request {
538+
r := httptest.NewRequest(method, path, nil)
539+
r.Host = "www.gorillatoolkit.org"
540+
if !useTLS {
541+
return PlaintextHTTPRequest(r)
542+
}
543+
return r
544+
}

‎helpers_test.go

+8-16
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,7 @@ func TestMultipartFormToken(t *testing.T) {
8383
}
8484
}))
8585

86-
r, err := http.NewRequest("GET", "/", nil)
87-
if err != nil {
88-
t.Fatal(err)
89-
}
86+
r := createRequest("GET", "/", true)
9087

9188
rr := httptest.NewRecorder()
9289
p := Protect(testKey)(s)
@@ -107,13 +104,13 @@ func TestMultipartFormToken(t *testing.T) {
107104

108105
mp.Close()
109106

110-
r, err = http.NewRequest("POST", "http://www.gorillatoolkit.org/", &b)
111-
if err != nil {
112-
t.Fatal(err)
113-
}
107+
r = httptest.NewRequest("POST", "/", &b)
108+
r.Host = "www.gorillatoolkit.org"
114109

115110
// Add the multipart header.
116111
r.Header.Set("Content-Type", mp.FormDataContentType())
112+
// Add Origin to pass the same-origin check.
113+
r.Header.Set("Origin", "https://www.gorillatoolkit.org")
117114

118115
// Send back the issued cookie.
119116
setCookie(rr, r)
@@ -248,10 +245,8 @@ func TestTemplateField(t *testing.T) {
248245
}))
249246

250247
testFieldName := "custom_field_name"
251-
r, err := http.NewRequest("GET", "/", nil)
252-
if err != nil {
253-
t.Fatal(err)
254-
}
248+
r := createRequest("GET", "/", false)
249+
// r, err := http.NewRequest("GET", "/", nil)
255250

256251
rr := httptest.NewRecorder()
257252
p := Protect(testKey, FieldName(testFieldName))(s)
@@ -301,10 +296,7 @@ func TestUnsafeSkipCSRFCheck(t *testing.T) {
301296
w.WriteHeader(teapot)
302297
}))
303298

304-
r, err := http.NewRequest("POST", "/", nil)
305-
if err != nil {
306-
t.Fatal(err)
307-
}
299+
r := createRequest("POST", "/", false)
308300

309301
// Must be used prior to the CSRF handler being invoked.
310302
p := skipCheck(Protect(testKey)(s))

0 commit comments

Comments
 (0)
Please sign in to comment.