1
1
package csrf
2
2
3
3
import (
4
+ "fmt"
4
5
"net/http"
5
6
"net/http/httptest"
6
7
"strings"
@@ -16,10 +17,7 @@ func TestProtect(t *testing.T) {
16
17
s := http .NewServeMux ()
17
18
s .HandleFunc ("/" , testHandler )
18
19
19
- r , err := http .NewRequest ("GET" , "/" , nil )
20
- if err != nil {
21
- t .Fatal (err )
22
- }
20
+ r := createRequest ("GET" , "/" , false )
23
21
24
22
rr := httptest .NewRecorder ()
25
23
p := Protect (testKey )(s )
@@ -46,10 +44,7 @@ func TestCookieOptions(t *testing.T) {
46
44
s := http .NewServeMux ()
47
45
s .HandleFunc ("/" , testHandler )
48
46
49
- r , err := http .NewRequest ("GET" , "/" , nil )
50
- if err != nil {
51
- t .Fatal (err )
52
- }
47
+ r := createRequest ("GET" , "/" , false )
53
48
54
49
rr := httptest .NewRecorder ()
55
50
p := Protect (testKey , CookieName ("nameoverride" ), Secure (false ), HttpOnly (false ), Path ("/pathoverride" ), Domain ("domainoverride" ), MaxAge (173 ))(s )
@@ -86,10 +81,7 @@ func TestMethods(t *testing.T) {
86
81
87
82
// Test idempontent ("safe") methods
88
83
for _ , method := range safeMethods {
89
- r , err := http .NewRequest (method , "/" , nil )
90
- if err != nil {
91
- t .Fatal (err )
92
- }
84
+ r := createRequest (method , "/" , false )
93
85
94
86
rr := httptest .NewRecorder ()
95
87
p .ServeHTTP (rr , r )
@@ -107,10 +99,7 @@ func TestMethods(t *testing.T) {
107
99
// Test non-idempotent methods (should return a 403 without a cookie set)
108
100
nonIdempotent := []string {"POST" , "PUT" , "DELETE" , "PATCH" }
109
101
for _ , method := range nonIdempotent {
110
- r , err := http .NewRequest (method , "/" , nil )
111
- if err != nil {
112
- t .Fatal (err )
113
- }
102
+ r := createRequest (method , "/" , false )
114
103
115
104
rr := httptest .NewRecorder ()
116
105
p .ServeHTTP (rr , r )
@@ -133,10 +122,7 @@ func TestNoCookie(t *testing.T) {
133
122
p := Protect (testKey )(s )
134
123
135
124
// POST the token back in the header.
136
- r , err := http .NewRequest ("POST" , "http://www.gorillatoolkit.org/" , nil )
137
- if err != nil {
138
- t .Fatal (err )
139
- }
125
+ r := createRequest ("POST" , "/" , false )
140
126
141
127
rr := httptest .NewRecorder ()
142
128
p .ServeHTTP (rr , r )
@@ -158,19 +144,13 @@ func TestBadCookie(t *testing.T) {
158
144
}))
159
145
160
146
// Obtain a CSRF cookie via a GET request.
161
- r , err := http .NewRequest ("GET" , "http://www.gorillatoolkit.org/" , nil )
162
- if err != nil {
163
- t .Fatal (err )
164
- }
147
+ r := createRequest ("GET" , "/" , false )
165
148
166
149
rr := httptest .NewRecorder ()
167
150
p .ServeHTTP (rr , r )
168
151
169
152
// POST the token back in the header.
170
- r , err = http .NewRequest ("POST" , "http://www.gorillatoolkit.org/" , nil )
171
- if err != nil {
172
- t .Fatal (err )
173
- }
153
+ r = createRequest ("POST" , "/" , false )
174
154
175
155
// Replace the cookie prefix
176
156
badHeader := strings .Replace (cookieName + "=" , rr .Header ().Get ("Set-Cookie" ), "_badCookie" , - 1 )
@@ -193,10 +173,7 @@ func TestVaryHeader(t *testing.T) {
193
173
s .HandleFunc ("/" , testHandler )
194
174
p := Protect (testKey )(s )
195
175
196
- r , err := http .NewRequest ("HEAD" , "https://www.golang.org/" , nil )
197
- if err != nil {
198
- t .Fatal (err )
199
- }
176
+ r := createRequest ("GET" , "/" , true )
200
177
201
178
rr := httptest .NewRecorder ()
202
179
p .ServeHTTP (rr , r )
@@ -211,16 +188,13 @@ func TestVaryHeader(t *testing.T) {
211
188
}
212
189
}
213
190
214
- // Requests with no Referer header should fail.
191
+ // TestNoReferer checks that HTTPS requests with no Referer header fail.
215
192
func TestNoReferer (t * testing.T ) {
216
193
s := http .NewServeMux ()
217
194
s .HandleFunc ("/" , testHandler )
218
195
p := Protect (testKey )(s )
219
196
220
- r , err := http .NewRequest ("POST" , "https://golang.org/" , nil )
221
- if err != nil {
222
- t .Fatal (err )
223
- }
197
+ r := createRequest ("POST" , "https://golang.org/" , true )
224
198
225
199
rr := httptest .NewRecorder ()
226
200
p .ServeHTTP (rr , r )
@@ -243,20 +217,12 @@ func TestBadReferer(t *testing.T) {
243
217
}))
244
218
245
219
// Obtain a CSRF cookie via a GET request.
246
- r , err := http .NewRequest ("GET" , "https://www.gorillatoolkit.org/" , nil )
247
- if err != nil {
248
- t .Fatal (err )
249
- }
250
-
220
+ r := createRequest ("GET" , "/" , true )
251
221
rr := httptest .NewRecorder ()
252
222
p .ServeHTTP (rr , r )
253
223
254
224
// POST the token back in the header.
255
- r , err = http .NewRequest ("POST" , "https://www.gorillatoolkit.org/" , nil )
256
- if err != nil {
257
- t .Fatal (err )
258
- }
259
-
225
+ r = createRequest ("POST" , "/" , true )
260
226
setCookie (rr , r )
261
227
r .Header .Set ("X-CSRF-Token" , token )
262
228
@@ -289,50 +255,47 @@ func TestTrustedReferer(t *testing.T) {
289
255
}
290
256
291
257
for _ , item := range testTable {
292
- s := http . NewServeMux ()
258
+ t . Run ( fmt . Sprintf ( "TrustedOrigin: %v" , item . trustedOrigin ), func ( t * testing. T ) {
293
259
294
- p := Protect ( testKey , TrustedOrigins ( item . trustedOrigin ))( s )
260
+ s := http . NewServeMux ( )
295
261
296
- var token string
297
- s .Handle ("/" , http .HandlerFunc (func (_ http.ResponseWriter , r * http.Request ) {
298
- token = Token (r )
299
- }))
262
+ p := Protect (testKey , TrustedOrigins (item .trustedOrigin ))(s )
300
263
301
- // Obtain a CSRF cookie via a GET request.
302
- r , err := http .NewRequest ("GET" , "https://www.gorillatoolkit.org/" , nil )
303
- if err != nil {
304
- t .Fatal (err )
305
- }
264
+ var token string
265
+ s .Handle ("/" , http .HandlerFunc (func (_ http.ResponseWriter , r * http.Request ) {
266
+ token = Token (r )
267
+ }))
306
268
307
- rr := httptest . NewRecorder ()
308
- p . ServeHTTP ( rr , r )
269
+ // Obtain a CSRF cookie via a GET request.
270
+ r := createRequest ( "GET" , "/" , true )
309
271
310
- // POST the token back in the header.
311
- r , err = http .NewRequest ("POST" , "https://www.gorillatoolkit.org/" , nil )
312
- if err != nil {
313
- t .Fatal (err )
314
- }
272
+ rr := httptest .NewRecorder ()
273
+ p .ServeHTTP (rr , r )
315
274
316
- setCookie ( rr , r )
317
- r . Header . Set ( "X-CSRF-Token " , token )
275
+ // POST the token back in the header.
276
+ r = createRequest ( "POST " , "/" , true )
318
277
319
- // Set a non-matching Referer header.
320
- r .Header .Set ("Referer " , "http://golang.org/" )
278
+ setCookie ( rr , r )
279
+ r .Header .Set ("X-CSRF-Token " , token )
321
280
322
- rr = httptest . NewRecorder ()
323
- p . ServeHTTP ( rr , r )
281
+ // Set a non-matching Referer header.
282
+ r . Header . Set ( "Referer" , "https://golang.org/" )
324
283
325
- if item .shouldPass {
326
- if rr .Code != http .StatusOK {
327
- t .Fatalf ("middleware failed to pass to the next handler: got %v want %v" ,
328
- rr .Code , http .StatusOK )
329
- }
330
- } else {
331
- if rr .Code != http .StatusForbidden {
332
- t .Fatalf ("middleware failed reject a non-matching Referer header: got %v want %v" ,
333
- rr .Code , http .StatusForbidden )
284
+ rr = httptest .NewRecorder ()
285
+ p .ServeHTTP (rr , r )
286
+
287
+ if item .shouldPass {
288
+ if rr .Code != http .StatusOK {
289
+ t .Fatalf ("middleware failed to pass to the next handler: got %v want %v" ,
290
+ rr .Code , http .StatusOK )
291
+ }
292
+ } else {
293
+ if rr .Code != http .StatusForbidden {
294
+ t .Fatalf ("middleware failed reject a non-matching Referer header: got %v want %v" ,
295
+ rr .Code , http .StatusForbidden )
296
+ }
334
297
}
335
- }
298
+ })
336
299
}
337
300
}
338
301
@@ -347,23 +310,16 @@ func TestWithReferer(t *testing.T) {
347
310
}))
348
311
349
312
// Obtain a CSRF cookie via a GET request.
350
- r , err := http .NewRequest ("GET" , "http://www.gorillatoolkit.org/" , nil )
351
- if err != nil {
352
- t .Fatal (err )
353
- }
354
-
313
+ r := createRequest ("GET" , "/" , true )
355
314
rr := httptest .NewRecorder ()
356
315
p .ServeHTTP (rr , r )
357
316
358
317
// POST the token back in the header.
359
- r , err = http .NewRequest ("POST" , "http://www.gorillatoolkit.org/" , nil )
360
- if err != nil {
361
- t .Fatal (err )
362
- }
318
+ r = createRequest ("POST" , "/" , true )
363
319
364
320
setCookie (rr , r )
365
321
r .Header .Set ("X-CSRF-Token" , token )
366
- r .Header .Set ("Referer" , "http ://www.gorillatoolkit.org/" )
322
+ r .Header .Set ("Referer" , "https ://www.gorillatoolkit.org/" )
367
323
368
324
rr = httptest .NewRecorder ()
369
325
p .ServeHTTP (rr , r )
@@ -387,26 +343,19 @@ func TestNoTokenProvided(t *testing.T) {
387
343
s .Handle ("/" , http .HandlerFunc (func (_ http.ResponseWriter , r * http.Request ) {
388
344
token = Token (r )
389
345
}))
390
-
391
346
// Obtain a CSRF cookie via a GET request.
392
- r , err := http .NewRequest ("GET" , "http://www.gorillatoolkit.org/" , nil )
393
- if err != nil {
394
- t .Fatal (err )
395
- }
347
+ r := createRequest ("GET" , "/" , true )
396
348
397
349
rr := httptest .NewRecorder ()
398
350
p .ServeHTTP (rr , r )
399
351
400
352
// POST the token back in the header.
401
- r , err = http .NewRequest ("POST" , "http://www.gorillatoolkit.org/" , nil )
402
- if err != nil {
403
- t .Fatal (err )
404
- }
353
+ r = createRequest ("POST" , "/" , true )
405
354
406
355
setCookie (rr , r )
407
356
// By accident we use the wrong header name for the token...
408
357
r .Header .Set ("X-CSRF-nekot" , token )
409
- r .Header .Set ("Referer" , "http ://www.gorillatoolkit.org/" )
358
+ r .Header .Set ("Referer" , "https ://www.gorillatoolkit.org/" )
410
359
411
360
rr = httptest .NewRecorder ()
412
361
p .ServeHTTP (rr , r )
@@ -419,3 +368,177 @@ func TestNoTokenProvided(t *testing.T) {
419
368
func setCookie (rr * httptest.ResponseRecorder , r * http.Request ) {
420
369
r .Header .Set ("Cookie" , rr .Header ().Get ("Set-Cookie" ))
421
370
}
371
+
372
+ func TestProtectScenarios (t * testing.T ) {
373
+ tests := []struct {
374
+ name string
375
+ safeMethod bool
376
+ originUntrusted bool
377
+ originHTTP bool
378
+ originTrusted bool
379
+ secureRequest bool
380
+ refererTrusted bool
381
+ refererUntrusted bool
382
+ refererHTTPDowngrade bool
383
+ refererRelative bool
384
+ tokenValid bool
385
+ tokenInvalid bool
386
+ want bool
387
+ }{
388
+ {
389
+ name : "safe method pass" ,
390
+ safeMethod : true ,
391
+ want : true ,
392
+ },
393
+ {
394
+ name : "cleartext POST with trusted origin & valid token pass" ,
395
+ originHTTP : true ,
396
+ tokenValid : true ,
397
+ want : true ,
398
+ },
399
+ {
400
+ name : "cleartext POST with untrusted origin reject" ,
401
+ originUntrusted : true ,
402
+ tokenValid : true ,
403
+ },
404
+ {
405
+ name : "cleartext POST with HTTP origin & invalid token reject" ,
406
+ originHTTP : true ,
407
+ },
408
+ {
409
+ name : "cleartext POST without origin with valid token pass" ,
410
+ tokenValid : true ,
411
+ want : true ,
412
+ },
413
+ {
414
+ name : "cleartext POST without origin with invalid token reject" ,
415
+ },
416
+ {
417
+ name : "TLS POST with HTTP origin & no referer & valid token reject" ,
418
+ tokenValid : true ,
419
+ secureRequest : true ,
420
+ originHTTP : true ,
421
+ },
422
+ {
423
+ name : "TLS POST without origin and without referer reject" ,
424
+ secureRequest : true ,
425
+ tokenValid : true ,
426
+ },
427
+ {
428
+ name : "TLS POST without origin with untrusted referer reject" ,
429
+ secureRequest : true ,
430
+ refererUntrusted : true ,
431
+ tokenValid : true ,
432
+ },
433
+ {
434
+ name : "TLS POST without origin with trusted referer & valid token pass" ,
435
+ secureRequest : true ,
436
+ refererTrusted : true ,
437
+ tokenValid : true ,
438
+ want : true ,
439
+ },
440
+ {
441
+ name : "TLS POST without origin from _cleartext_ same domain referer with valid token reject" ,
442
+ secureRequest : true ,
443
+ refererHTTPDowngrade : true ,
444
+ tokenValid : true ,
445
+ },
446
+ {
447
+ name : "TLS POST without origin from relative referer with valid token pass" ,
448
+ secureRequest : true ,
449
+ refererRelative : true ,
450
+ tokenValid : true ,
451
+ want : true ,
452
+ },
453
+ {
454
+ name : "TLS POST without origin from relative referer with invalid token reject" ,
455
+ secureRequest : true ,
456
+ refererRelative : true ,
457
+ tokenInvalid : true ,
458
+ },
459
+ }
460
+
461
+ for _ , tt := range tests {
462
+ t .Run (tt .name , func (t * testing.T ) {
463
+ var token string
464
+ var flag bool
465
+ mux := http .NewServeMux ()
466
+ mux .Handle ("/" , http .HandlerFunc (func (_ http.ResponseWriter , r * http.Request ) {
467
+ token = Token (r )
468
+ }))
469
+ mux .Handle ("/submit" , http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
470
+ flag = true
471
+ }))
472
+ p := Protect (testKey )(mux )
473
+
474
+ // Obtain a CSRF cookie via a GET request.
475
+ r := createRequest ("GET" , "/" , tt .secureRequest )
476
+ rr := httptest .NewRecorder ()
477
+ p .ServeHTTP (rr , r )
478
+
479
+ r = createRequest ("POST" , "/submit" , tt .secureRequest )
480
+ if tt .safeMethod {
481
+ r = createRequest ("GET" , "/submit" , tt .secureRequest )
482
+ }
483
+
484
+ // Set the Origin header
485
+ switch {
486
+ case tt .originUntrusted :
487
+ r .Header .Set ("Origin" , "http://www.untrusted-origin.org" )
488
+ case tt .originTrusted :
489
+ r .Header .Set ("Origin" , "https://www.gorillatoolkit.org" )
490
+ case tt .originHTTP :
491
+ r .Header .Set ("Origin" , "http://www.gorillatoolkit.org" )
492
+ }
493
+
494
+ // Set the Referer header
495
+ switch {
496
+ case tt .refererTrusted :
497
+ p = Protect (testKey , TrustedOrigins ([]string {"external-trusted-origin.test" }))(mux )
498
+ r .Header .Set ("Referer" , "https://external-trusted-origin.test/foobar" )
499
+ case tt .refererUntrusted :
500
+ r .Header .Set ("Referer" , "http://www.invalid-referer.org" )
501
+ case tt .refererHTTPDowngrade :
502
+ r .Header .Set ("Referer" , "http://www.gorillatoolkit.org/foobar" )
503
+ case tt .refererRelative :
504
+ r .Header .Set ("Referer" , "/foobar" )
505
+ }
506
+
507
+ // Set the CSRF token & associated cookie
508
+ switch {
509
+ case tt .tokenInvalid :
510
+ setCookie (rr , r )
511
+ r .Header .Set ("X-CSRF-Token" , "this-is-an-invalid-token" )
512
+ case tt .tokenValid :
513
+ setCookie (rr , r )
514
+ r .Header .Set ("X-CSRF-Token" , token )
515
+ }
516
+
517
+ rr = httptest .NewRecorder ()
518
+ p .ServeHTTP (rr , r )
519
+
520
+ if tt .want && rr .Code != http .StatusOK {
521
+ t .Fatalf ("middleware failed to pass to the next handler: got %v want %v" ,
522
+ rr .Code , http .StatusOK )
523
+ }
524
+
525
+ if tt .want && ! flag {
526
+ t .Fatalf ("middleware failed to pass to the next handler: got %v want %v" ,
527
+ flag , true )
528
+
529
+ }
530
+ if ! tt .want && flag {
531
+ t .Fatalf ("middleware failed to reject the request: got %v want %v" , flag , false )
532
+ }
533
+ })
534
+ }
535
+ }
536
+
537
+ func createRequest (method , path string , useTLS bool ) * http.Request {
538
+ r := httptest .NewRequest (method , path , nil )
539
+ r .Host = "www.gorillatoolkit.org"
540
+ if ! useTLS {
541
+ return PlaintextHTTPRequest (r )
542
+ }
543
+ return r
544
+ }
0 commit comments