Skip to content

Commit a291875

Browse files
committed
Add ability to only check CSRF on specified request types
Refs #28
1 parent 005fca5 commit a291875

File tree

4 files changed

+66
-27
lines changed

4 files changed

+66
-27
lines changed

docs/options.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ The available options are:
3333
will cause this access cookie to be sent in with every request. Should be modified
3434
for only the paths that need the refresh cookie
3535
``JWT_COOKIE_CSRF_PROTECT`` Enable/disable CSRF protection. Only used when sending the JWT in via cookies
36+
``JWT_CSRF_METHODS`` The request types that will use CSRF protection. Defaults to
37+
```['POST', 'PUT', 'PATCH', 'DELETE']```
3638
``JWT_ACCESS_CSRF_COOKIE_NAME`` Name of the CSRF access cookie. Defaults to ``'csrf_access_token'``. Only used
3739
if using cookies with CSRF protection enabled
3840
``JWT_REFRESH_CSRF_COOKIE_NAME`` Name of the CSRF refresh cookie. Defaults to ``'csrf_refresh_token'``. Only used

flask_jwt_extended/config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
# Options for using double submit for verifying CSRF tokens
2020
COOKIE_CSRF_PROTECT = True
21+
CSRF_METHODS = ['POST', 'PUT', 'PATCH', 'DELETE']
2122
ACCESS_CSRF_COOKIE_NAME = 'csrf_access_token'
2223
REFRESH_CSRF_COOKIE_NAME = 'csrf_refresh_token'
2324
CSRF_HEADER_NAME = 'X-CSRF-TOKEN'
@@ -79,6 +80,10 @@ def get_cookie_csrf_protect():
7980
return current_app.config.get('JWT_COOKIE_CSRF_PROTECT', COOKIE_CSRF_PROTECT)
8081

8182

83+
def get_csrf_request_methods():
84+
return current_app.config.get('JWT_CSRF_METHODS', CSRF_METHODS)
85+
86+
8287
def get_access_csrf_cookie_name():
8388
return current_app.config.get('JWT_ACCESS_CSRF_COOKIE_NAME', ACCESS_CSRF_COOKIE_NAME)
8489

flask_jwt_extended/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
get_cookie_csrf_protect, get_access_csrf_cookie_name, \
1919
get_refresh_cookie_name, get_refresh_cookie_path, \
2020
get_refresh_csrf_cookie_name, get_token_location, \
21-
get_csrf_header_name, get_jwt_header_name
21+
get_csrf_header_name, get_jwt_header_name, get_csrf_request_methods
2222
from flask_jwt_extended.exceptions import JWTEncodeError, JWTDecodeError, \
2323
InvalidHeaderError, NoAuthorizationError, WrongTokenError, \
2424
FreshTokenRequired, CSRFError
@@ -195,7 +195,7 @@ def _decode_jwt_from_cookies(type):
195195
algorithm = get_algorithm()
196196
token = _decode_jwt(token, secret, algorithm)
197197

198-
if get_cookie_csrf_protect():
198+
if get_cookie_csrf_protect() and request.method in get_csrf_request_methods():
199199
csrf_header_key = get_csrf_header_name()
200200
csrf_token_from_header = request.headers.get(csrf_header_key, None)
201201
csrf_token_from_cookie = token.get('csrf', None)

tests/test_protected_endpoints.py

Lines changed: 57 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -390,25 +390,11 @@ def refresh():
390390
set_access_cookies(resp, access_token)
391391
return resp, 200
392392

393-
@self.app.route('/api/protected')
393+
@self.app.route('/api/protected', methods=['POST'])
394394
@jwt_required
395395
def protected():
396396
return jsonify({'msg': "hello world"})
397397

398-
def _jwt_post(self, url, jwt):
399-
response = self.client.post(url, content_type='application/json',
400-
headers={'Authorization': 'Bearer {}'.format(jwt)})
401-
status_code = response.status_code
402-
data = json.loads(response.get_data(as_text=True))
403-
return status_code, data
404-
405-
def _jwt_get(self, url, jwt, header_name='Authorization', header_type='Bearer'):
406-
header_type = '{} {}'.format(header_type, jwt).strip()
407-
response = self.client.get(url, headers={header_name: header_type})
408-
status_code = response.status_code
409-
data = json.loads(response.get_data(as_text=True))
410-
return status_code, data
411-
412398
def _login(self):
413399
resp = self.client.post('/auth/login')
414400
index = 1
@@ -491,7 +477,7 @@ def test_endpoints_with_cookies(self):
491477
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = False
492478

493479
# Try access without logging in
494-
response = self.client.get('/api/protected')
480+
response = self.client.post('/api/protected')
495481
status_code = response.status_code
496482
data = json.loads(response.get_data(as_text=True))
497483
self.assertEqual(status_code, 401)
@@ -506,7 +492,7 @@ def test_endpoints_with_cookies(self):
506492

