diff --git a/flask_restful/utils/cors.py b/flask_restful/utils/cors.py index 40674461..523a014e 100644 --- a/flask_restful/utils/cors.py +++ b/flask_restful/utils/cors.py @@ -3,7 +3,7 @@ from functools import update_wrapper -def crossdomain(origin=None, methods=None, headers=None, +def crossdomain(origin=None, methods=None, headers=None, expose_headers=None, max_age=21600, attach_to_all=True, automatic_options=True, credentials=False): """ @@ -13,6 +13,8 @@ def crossdomain(origin=None, methods=None, headers=None, methods = ', '.join(sorted(x.upper() for x in methods)) if headers is not None and not isinstance(headers, str): headers = ', '.join(x.upper() for x in headers) + if expose_headers is not None and not isinstance(expose_headers, str): + expose_headers = ', '.join(x.upper() for x in expose_headers) if not isinstance(origin, str): origin = ', '.join(origin) if isinstance(max_age, timedelta): @@ -43,6 +45,8 @@ def wrapped_function(*args, **kwargs): h['Access-Control-Allow-Credentials'] = 'true' if headers is not None: h['Access-Control-Allow-Headers'] = headers + if expose_headers is not None: + h['Access-Control-Expose-Headers'] = expose_headers return resp f.provide_automatic_options = False diff --git a/tests/test_cors.py b/tests/test_cors.py index 1bf36d1d..2fa395b2 100644 --- a/tests/test_cors.py +++ b/tests/test_cors.py @@ -27,6 +27,24 @@ def get(self): assert_true('OPTIONS' in res.headers['Access-Control-Allow-Methods']) assert_true('GET' in res.headers['Access-Control-Allow-Methods']) + def test_access_control_expose_headers(self): + + class Foo(flask_restful.Resource): + @cors.crossdomain(origin='*', + expose_headers=['X-My-Header', 'X-Another-Header']) + def get(self): + return "data" + + app = Flask(__name__) + api = flask_restful.Api(app) + api.add_resource(Foo, '/') + + with app.test_client() as client: + res = client.get('/') + assert_equals(res.status_code, 200) + assert_true('X-MY-HEADER' in res.headers['Access-Control-Expose-Headers']) + assert_true('X-ANOTHER-HEADER' in res.headers['Access-Control-Expose-Headers']) + def test_no_crossdomain(self): class Foo(flask_restful.Resource):