507493
# Try with logging in
508494
self._login()
509-
response = self.client.get('/api/protected')
495+
response = self.client.post('/api/protected')
510496
status_code = response.status_code
511497
data = json.loads(response.get_data(as_text=True))
512498
self.assertEqual(status_code, 200)
@@ -525,7 +511,7 @@ def test_endpoints_with_cookies(self):
525511
access_cookie_key = access_cookie_str.split('=')[0]
526512
access_cookie_value = "".join(access_cookie_str.split('=')[1:])
527513
self.client.set_cookie('localhost', access_cookie_key, access_cookie_value)
528-
response = self.client.get('/api/protected')
514+
response = self.client.post('/api/protected')
529515
status_code = response.status_code
530516
data = json.loads(response.get_data(as_text=True))
531517
self.assertEqual(status_code, 200)
@@ -535,7 +521,7 @@ def test_access_endpoints_with_cookies_and_csrf(self):
535521
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True
536522

537523
# Try without logging in
538-
response = self.client.get('/api/protected')
524+
response = self.client.post('/api/protected')
539525
status_code = response.status_code
540526
data = json.loads(response.get_data(as_text=True))
541527
self.assertEqual(status_code, 401)
@@ -545,30 +531,30 @@ def test_access_endpoints_with_cookies_and_csrf(self):
545531
access_csrf, refresh_csrf = self._login()
546532

547533
# Try with logging in but without double submit csrf protection
548-
response = self.client.get('/api/protected')
534+
response = self.client.post('/api/protected')
549535
status_code = response.status_code
550536
data = json.loads(response.get_data(as_text=True))
551537
self.assertEqual(status_code, 401)
552538
self.assertIn('msg', data)
553539

554540
# Try with logged in and bad header name for double submit token
555-
response = self.client.get('/api/protected',
541+
response = self.client.post('/api/protected',
556542
headers={'bad-header-name': 'banana'})
557543
status_code = response.status_code
558544
data = json.loads(response.get_data(as_text=True))
559545
self.assertEqual(status_code, 401)
560546
self.assertIn('msg', data)
561547

562548
# Try with logged in and bad header data for double submit token
563-
response = self.client.get('/api/protected',
549+
response = self.client.post('/api/protected',
564550
headers={'X-CSRF-TOKEN': 'banana'})
565551
status_code = response.status_code
566552
data = json.loads(response.get_data(as_text=True))
567553
self.assertEqual(status_code, 401)
568554
self.assertIn('msg', data)
569555

570556
# Try with logged in and good double submit token
571-
response = self.client.get('/api/protected',
557+
response = self.client.post('/api/protected',
572558
headers={'X-CSRF-TOKEN': access_csrf})
573559
status_code = response.status_code
574560
data = json.loads(response.get_data(as_text=True))
@@ -582,7 +568,7 @@ def test_access_endpoints_with_cookie_missing_csrf_field(self):
582568
self._login()
583569
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True
584570

585-
response = self.client.get('/api/protected')
571+
response = self.client.post('/api/protected')
586572
status_code = response.status_code
587573
data = json.loads(response.get_data(as_text=True))
588574
self.assertEqual(status_code, 422)
@@ -606,12 +592,58 @@ def test_access_endpoints_with_cookie_csrf_claim_not_string(self):
606592
self.client.set_cookie('localhost', access_cookie_key, encoded_token)
607593

608594
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True
609-
response = self.client.get('/api/protected')
595+
response = self.client.post('/api/protected')
610596
status_code = response.status_code
611597
data = json.loads(response.get_data(as_text=True))
612598
self.assertEqual(status_code, 422)
613599
self.assertIn('msg', data)
614600

601+
def test_custom_csrf_methods(self):
602+
@self.app.route('/protected-post', methods=['POST'])
603+
@jwt_required
604+
def protected_post():
605+
return jsonify({'msg': "hello world"})
606+
607+
@self.app.route('/protected-get', methods=['GET'])
608+
@jwt_required
609+
def protected_get():
610+
return jsonify({'msg': "hello world"})
611+
612+
# Login (saves jwts in the cookies for the test client
613+
self.app.config['JWT_COOKIE_CSRF_PROTECT'] = True
614+
self._login()
615+
616+
# Test being able to access GET without CSRF protection, and POST with
617+
# CSRF protection
618+
self.app.config['JWT_CSRF_METHODS'] = ['POST']
619+
620+
response = self.client.post('/protected-post')
621+
status_code = response.status_code
622+
data = json.loads(response.get_data(as_text=True))
623+
self.assertEqual(status_code, 401)
624+
self.assertIn('msg', data)
625+
626+
response = self.client.get('/protected-get')
627+
status_code = response.status_code
628+
data = json.loads(response.get_data(as_text=True))
629+
self.assertEqual(status_code, 200)
630+
self.assertEqual(data, {'msg': 'hello world'})
631+
632+
# Now swap it around, and verify the JWT_CRSF_METHODS are being honored
633+
self.app.config['JWT_CSRF_METHODS'] = ['GET']
634+
635+
response = self.client.get('/protected-get')
636+
status_code = response.status_code
637+
data = json.loads(response.get_data(as_text=True))
638+
self.assertEqual(status_code, 401)
639+
self.assertIn('msg', data)
640+
641+
response = self.client.post('/protected-post')
642+
status_code = response.status_code
643+
data = json.loads(response.get_data(as_text=True))
644+
self.assertEqual(status_code, 200)
645+
self.assertEqual(data, {'msg': 'hello world'})
646+
615647

616648
class TestEndpointsWithHeadersAndCookies(unittest.TestCase):
617649

0 commit comments

Comments
 (0)